tfep.nn.masked.masked_weight_norm

tfep.nn.masked.masked_weight_norm(module, name='weight', dim=0)[source]

NaN-free implementation of weight normalization.

Applying the normal weight normalization implemented with torch.nn.utils.weight_norm() results in NaN entries in the matrices when the mask covers an entire vector (thus making its norm zero). This takes care of this special case.

See also

torch.nn.utils.weight_norm.weight_norm