Source code for tfep.nn.transformers.affine
#!/usr/bin/env python
# =============================================================================
# MODULE DOCSTRING
# =============================================================================
"""
Affine transformer for autoregressive normalizing flows.
"""
# =============================================================================
# GLOBAL IMPORTS
# =============================================================================
from typing import Optional
import torch
from tfep.nn.transformers.transformer import MAFTransformer
# =============================================================================
# AFFINE
# =============================================================================
[docs]
class AffineTransformer(MAFTransformer):
r"""Affine transformer module for autoregressive normalizing flows.
This is an implementation of the transformation
:math:`y_i = exp(a_i) * x_i + b_i`
where :math:`a_i` and :math:`b_i` are the log scale and shift parameters of
the transformation that are usually generated by a conditioner.
See Also
--------
:func:`.affine_transformer`
Functional API for the transformer.
"""
# Number of parameters needed by the transformer for each input dimension.
n_parameters_per_feature = 2
[docs]
def forward(self, x: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Apply the affine transformation to the input.
Parameters
----------
x : torch.Tensor
Shape ``(batch_size, n_features)``. The input features.
parameters : torch.Tensor
Shape ``(batch_size, 2*n_features)``. The parameters for the affine
transformation where ``parameters[:, i]`` is the shift parameter
:math:`b_1` and ``parameters[:, n_features+i]`` is the log scale
:math:`a_1` for the ``i``-th feature.
Returns
-------
y : torch.Tensor
Shape ``(batch_size, n_features)``. The transformed features.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The log absolute value of the Jacobian
determinant of the transformation.
"""
shift, log_scale = self._split_parameters(parameters)
return affine_transformer(x, shift, log_scale)
[docs]
def inverse(self, y: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Reverse the affine transformation.
Parameters
----------
y : torch.Tensor
Shape ``(batch_size, n_features)``. The input features.
parameters : torch.Tensor
Shape ``(batch_size, 2*n_features)``. The parameters for the affine
transformation where ``parameters[:, i]`` is the shift parameter
:math:`b_1` and ``parameters[:, n_features+i]`` is the log scale
:math:`a_1` for the ``i``-th feature.
Returns
-------
x : torch.Tensor
Shape ``(batch_size, n_features)``. The transformed features.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The log absolute value of the Jacobian
determinant of the transformation.
"""
shift, log_scale = self._split_parameters(parameters)
return affine_transformer_inverse(y, shift, log_scale)
[docs]
def get_identity_parameters(self, n_features: int) -> torch.Tensor:
"""Return the value of the parameters that makes this the identity function.
This can be used to initialize the normalizing flow to perform the identity
transformation. Both the shift and the log scale must be zero for the affine
transformation to be the identity.
Parameters
----------
n_features : int
The dimension of the input vector of the transformer.
Returns
-------
parameters : torch.Tensor
Shape ``(2*n_features)``. The parameters for the identity.
"""
return torch.zeros(size=(self.n_parameters_per_feature*n_features,))
[docs]
def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
"""Returns the degrees associated to the conditioner's output.
Parameters
----------
degrees_in : torch.Tensor
Shape ``(n_transformed_features,)``. The autoregressive degrees
associated to the features provided as input to the transformer.
Returns
-------
degrees_out : torch.Tensor
Shape ``(n_parameters,)``. The autoregressive degrees associated
to each output of the conditioner that will be fed to the
transformer as parameters.
"""
return degrees_in.tile((self.n_parameters_per_feature,))
def _split_parameters(self, parameters):
"""Divide shift from log scale."""
# From (batch, 2*n_features) to (batch, 2, n_features).
batch_size = parameters.shape[0]
parameters = parameters.reshape(batch_size, self.n_parameters_per_feature, -1)
return parameters[:, 0], parameters[:, 1]
# =============================================================================
# VOLUME PRESERVING TRANSFORMER
# =============================================================================
[docs]
class VolumePreservingShiftTransformer(MAFTransformer):
r"""Implement a volume-preserving transformer for autoregressive normalizing flows.
This is an implementation of the transformation
:math:`y_i = x_i + b_i`
where :math:`b_i` is the shift parameter of the transformation that are
usually generated by a conditioner.
See Also
--------
:func:`.volume_preserving_shift_transformer`
Functional API for the transformer.
"""
# Number of parameters needed by the transformer for each input dimension.
n_parameters_per_feature = 1
[docs]
def __init__(
self,
periodic_indices: Optional[torch.Tensor] = None,
periodic_limits: Optional[torch.Tensor] = None,
):
"""Constructor.
Parameters
----------
periodic_indices : torch.Tensor, optional
If provided, the features indexed by ``periodic_indices`` will be
treated as periodic with period ``periodic_limits``.
periodic_limits : torch.Tensor, optional
The period of periodic features.
"""
super().__init__()
self.periodic_indices = periodic_indices
self.periodic_limits = periodic_limits
[docs]
def forward(self, x: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Apply the affine transformation to the input.
Parameters
----------
x : torch.Tensor
Shape ``(batch_size, n_features)``. The input features.
parameters : torch.Tensor
Shape ``(batch_size, n_features)``. The parameters for the volume-
preserving transformation, where ``parameters[:, i]`` is the shift
parameter :math:`b_1` for the ``i``-th feature.
Returns
-------
y : torch.Tensor
Shape ``(batch_size, n_features)``. The transformed features.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The log absolute value of the Jacobian
determinant of the transformation.
"""
return volume_preserving_shift_transformer(
x, shift=parameters, periodic_indices=self.periodic_indices,
periodic_limits=self.periodic_limits)
[docs]
def inverse(self, y: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Reverse the affine transformation.
Parameters
----------
y : torch.Tensor
Shape ``(batch_size, n_features)``. The input features.
parameters : torch.Tensor
Shape ``(batch_size, n_features)``. The parameters for the volume-
preserving transformation, where ``parameters[:, i]`` is the shift
parameter :math:`b_1` for the ``i``-th feature.
Returns
-------
x : torch.Tensor
Shape ``(batch_size, n_features)``. The transformed features.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The log absolute value of the Jacobian
determinant of the transformation.
"""
return volume_preserving_shift_transformer_inverse(
y, shift=parameters, periodic_indices=self.periodic_indices,
periodic_limits=self.periodic_limits)
[docs]
def get_identity_parameters(self, n_features: int) -> torch.Tensor:
"""Return the value of the parameters that makes this the identity function.
This can be used to initialize the normalizing flow to perform the identity
transformation. The shift must be zero for the transformation to be the
identity.
Parameters
----------
n_features : int
The dimension of the input vector of the transformer.
Returns
-------
parameters : torch.Tensor
Shape ``(n_features)``. The parameters for the identity.
"""
return torch.zeros(size=(self.n_parameters_per_feature*n_features,))
[docs]
def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
"""Returns the degrees associated to the conditioner's output.
Parameters
----------
degrees_in : torch.Tensor
Shape ``(n_transformed_features,)``. The autoregressive degrees
associated to the features provided as input to the transformer.
Returns
-------
degrees_out : torch.Tensor
Shape ``(n_parameters,)``. The autoregressive degrees associated
to each output of the conditioner that will be fed to the
transformer as parameters.
"""
return degrees_in.tile((self.n_parameters_per_feature,))
# =============================================================================
# FUNCTIONAL API
# =============================================================================
[docs]
def affine_transformer(x, shift, log_scale):
r"""Implement an affine transformer for autoregressive normalizing flows.
This provides a functional API to the ``AffineTransformer`` layer. It
implements the transformation
:math:`y_i = exp(a_i) * x_i + b_i`
where :math:`a_i` and :math:`b_i` are the log scale and shift parameters of
the transformation that are usually generated by a conditioner.
The function returns the transformed feature as a ``Tensor`` of shape
``(batch_size, n_features)`` and the log absolute determinant of its
Jacobian as a ``Tensor`` of shape ``(batch_size,)``.
Parameters
----------
x : torch.Tensor
Input tensor x of shape ``(batch_size, n_features)``.
shift : torch.Tensor
The shift coefficients of shape ``(batch_size, n_features)``
(i.e. the ``b`` coefficients).
log_scale : torch.Tensor
The logarithm of the scale coefficients of shape ``(batch_size, n_features)``
(i.e. the ``a`` coefficients).
Returns
-------
y : torch.Tensor
Output tensor of shape ``(batch_size, n_features)``.
log_det_J : torch.Tensor
The logarithm of the absolute value of the determinant of the Jacobian
of the transformation with shape ``(batch_size,)``.
See Also
--------
:class:`.AffineTransformer`
Object-oriented API for the transformer.
"""
y = x * torch.exp(log_scale) + shift
log_det_J = torch.sum(log_scale, dim=1)
return y, log_det_J
[docs]
def affine_transformer_inverse(y, shift, log_scale):
r"""Inverse function of ``affine_transformer``.
This provides a functional API to the ``AffineTransformer`` layer. It
implements the inverse of the transformation
:math:`y_i = exp(a_i) * x_i + b_i`
where :math:`a_i` and :math:`b_i` are the log scale and shift parameters of
the transformation that are usually generated by a conditioner.
The function returns the transformed feature as a ``Tensor`` of shape
``(batch_size, n_features)`` and the log absolute determinant of its
Jacobian as a ``Tensor`` of shape ``(batch_size,)``.
Parameters
----------
y : torch.Tensor
Input tensor x of shape ``(batch_size, n_features)``.
shift : torch.Tensor
The shift coefficients of shape ``(batch_size, n_features)``
(i.e. the ``b`` coefficients).
log_scale : torch.Tensor
The logarithm of the scale coefficients of shape ``(batch_size, n_features)``
(i.e. the ``a`` coefficients).
Returns
-------
x : torch.Tensor
Output tensor of shape ``(batch_size, n_features)``.
log_det_J : torch.Tensor
The logarithm of the absolute value of the determinant of the Jacobian
of the transformation with shape ``(batch_size,)``.
"""
x = (y - shift) * torch.exp(-log_scale)
log_det_J = -torch.sum(log_scale, dim=1)
return x, log_det_J
[docs]
def volume_preserving_shift_transformer(x, shift, periodic_indices=None, periodic_limits=None):
r"""Implement a volume-preserving transformer for autoregressive normalizing flows.
This provides a functional API to the ``VolumePreservingShiftTransformer``
layer. It implements the transformation
:math:`y_i = x_i + b_i`
where :math:`b_i` is the shift parameter of the transformation that are
usually generated by a conditioner.
The function returns the transformed feature as a ``Tensor`` of shape
``(batch_size, n_features)`` and the log absolute determinant of its
Jacobian (always a zero vector) as a ``Tensor`` of shape ``(batch_size,)``.
Parameters
----------
x : torch.Tensor
Input tensor x of shape ``(batch_size, n_features)``.
shift : torch.Tensor
The shift coefficients of shape ``(batch_size, n_features)``
(i.e. the ``b`` coefficients).
periodic_indices : torch.Tensor, optional
If provided, the features indexed by ``periodic_indices`` will be treated
as periodic with period ``periodic_limits``.
periodic_limits : torch.Tensor, optional
The period of periodic features.
Returns
-------
y : torch.Tensor
Output tensor of shape ``(batch_size, n_features)``.
log_det_J : torch.Tensor
The logarithm of the absolute value of the determinant of the Jacobian
of the transformation (always zero) with shape ``(batch_size,)``.
See Also
--------
:class:`.VolumePreservingShiftTransformer`
Object-oriented API for the transformer.
"""
y = x + shift
if periodic_indices is not None:
y[:, periodic_indices] = y[:, periodic_indices] % (periodic_limits[1] - periodic_limits[0]) + periodic_limits[0]
log_det_J = torch.zeros(x.shape[0], dtype=x.dtype, device=x.device)
return y, log_det_J
[docs]
def volume_preserving_shift_transformer_inverse(y, shift, periodic_indices=None, periodic_limits=None):
r"""Inverse function of ``volume_preserving_shift_transformer``.
This provides a functional API to the ``VolumePreservingShiftTransformer``
layer. It implements the inverse of the transformation
:math:`y_i = x_i + b_i`
where :math:`b_i` is the shift parameter of the transformation that are
usually generated by a conditioner.
The function returns the transformed feature as a ``Tensor`` of shape
``(batch_size, n_features)`` and the log absolute determinant of its
Jacobian (always a zero vector) as a ``Tensor`` of shape ``(batch_size,)``.
Parameters
----------
y : torch.Tensor
Input tensor x of shape ``(batch_size, n_features)``.
shift : torch.Tensor
The shift coefficients of shape ``(batch_size, n_features)``
(i.e. the ``b`` coefficients).
periodic_indices : torch.Tensor, optional
If provided, the features indexed by ``periodic_indices`` will be treated
as periodic with period ``periodic_limits``.
periodic_limits : torch.Tensor, optional
The period of periodic features.
Returns
-------
x : torch.Tensor
Output tensor of shape ``(batch_size, n_features)``.
log_det_J : torch.Tensor
The logarithm of the absolute value of the determinant of the Jacobian
of the transformation (always zero) with shape ``(batch_size,)``.
"""
x = y - shift
if periodic_indices is not None:
x[:, periodic_indices] = x[:, periodic_indices] % (periodic_limits[1] - periodic_limits[0]) + periodic_limits[0]
log_det_J = torch.zeros(y.shape[0], dtype=y.dtype, device=y.device)
return x, log_det_J