tfep.nn.masked.MaskedWeightNorm
- class tfep.nn.masked.MaskedWeightNorm(name, dim, mask)[source]
Bases:
WeightNormNaN-free implementation of weight normalization.
Applying the normal weight normalization implemented with
torch.nn.utils.weight_norm()results in NaN entries in the matrices when the mask covers an entire vector (thus making its norm zero). This takes care of this special case.See also
torch.nn.utils.weight_norm.WeightNormMethods
__init__(name, dim, mask)apply(module, name, dim, mask)compute_weight(module)remove(module)Attributes
namedim