Source code for tfep.nn.transformers.mixed
#!/usr/bin/env python
# =============================================================================
# MODULE DOCSTRING
# =============================================================================
"""
A transformer applying different transformers to different features.
"""
# =============================================================================
# GLOBAL IMPORTS
# =============================================================================
from collections.abc import Sequence
import torch
from tfep.nn.transformers.transformer import MAFTransformer
from tfep.utils.misc import ensure_tensor_sequence
# =============================================================================
# MIXED TRANSFORMER
# =============================================================================
[docs]
class MixedTransformer(MAFTransformer):
"""A transformer applying different transformers to different features."""
[docs]
def __init__(
self,
transformers : Sequence[MAFTransformer],
indices : Sequence[Sequence[int]],
):
"""Constructor.
Parameters
----------
transformers : Sequence[MAFTransformer].
The transformers to mix.
indices : Sequence[Sequence[int]]
A list of length ``len(transformers)``. ``indices[i]`` is another
list containing the indices of the input features for the ``i``-th
transformer. The sum of all the lengths must equal the number of
features.
"""
super().__init__()
# Input checking.
if len(transformers) < 2:
raise ValueError('The number of transformers must be greater than 1.')
if len(transformers) != len(indices):
raise ValueError('The number of elements in indices must equal that in transformers.')
self._transformers = transformers
# Save the indices into buffers.
for idx, ind in enumerate(indices):
self.register_buffer(f'_indices{idx}', ensure_tensor_sequence(ind))
# Cache the starting and ending indices to split the parameters.
par_lengths = [len(transformer.get_identity_parameters(len(ind)))
for transformer, ind in zip(transformers, indices)]
split_indices = torch.cumsum(torch.tensor(par_lengths[:-1]), dim=0)
self.register_buffer('_parameters_split_indices', split_indices)
[docs]
def forward(self, x: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Apply the transformation.
Parameters
----------
x : torch.Tensor
Shape ``(batch_size, n_features)``. The input features.
parameters : torch.Tensor
Shape ``(batch_size, n_parameters)``. The parameters for the
transformers expected grouped by transformer (i.e., first all
parameters for the first transformer, then those of the second one
etc.).
Returns
-------
y : torch.Tensor
Shape ``(batch_size, n_features)``. The transformed vectors.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The logarithm of the absolute value of the Jacobian
determinant ``dy / dx``.
"""
return self._run(x, parameters, inverse=False)
[docs]
def inverse(self, y: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Reverse the transformation.
Parameters
----------
y : torch.Tensor
Shape ``(batch_size, n_features)``. The input features.
parameters : torch.Tensor
Shape ``(batch_size, n_parameters)``. The parameters for the
transformers expected grouped by transformer (i.e., first all
parameters for the first transformer, then those of the second one
etc.).
Returns
-------
x : torch.Tensor
Shape ``(batch_size, n_features)``. The transformed vectors.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The logarithm of the absolute value of the Jacobian
determinant ``dx / dy``.
"""
return self._run(y, parameters, inverse=True)
[docs]
def get_identity_parameters(self, n_features: int) -> torch.Tensor:
"""Return the value of the parameters that makes this the identity function.
This can be used to initialize the normalizing flow to perform the identity
transformation.
Parameters
----------
n_features : int
The dimension of the input vector passed to the transformer.
Returns
-------
parameters : torch.Tensor
Shape ``(n_parameters,)``. The parameters for the identity function.
"""
parameters = [transformer.get_identity_parameters(len(indices))
for transformer, indices in zip(self._transformers, self._indices)]
return torch.cat(parameters, dim=-1)
[docs]
def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
"""Returns the degrees associated to the conditioner's output.
Parameters
----------
degrees_in : torch.Tensor
Shape ``(n_transformed_features,)``. The autoregressive degrees
associated to the features provided as input to the transformer.
Returns
-------
degrees_out : torch.Tensor
Shape ``(n_parameters,)``. The autoregressive degrees associated
to each output of the conditioner that will be fed to the
transformer as parameters.
"""
degrees_out = [transformer.get_degrees_out(degrees_in[indices])
for transformer, indices in zip(self._transformers, self._indices)]
return torch.cat(degrees_out, dim=-1)
@property
def _indices(self):
"""Construct a list of buffers."""
indices = []
for idx, transformer in enumerate(self._transformers):
indices.append(getattr(self, f'_indices{idx}'))
return indices
def _run(self, x, parameters, inverse):
"""Execute the transformation."""
# Avoid in place modification for the result.
y = torch.empty_like(x)
cumulative_log_det_J = 0.0
# Split the parameters by transformer.
parameters = torch.tensor_split(parameters, self._parameters_split_indices, dim=1)
# Run transformers.
for idx, (transformer, par) in enumerate(zip(self._transformers, parameters)):
indices = getattr(self, f'_indices{idx}')
if inverse:
y[:, indices], log_det_J = transformer.inverse(x[:, indices], par)
else:
y[:, indices], log_det_J = transformer(x[:, indices], par)
cumulative_log_det_J = cumulative_log_det_J + log_det_J
return y, cumulative_log_det_J