tfep.nn.masked
Masked linear transformations for PyTorch.
The module include both functional (masked_linear) and Module API
(MaskedLinear) to implement a masked linear transformation.
It also contains functions to implement weight normalization in masked linear
layers (masked_weight_norm). Indeed, the mask may cause NaNs in the native
PyTorch implementation.
Functions
|
Create an autoregressive mask between input and output connections. |
|
NaN-free implementation of weight normalization. |
|
Remove masked weighed normalization hooks. |
Classes
|
Implement the masked linear transformation: \(y = x \cdot (M \circ A)^T + b\). |
|
Implement the masked linear transformation: \(y = x \cdot (M \circ A)^T + b\). |
|
NaN-free implementation of weight normalization. |