tfep.nn.masked.MaskedWeightNorm

class tfep.nn.masked.MaskedWeightNorm(name, dim, mask)[source]

Bases: WeightNorm

NaN-free implementation of weight normalization.

Applying the normal weight normalization implemented with torch.nn.utils.weight_norm() results in NaN entries in the matrices when the mask covers an entire vector (thus making its norm zero). This takes care of this special case.

See also

torch.nn.utils.weight_norm.WeightNorm

__init__(name, dim, mask)[source]

Methods

__init__(name, dim, mask)

apply(module, name, dim, mask)

compute_weight(module)

remove(module)

Attributes

name

dim