Source code for tfep.nn.transformers.affine

#!/usr/bin/env python


# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Affine transformer for autoregressive normalizing flows.
"""


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from typing import Optional

import torch

from tfep.nn.transformers.transformer import MAFTransformer


# =============================================================================
# AFFINE
# =============================================================================

[docs] class AffineTransformer(MAFTransformer): r"""Affine transformer module for autoregressive normalizing flows. This is an implementation of the transformation :math:`y_i = exp(a_i) * x_i + b_i` where :math:`a_i` and :math:`b_i` are the log scale and shift parameters of the transformation that are usually generated by a conditioner. See Also -------- :func:`.affine_transformer` Functional API for the transformer. """ # Number of parameters needed by the transformer for each input dimension. n_parameters_per_feature = 2
[docs] def forward(self, x: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]: """Apply the affine transformation to the input. Parameters ---------- x : torch.Tensor Shape ``(batch_size, n_features)``. The input features. parameters : torch.Tensor Shape ``(batch_size, 2*n_features)``. The parameters for the affine transformation where ``parameters[:, i]`` is the shift parameter :math:`b_1` and ``parameters[:, n_features+i]`` is the log scale :math:`a_1` for the ``i``-th feature. Returns ------- y : torch.Tensor Shape ``(batch_size, n_features)``. The transformed features. log_det_J : torch.Tensor Shape ``(batch_size,)``. The log absolute value of the Jacobian determinant of the transformation. """ shift, log_scale = self._split_parameters(parameters) return affine_transformer(x, shift, log_scale)
[docs] def inverse(self, y: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]: """Reverse the affine transformation. Parameters ---------- y : torch.Tensor Shape ``(batch_size, n_features)``. The input features. parameters : torch.Tensor Shape ``(batch_size, 2*n_features)``. The parameters for the affine transformation where ``parameters[:, i]`` is the shift parameter :math:`b_1` and ``parameters[:, n_features+i]`` is the log scale :math:`a_1` for the ``i``-th feature. Returns ------- x : torch.Tensor Shape ``(batch_size, n_features)``. The transformed features. log_det_J : torch.Tensor Shape ``(batch_size,)``. The log absolute value of the Jacobian determinant of the transformation. """ shift, log_scale = self._split_parameters(parameters) return affine_transformer_inverse(y, shift, log_scale)
[docs] def get_identity_parameters(self, n_features: int) -> torch.Tensor: """Return the value of the parameters that makes this the identity function. This can be used to initialize the normalizing flow to perform the identity transformation. Both the shift and the log scale must be zero for the affine transformation to be the identity. Parameters ---------- n_features : int The dimension of the input vector of the transformer. Returns ------- parameters : torch.Tensor Shape ``(2*n_features)``. The parameters for the identity. """ return torch.zeros(size=(self.n_parameters_per_feature*n_features,))
[docs] def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor: """Returns the degrees associated to the conditioner's output. Parameters ---------- degrees_in : torch.Tensor Shape ``(n_transformed_features,)``. The autoregressive degrees associated to the features provided as input to the transformer. Returns ------- degrees_out : torch.Tensor Shape ``(n_parameters,)``. The autoregressive degrees associated to each output of the conditioner that will be fed to the transformer as parameters. """ return degrees_in.tile((self.n_parameters_per_feature,))
def _split_parameters(self, parameters): """Divide shift from log scale.""" # From (batch, 2*n_features) to (batch, 2, n_features). batch_size = parameters.shape[0] parameters = parameters.reshape(batch_size, self.n_parameters_per_feature, -1) return parameters[:, 0], parameters[:, 1]
# ============================================================================= # VOLUME PRESERVING TRANSFORMER # =============================================================================
[docs] class VolumePreservingShiftTransformer(MAFTransformer): r"""Implement a volume-preserving transformer for autoregressive normalizing flows. This is an implementation of the transformation :math:`y_i = x_i + b_i` where :math:`b_i` is the shift parameter of the transformation that are usually generated by a conditioner. See Also -------- :func:`.volume_preserving_shift_transformer` Functional API for the transformer. """ # Number of parameters needed by the transformer for each input dimension. n_parameters_per_feature = 1
[docs] def __init__( self, periodic_indices: Optional[torch.Tensor] = None, periodic_limits: Optional[torch.Tensor] = None, ): """Constructor. Parameters ---------- periodic_indices : torch.Tensor, optional If provided, the features indexed by ``periodic_indices`` will be treated as periodic with period ``periodic_limits``. periodic_limits : torch.Tensor, optional The period of periodic features. """ super().__init__() self.periodic_indices = periodic_indices self.periodic_limits = periodic_limits
[docs] def forward(self, x: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]: """Apply the affine transformation to the input. Parameters ---------- x : torch.Tensor Shape ``(batch_size, n_features)``. The input features. parameters : torch.Tensor Shape ``(batch_size, n_features)``. The parameters for the volume- preserving transformation, where ``parameters[:, i]`` is the shift parameter :math:`b_1` for the ``i``-th feature. Returns ------- y : torch.Tensor Shape ``(batch_size, n_features)``. The transformed features. log_det_J : torch.Tensor Shape ``(batch_size,)``. The log absolute value of the Jacobian determinant of the transformation. """ return volume_preserving_shift_transformer( x, shift=parameters, periodic_indices=self.periodic_indices, periodic_limits=self.periodic_limits)
[docs] def inverse(self, y: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]: """Reverse the affine transformation. Parameters ---------- y : torch.Tensor Shape ``(batch_size, n_features)``. The input features. parameters : torch.Tensor Shape ``(batch_size, n_features)``. The parameters for the volume- preserving transformation, where ``parameters[:, i]`` is the shift parameter :math:`b_1` for the ``i``-th feature. Returns ------- x : torch.Tensor Shape ``(batch_size, n_features)``. The transformed features. log_det_J : torch.Tensor Shape ``(batch_size,)``. The log absolute value of the Jacobian determinant of the transformation. """ return volume_preserving_shift_transformer_inverse( y, shift=parameters, periodic_indices=self.periodic_indices, periodic_limits=self.periodic_limits)
[docs] def get_identity_parameters(self, n_features: int) -> torch.Tensor: """Return the value of the parameters that makes this the identity function. This can be used to initialize the normalizing flow to perform the identity transformation. The shift must be zero for the transformation to be the identity. Parameters ---------- n_features : int The dimension of the input vector of the transformer. Returns ------- parameters : torch.Tensor Shape ``(n_features)``. The parameters for the identity. """ return torch.zeros(size=(self.n_parameters_per_feature*n_features,))
[docs] def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor: """Returns the degrees associated to the conditioner's output. Parameters ---------- degrees_in : torch.Tensor Shape ``(n_transformed_features,)``. The autoregressive degrees associated to the features provided as input to the transformer. Returns ------- degrees_out : torch.Tensor Shape ``(n_parameters,)``. The autoregressive degrees associated to each output of the conditioner that will be fed to the transformer as parameters. """ return degrees_in.tile((self.n_parameters_per_feature,))
# ============================================================================= # FUNCTIONAL API # =============================================================================
[docs] def affine_transformer(x, shift, log_scale): r"""Implement an affine transformer for autoregressive normalizing flows. This provides a functional API to the ``AffineTransformer`` layer. It implements the transformation :math:`y_i = exp(a_i) * x_i + b_i` where :math:`a_i` and :math:`b_i` are the log scale and shift parameters of the transformation that are usually generated by a conditioner. The function returns the transformed feature as a ``Tensor`` of shape ``(batch_size, n_features)`` and the log absolute determinant of its Jacobian as a ``Tensor`` of shape ``(batch_size,)``. Parameters ---------- x : torch.Tensor Input tensor x of shape ``(batch_size, n_features)``. shift : torch.Tensor The shift coefficients of shape ``(batch_size, n_features)`` (i.e. the ``b`` coefficients). log_scale : torch.Tensor The logarithm of the scale coefficients of shape ``(batch_size, n_features)`` (i.e. the ``a`` coefficients). Returns ------- y : torch.Tensor Output tensor of shape ``(batch_size, n_features)``. log_det_J : torch.Tensor The logarithm of the absolute value of the determinant of the Jacobian of the transformation with shape ``(batch_size,)``. See Also -------- :class:`.AffineTransformer` Object-oriented API for the transformer. """ y = x * torch.exp(log_scale) + shift log_det_J = torch.sum(log_scale, dim=1) return y, log_det_J
[docs] def affine_transformer_inverse(y, shift, log_scale): r"""Inverse function of ``affine_transformer``. This provides a functional API to the ``AffineTransformer`` layer. It implements the inverse of the transformation :math:`y_i = exp(a_i) * x_i + b_i` where :math:`a_i` and :math:`b_i` are the log scale and shift parameters of the transformation that are usually generated by a conditioner. The function returns the transformed feature as a ``Tensor`` of shape ``(batch_size, n_features)`` and the log absolute determinant of its Jacobian as a ``Tensor`` of shape ``(batch_size,)``. Parameters ---------- y : torch.Tensor Input tensor x of shape ``(batch_size, n_features)``. shift : torch.Tensor The shift coefficients of shape ``(batch_size, n_features)`` (i.e. the ``b`` coefficients). log_scale : torch.Tensor The logarithm of the scale coefficients of shape ``(batch_size, n_features)`` (i.e. the ``a`` coefficients). Returns ------- x : torch.Tensor Output tensor of shape ``(batch_size, n_features)``. log_det_J : torch.Tensor The logarithm of the absolute value of the determinant of the Jacobian of the transformation with shape ``(batch_size,)``. """ x = (y - shift) * torch.exp(-log_scale) log_det_J = -torch.sum(log_scale, dim=1) return x, log_det_J
[docs] def volume_preserving_shift_transformer(x, shift, periodic_indices=None, periodic_limits=None): r"""Implement a volume-preserving transformer for autoregressive normalizing flows. This provides a functional API to the ``VolumePreservingShiftTransformer`` layer. It implements the transformation :math:`y_i = x_i + b_i` where :math:`b_i` is the shift parameter of the transformation that are usually generated by a conditioner. The function returns the transformed feature as a ``Tensor`` of shape ``(batch_size, n_features)`` and the log absolute determinant of its Jacobian (always a zero vector) as a ``Tensor`` of shape ``(batch_size,)``. Parameters ---------- x : torch.Tensor Input tensor x of shape ``(batch_size, n_features)``. shift : torch.Tensor The shift coefficients of shape ``(batch_size, n_features)`` (i.e. the ``b`` coefficients). periodic_indices : torch.Tensor, optional If provided, the features indexed by ``periodic_indices`` will be treated as periodic with period ``periodic_limits``. periodic_limits : torch.Tensor, optional The period of periodic features. Returns ------- y : torch.Tensor Output tensor of shape ``(batch_size, n_features)``. log_det_J : torch.Tensor The logarithm of the absolute value of the determinant of the Jacobian of the transformation (always zero) with shape ``(batch_size,)``. See Also -------- :class:`.VolumePreservingShiftTransformer` Object-oriented API for the transformer. """ y = x + shift if periodic_indices is not None: y[:, periodic_indices] = y[:, periodic_indices] % (periodic_limits[1] - periodic_limits[0]) + periodic_limits[0] log_det_J = torch.zeros(x.shape[0], dtype=x.dtype, device=x.device) return y, log_det_J
[docs] def volume_preserving_shift_transformer_inverse(y, shift, periodic_indices=None, periodic_limits=None): r"""Inverse function of ``volume_preserving_shift_transformer``. This provides a functional API to the ``VolumePreservingShiftTransformer`` layer. It implements the inverse of the transformation :math:`y_i = x_i + b_i` where :math:`b_i` is the shift parameter of the transformation that are usually generated by a conditioner. The function returns the transformed feature as a ``Tensor`` of shape ``(batch_size, n_features)`` and the log absolute determinant of its Jacobian (always a zero vector) as a ``Tensor`` of shape ``(batch_size,)``. Parameters ---------- y : torch.Tensor Input tensor x of shape ``(batch_size, n_features)``. shift : torch.Tensor The shift coefficients of shape ``(batch_size, n_features)`` (i.e. the ``b`` coefficients). periodic_indices : torch.Tensor, optional If provided, the features indexed by ``periodic_indices`` will be treated as periodic with period ``periodic_limits``. periodic_limits : torch.Tensor, optional The period of periodic features. Returns ------- x : torch.Tensor Output tensor of shape ``(batch_size, n_features)``. log_det_J : torch.Tensor The logarithm of the absolute value of the determinant of the Jacobian of the transformation (always zero) with shape ``(batch_size,)``. """ x = y - shift if periodic_indices is not None: x[:, periodic_indices] = x[:, periodic_indices] % (periodic_limits[1] - periodic_limits[0]) + periodic_limits[0] log_det_J = torch.zeros(y.shape[0], dtype=y.dtype, device=y.device) return x, log_det_J