tfep.nn.masked

Masked linear transformations for PyTorch.

The module include both functional (masked_linear) and Module API (MaskedLinear) to implement a masked linear transformation.

It also contains functions to implement weight normalization in masked linear layers (masked_weight_norm). Indeed, the mask may cause NaNs in the native PyTorch implementation.

Functions

create_autoregressive_mask(degrees_in, ...)

Create an autoregressive mask between input and output connections.

masked_weight_norm(module[, name, dim])

NaN-free implementation of weight normalization.

remove_masked_weight_norm(module[, name])

Remove masked weighed normalization hooks.

Classes

MaskedLinear(in_features, out_features[, ...])

Implement the masked linear transformation: \(y = x \cdot (M \circ A)^T + b\).

MaskedLinearFunc(*args, **kwargs)

Implement the masked linear transformation: \(y = x \cdot (M \circ A)^T + b\).

MaskedWeightNorm(name, dim, mask)

NaN-free implementation of weight normalization.