Source code for tfep.nn.transformers.quatprod
#!/usr/bin/env python
# =============================================================================
# MODULE DOCSTRING
# =============================================================================
"""
Quaternion product transformation for autoregressive normalizing flows.
"""
# =============================================================================
# GLOBAL IMPORTS
# =============================================================================
import torch
from tfep.nn.transformers.transformer import MAFTransformer
# =============================================================================
# QUATERNION PRODUCT TRANSFORMER
# =============================================================================
[docs]
class QuaternionProductTransformer(MAFTransformer):
r"""Quaternion product transformer.
This is a volume-preserving transformation that can be applied to
quaternions. For each (normalized) quaternion in the input, the conditioner
must provide a 4-dimensional vector (possibly unnormalized). As quaternions
typically model the orientation of a molecule, the transformation is
equivalent to applying a separate rigid rotation to each molecule and thus
has a unit Jacobian.
"""
[docs]
def forward(self, x: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Apply the transformation.
Parameters
----------
x : torch.Tensor
Shape ``(batch_size, n_quaternions*4)``. The quaternions elements
are contiguous (i.e., the first and second input quaternions are
``x[:4]`` and ``x[4:8]``.
parameters : torch.Tensor
Shape ``(batch_size, n_quaternions*4)``. The parameters interpreted
as (unnormalized) quaternions that will multiply those in ``x``.
These are normalized in the function.
Returns
-------
y : torch.Tensor
Shape ``(batch_size, n_quaternions*4)``. The transformed normalized
quaternions.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The logarithm of the absolute value of the Jacobian
determinant ``dy / dx`` (i.e., always zero).
"""
# roma is an optional dependency at the moment
import roma
# From (batch, n_quaternions*4) to (batch*n_quaternions, 4).
batch_size = x.shape[0]
x = x.reshape(-1, 4)
parameters = parameters.reshape(-1, 4)
# Transform.
y = roma.quat_product(roma.quat_normalize(parameters), x)
log_det_J = torch.zeros(batch_size).to(x)
return y.reshape(batch_size, -1), log_det_J
[docs]
def inverse(self, y: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Reverse the transformation.
Parameters
----------
y : torch.Tensor
Shape ``(batch_size, n_quaternions*4)``. The quaternions elements
are contiguous (i.e., the first and second input quaternions are
``y[:4]`` and ``y[4:8]``.
parameters : torch.Tensor
Shape ``(batch_size, n_quaternions*4)``. The parameters interpreted
as (unnormalized) quaternions that will multiply those in ``y``.
These are normalized in the function.
Returns
-------
x : torch.Tensor
Shape ``(batch_size, n_quaternions*4)``. The transformed normalized
quaternions.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The logarithm of the absolute value of the Jacobian
determinant ``dy / dx`` (i.e., always zero).
"""
# roma is an optional dependency at the moment
import roma
# From (batch, n_quaternions*4) to (batch*n_quaternions, 4).
batch_size = y.shape[0]
y = y.reshape(-1, 4)
parameters = parameters.reshape(-1, 4)
# Transform.
x = roma.quat_product(roma.quat_conjugation(roma.quat_normalize(parameters)), y)
log_det_J = torch.zeros(batch_size).to(y)
return x.reshape(batch_size, -1), log_det_J
[docs]
def get_identity_parameters(self, n_features: int) -> torch.Tensor:
"""Return the value of the parameters that makes this the identity function.
This can be used to initialize the normalizing flow to perform the identity
transformation.
Parameters
----------
n_features : int
The dimension of the input vector passed to the transformer. Must
be divisible by 4.
Returns
-------
parameters : torch.Tensor
A tensor of shape ``(n_features,)`` representing the parameter
vector to perform the identity function with a Moebius transformer.
"""
# roma is an optional dependency at the moment
import roma
return roma.identity_quat(n_features//4).flatten()
[docs]
def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
"""Returns the degrees associated to the conditioner's output.
Parameters
----------
degrees_in : torch.Tensor
Shape ``(n_transformed_features,)``. The autoregressive degrees
associated to the features provided as input to the transformer.
Returns
-------
degrees_out : torch.Tensor
Shape ``(n_parameters,)``. The autoregressive degrees associated
to each output of the conditioner that will be fed to the
transformer as parameters.
"""
return degrees_in.detach().clone()