#!/usr/bin/env python
# =============================================================================
# MODULE DOCSTRING
# =============================================================================
"""
Masked linear transformations for PyTorch.
The module include both functional (``masked_linear``) and ``Module`` API
(``MaskedLinear``) to implement a masked linear transformation.
It also contains functions to implement weight normalization in masked linear
layers (``masked_weight_norm``). Indeed, the mask may cause NaNs in the native
PyTorch implementation.
"""
# =============================================================================
# GLOBAL IMPORTS
# =============================================================================
import torch.autograd
import torch.nn.functional
from torch import norm_except_dim
from torch.nn.parameter import Parameter
from torch.nn.utils.weight_norm import WeightNorm
# =============================================================================
# CREATE AUTOREGRESSIVE MASKS
# =============================================================================
[docs]
def create_autoregressive_mask(
degrees_in,
degrees_out,
strictly_less=True,
transpose=False,
dtype=None
):
"""Create an autoregressive mask between input and output connections.
``mask[i][j]`` is ``1`` if the i-th input is connected to th j-th output.
The output nodes are connected to input nodes with a strictly less degree
unless ``strictly_less`` is ``False``, in which case output nodes are connected
to all input nodes with less or equal degree.
This function can be used to implement masks as proposed in the MADE
paper [1] by setting ``strictly_less=False`` for hidden layers and ``True``
for the output layer (see Eq. 13 in the MADE paper).
Parameters
----------
degrees_in : numpy.ndarray[int] or torch.Tensor[int]
Shape ``(n_input_nodes,)``. ``degrees_in[k]`` is the integer degree
assigned to the ``k``-th input node (i.e., :math:`m^{l-1}(k)` in the
MADE paper).
degrees_out : numpy.ndarray[int] or torch.Tensor[int]
Shape ``(n_output_nodes,)``. ``degrees_out[k]`` is the integer degree
assigned to the ``k``-th output node (i.e., :math:`m^l(k)` in the MADE
paper).
strictly_less : bool, optional
``True`` if the output nodes must be connected to input node with a strictly
less degree. Otherwise, nodes are connected if they have a less or equal
degree.
transpose : bool, optional
If ``True``, the returned mask is transposed and input/output node indices
are swapped.
dtype : torch.dtype, optional
The data type of the returned mask. By default, the default PyTorch type
is used.
Returns
-------
mask : torch.Tensor
If ``transpose`` is ``False``, this has shape ``(n_input_nodes, n_output_nodes)``,
otherwise ``(n_output_nodes, n_input_nodes)``. In the first(latter) case, ``mask[i][j]``
is ``1`` if the i-th input(output) is connected to th j-th output(input).
This corresponds to the :math:`W^l`, in the MADE paper.
References
----------
[1] Germain M, Gregor K, Murray I, Larochelle H. Made: Masked autoencoder
for distribution estimation. In International Conference on Machine
Learning 2015 Jun 1 (pp. 881-889).
"""
if transpose:
if strictly_less:
mask = degrees_out[:, None] > degrees_in[None, :]
else:
mask = degrees_out[:, None] >= degrees_in[None, :]
else:
if strictly_less:
mask = degrees_out[None, :] > degrees_in[:, None]
else:
mask = degrees_out[None, :] >= degrees_in[:, None]
# Convert to tensor of default type before returning.
if not torch.is_tensor(mask):
mask = torch.from_numpy(mask)
if dtype is None:
dtype = torch.get_default_dtype()
if mask.dtype != dtype:
mask = mask.type(dtype)
return mask
# =============================================================================
# MASKED LINEAR MODULE API
# =============================================================================
[docs]
class MaskedLinear(torch.nn.Linear):
r"""Implement the masked linear transformation: :math:`y = x \cdot (M \circ A)^T + b`.
Parameters
----------
in_features : int
Size of each input sample.
out_features : int
Size of each output sample.
bias : bool, optional
If set to ``False``, the layer will not learn an additive bias.
Default is ``True``.
mask : torch.Tensor, optional
The mask of zeros and ones of shape ``(out_features, in_features)``
to apply to the scaling matrix. Default is ``None``.
Attributes
----------
weight : torch.Tensor
The learnable weights of the module of shape ``(out_features, in_features)``.
The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`,
where :math:`k = \frac{1}{\text{in\_features}}`.
bias : torch.Tensor
The learnable bias of the module of shape ``(out_features)``.
If :attr:`bias` is ``True``, the values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in\_features}}`.
mask : torch.Tensor
The mask passed during initialization.
See Also
--------
functions.MaskedLinearFunc
The autograd ``Function`` object used to implement the module.
Examples
--------
>>> in_features, out_features, batch_size = 8, 5, 20
>>> # Lower triangular mask.
>>> mask = torch.tril(torch.ones(out_features, in_features, dtype=torch.bool))
>>> m = MaskedLinear(in_features, out_features, mask=mask)
>>> input = torch.randn(batch_size, in_features)
>>> output = m(input)
>>> print(output.size())
torch.Size([20, 5])
"""
[docs]
def __init__(self, in_features, out_features, bias=True, mask=None):
# Let nn.Linear register and initialize the parameters.
super().__init__(in_features, out_features, bias=bias)
# We don't need to propagate gradients through the mask so we
# register it as a buffer.
self.register_buffer('mask', mask)
# Set the masked weights to 0.0. This effectively sets the
# gradient of the masked parameters to zero even when weight
# normalization (whose gradient has a component that depend
# on the gradient w.r.t. g) is used.
self.weight.data = self.weight.data * self.mask
[docs]
def n_parameters(self):
"""int: The total number of (unmasked) parameters."""
if self.mask is None:
n_parameters = self.weight.numel()
else:
n_parameters = (self.mask != 0).sum()
if self.bias is not None:
n_parameters += self.bias.numel()
return n_parameters
[docs]
def forward(self, input):
"""
Performs the forward computation.
Parameters
----------
input : torch.Tensor
Input of shape ``(batch_size, *, in_features)`` where ``*``
means any number of additional dimensions.
Returns
-------
output : torch.Tensor
Output of shape ``(batch_size, *, in_features)`` where ``*``
is the same number number of additional dimensions in ``input``.
"""
# If there is no mask, fall back to normal linear behavior.
if self.mask is None:
return super().forward(input)
return masked_linear(input, self.weight, self.bias, self.mask)
# =============================================================================
# MASKED LINEAR FUNCTIONAL API
# =============================================================================
[docs]
class MaskedLinearFunc(torch.autograd.Function):
r"""Implement the masked linear transformation: :math:`y = x \cdot (M \circ A)^T + b`.
This is based on :func:`torch.nn.functional.linear`, but with an extra
keyword argument ``mask`` having the same shape as ``weight``.
Note that the function does not perform a sparse multiplication, but
simply implements the mask with an element-wise multiplication of the
weight matrix before evaluating the linear transformation.
A functional shortcut to ``MaskedLinearFunc`` is available in this same
module with ``masked_linear``.
The return value is a ``Tensor`` of shape ``(batch_size, *, n_out_features)``,
where ``*`` correspond to the same number of additional dimensions
in the `input` argument.
Parameters
----------
input : torch.Tensor
Input tensor x of shape ``(batch_size, *, n_in_features)``, where
``*`` means any number of additional dimensions.
weight : torch.Tensor
Scaling tensor A of shape ``(n_out_features, n_in_features)``.
bias : torch.Tensor, optional
Shifting tensor b of shape ``(n_out_features)``.
mask : torch.Tensor, optional
Mask of A of shape ``(n_out_features, n_in_features)``.
Examples
--------
>>> batch_size = 2
>>> in_features = 3
>>> out_features = 5
>>> input = torch.randn(batch_size, in_features, dtype=torch.double)
>>> weight = torch.randn(out_features, in_features, dtype=torch.double)
>>> bias = torch.randn(out_features, dtype=torch.double)
>>> # Lower triangular mask.
>>> mask = torch.tril(torch.ones(out_features, in_features, dtype=torch.bool))
>>> output = masked_linear(input, weight, bias, mask)
"""
[docs]
@staticmethod
def forward(ctx, input, weight, bias=None, mask=None):
# Check if we need to mask the weights.
if mask is not None:
# Mask weight matrix.
weight = weight * mask
# We save the MASKED weights for backward propagation so that
# we don't need to perform the element-wise multiplication.
ctx.save_for_backward(input, weight, bias, mask)
# Compute the linear transformation.
return torch.nn.functional.linear(input, weight, bias)
[docs]
@staticmethod
def backward(ctx, grad_output):
# Unpack previously stored tensors.
input, masked_weight, bias, mask = ctx.saved_tensors
# We still need to return None for grad_mask even if we don't
# compute its gradient.
grad_input = grad_weight = grad_bias = grad_mask = None
# Compute gradients w.r.t. input parameters.
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(masked_weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
# Mask the gradients.
if mask is not None:
grad_weight.mul_(mask)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias, grad_mask
# Functional notation.
masked_linear = MaskedLinearFunc.apply
# =============================================================================
# WEIGHT NORMALIZATION FOR MASKED LINEAR LAYER
# =============================================================================
[docs]
def masked_weight_norm(module, name='weight', dim=0):
"""NaN-free implementation of weight normalization.
Applying the normal weight normalization implemented with :func:`torch.nn.utils.weight_norm`
results in NaN entries in the matrices when the mask covers an entire
vector (thus making its norm zero). This takes care of this special
case.
See Also
--------
torch.nn.utils.weight_norm.weight_norm
"""
try:
mask = module.mask
except AttributeError:
mask = None
MaskedWeightNorm.apply(module, name, dim, mask)
return module
[docs]
def remove_masked_weight_norm(module, name='weight'):
"""Remove masked weighed normalization hooks.
See Also
--------
torch.nn.utils.weight_norm.remove_weight_norm
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, MaskedWeightNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError("weight_norm of '{}' not found in {}"
.format(name, module))
[docs]
class MaskedWeightNorm(WeightNorm):
"""NaN-free implementation of weight normalization.
Applying the normal weight normalization implemented with :func:`torch.nn.utils.weight_norm`
results in NaN entries in the matrices when the mask covers an entire
vector (thus making its norm zero). This takes care of this special
case.
See Also
--------
torch.nn.utils.weight_norm.WeightNorm
"""
[docs]
def __init__(self, name, dim, mask):
super().__init__(name, dim)
self.apply_mask = _ApplyMask(mask)
def compute_weight(self, module):
weight = super().compute_weight(module)
return self.apply_mask(weight)
@staticmethod
def apply(module, name, dim, mask):
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, MaskedWeightNorm) and hook.name == name:
raise RuntimeError("Cannot register two weight_norm hooks on "
"the same parameter {}".format(name))
if dim is None:
dim = -1
fn = MaskedWeightNorm(name, dim, mask)
weight = getattr(module, name)
# remove w from parameter list
del module._parameters[name]
# add g and v as new parameters and express w as g/||v|| * v
g = Parameter(norm_except_dim(weight, 2, dim).data)
v = Parameter(weight.data)
module.register_parameter(name + '_g', g)
module.register_parameter(name + '_v', v)
setattr(module, name, fn.compute_weight(module))
# recompute weight before every forward()
module.register_forward_pre_hook(fn)
# Register hook to zero out gradient in the masked weights.
g.register_hook(_ApplyMask(mask, dim, norm=True))
v.register_hook(_ApplyMask(mask))
return fn
class _ApplyMask:
"""NaN-safe mask application.
Parameters
----------
norm : bool, optional
If True, the mask is applied to a norm vector (i.e., g) rather
than a matrix (i.e., v or w). Default is False.
inplace : bool, optional
If True, the tensor is modified in place when ApplyTask is called.
Otherwise, a copy is created.
"""
def __init__(self, mask, dim=0, norm=False, inplace=True):
# Precompute the masked indices.
self.inplace = inplace
self._zero_indices = None
if mask is not None:
if norm:
# For g, we need to zet to zero only those vectors
# that have zero norm because of the mask.
self._zero_indices = torch.nonzero(norm_except_dim(mask, 2, dim).flatten() == 0.0)
else:
self._zero_indices = mask == 0.0
def __call__(self, w):
# An element-wise multiplication doesn't work if there are NaNs.
if self._zero_indices is not None:
if not self.inplace:
w = w.clone()
w.data[self._zero_indices] = 0.0
return w
return None