Source code for tfep.nn.transformers.moebius
#!/usr/bin/env python
# =============================================================================
# MODULE DOCSTRING
# =============================================================================
"""
Moebius transformation for autoregressive normalizing flows.
"""
# =============================================================================
# GLOBAL IMPORTS
# =============================================================================
import torch
from tfep.nn.transformers.transformer import MAFTransformer
from tfep.utils.math import batchwise_dot, batchwise_outer
# =============================================================================
# MOEBIUS TRANSFORMERS
# =============================================================================
[docs]
class MoebiusTransformer(MAFTransformer):
r"""Moebius transformer.
This implements a generalization of the Moebius transformation proposed in
[1, 2] to non-unit spheres. The transformer will expand/contract the distribution
on the sphere of radius :math:`r`, where :math:`r` is the norm of the input
vector.
The transformation has the form
:math:`y = \frac{||x||^2 - ||w||^2}{||x - w||^2} (x - w) - w`
where :math:`y, x, w` are all ``dimension``-dimensional vectors and
:math:`||w|| < ||x||`. The function automatically rescales the ``w`` argument
following the same strategy as in [2] to satisfy the condition on the norm.
Consequently, ``w``s of any norm can be passed.
The implementation of the transformation on the unit sphere is slightly more
efficient and can be toggled with the ``unit_sphere`` argument.
References
----------
[1] Kato S, McCullagh P. Moebius transformation and a Cauchy family
on the sphere. arXiv preprint arXiv:1510.07679. 2015 Oct 26.
[2] Rezende DJ, Papamakarios G, Racanière S, Albergo MS, Kanwar G,
Shanahan PE, Cranmer K. Normalizing Flows on Tori and Spheres.
arXiv preprint arXiv:2002.02428. 2020 Feb 6.
"""
[docs]
def __init__(self, dimension: int, max_radius: float = 0.99, unit_sphere: bool = False):
"""Constructor.
Parameters
----------
dimension : int
The dimensionality of the vectors in ``x`` and ``w``.
max_radius : float
Must be stringly less than 1. Rescaling of the ``w`` vectors will be
performed so that its maximum norm will be ``max_radius * |x|``.
unit_sphere : bool
If ``True``, the input vectors ``x`` are assumed to be on the unit sphere,
which makes the implementation slightly faster.
"""
super().__init__()
self.dimension = dimension
self.max_radius = max_radius
self.unit_sphere = unit_sphere
[docs]
def forward(self, x: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Apply the transformation.
Parameters
----------
x : torch.Tensor
Shape ``(batch_size, n_vectors*dimension)``. Contiguous elements of ``x``
are interpreted as vectors (i.e., the first and second input vectors are
``x[:dimension]`` and ``x[dimension:2*dimension]``.
parameters : torch.Tensor
Shape ``(batch_size, n_vectors*dimension)``. The transformation parameters.
These parameter vectors are automatically rescaled so that ``|w| < |x|``.
Returns
-------
y : torch.Tensor
Shape ``(batch_size, n_vectors*dimension)``. The transformed vectors.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The logarithm of the absolute value of the Jacobian
determinant ``dy / dx``.
"""
# From shape (batch, n_vectors*dimension) to (batch, n_vectors, dimension)
batch_size, n_features = x.shape
x = x.reshape(batch_size, -1, self.dimension)
parameters = parameters.reshape(batch_size, -1, self.dimension)
y, log_det_J = moebius_transformer(
x,
parameters,
max_radius=self.max_radius,
unit_sphere=self.unit_sphere
)
# From shape (batch, n_vectors, dimension) to (batch, n_vectors*dimension)
y = y.reshape(batch_size, n_features)
return y, log_det_J
[docs]
def inverse(self, y: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Reverse the transformation.
Parameters
----------
y : torch.Tensor
Shape ``(batch_size, n_vectors*dimension)``. Contiguous elements of ``y``
are interpreted as vectors (i.e., the first and second input vectors are
``y[:dimension]`` and ``y[dimension:2*dimension]``.
parameters : torch.Tensor
Shape ``(batch_size, n_vectors*dimension)``. The transformation parameters.
These parameter vectors are automatically rescaled so that ``|w| < |y|``.
Returns
-------
x : torch.Tensor
Shape ``(batch_size, n_vectors*dimension)``. The transformed vectors.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The logarithm of the absolute value of the Jacobian
determinant ``dx / dy``.
"""
# From shape (batch, n_vectors*dimension) to (batch, n_vectors, dimension)
batch_size, n_features = y.shape
y = y.reshape(batch_size, -1, self.dimension)
parameters = parameters.reshape(batch_size, -1, self.dimension)
x, log_det_J = moebius_transformer(
y,
-parameters,
max_radius=self.max_radius,
unit_sphere=self.unit_sphere
)
# From shape (batch, n_vectors, dimension) to (batch, n_vectors*dimension)
x = x.reshape(batch_size, n_features)
return x, log_det_J
[docs]
def get_identity_parameters(self, n_features: int) -> torch.Tensor:
"""Return the value of the parameters that makes this the identity function.
This can be used to initialize the normalizing flow to perform the identity
transformation.
Parameters
----------
n_features : int
The dimension of the input vector passed to the transformer.
Returns
-------
parameters : torch.Tensor
A tensor of shape ``(n_features,)`` representing the parameter
vector to perform the identity function with a Moebius transformer.
"""
return torch.zeros(size=(n_features,))
[docs]
def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
"""Returns the degrees associated to the conditioner's output.
Parameters
----------
degrees_in : torch.Tensor
Shape ``(n_transformed_features,)``. The autoregressive degrees
associated to the features provided as input to the transformer.
Returns
-------
degrees_out : torch.Tensor
Shape ``(n_parameters,)``. The autoregressive degrees associated
to each output of the conditioner that will be fed to the
transformer as parameters.
"""
return degrees_in.detach().clone()
[docs]
class SymmetrizedMoebiusTransformer(MAFTransformer):
r"""Symmetrized Moebius transformer.
This implements a generalization of the symmetrized Moebius transformation
proposed in [1] to non-unit spheres. The transformer will expand/contract
the distribution on the sphere of radius :math:`r`, where :math:`r` is the
norm of the input vector.
The transformation has the form
:math:`y = ||f(x; w)|| \frac{f(x; w) + f(x; -w)}{||f(x; w) + f(x; -w)||}`
where :math:`f` is the Moebius transform (see :class:``.MoebiusTransformer``),
and :math:`y, x, w` are all ``dimension``-dimensional vectors with
:math:`||w|| < ||x||`. The function automatically rescales the ``w`` argument
following the same strategy as in [2] to satisfy the condition on the norm.
Consequently, ``w``s of any norm can be passed.
The transformer can implement the identity function when :math:`w` is zero.
However, in this case, the gradient w.r.t. the parameters will also be zero
and thus the transformer will not be able to learn another function. To
avoid this :func:``.SymmetrizedMoebiusTransformer.get_identity_parameters``
returns a very small random tensor rather than exactly zero. How small is
controlled by the ``identity_eps`` argument.
References
----------
[1] Köhler J, Invernizzi M, De Haan P, Noé F. Rigid body flows for sampling
molecular crystal structures. arXiv preprint arXiv:2301.11355. 2023 Jan 26.
[2] Rezende DJ, Papamakarios G, Racanière S, Albergo MS, Kanwar G,
Shanahan PE, Cranmer K. Normalizing Flows on Tori and Spheres.
arXiv preprint arXiv:2002.02428. 2020 Feb 6.
"""
[docs]
def __init__(
self,
dimension: int,
max_radius: float = 0.99,
identity_eps: float = 1e-9,
):
"""Constructor.
Parameters
----------
dimension : int
The dimensionality of the ``x`` and ``w`` vectors.
max_radius : float
Must be stringly less than 1. Rescaling of the ``w`` vectors will be
performed so that its maximum norm will be ``max_radius * |x|``.
identity_eps : float
The maximum value randomly generated for the tensor elements in
:func:``.SymmetrizedMoebiusTransformer.get_identity_parameters``.
Set this to ``0.`` to implement the exact identity function.
"""
super().__init__()
self.dimension = dimension
self.max_radius = max_radius
self.identity_eps = identity_eps
[docs]
def forward(self, x: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Apply the transformation.
Parameters
----------
x : torch.Tensor
Shape ``(batch_size, n_vectors*dimension)``. Contiguous elements of ``x``
are interpreted as vectors (i.e., the first and second input vectors are
``x[:dimension]`` and ``x[dimension:2*dimension]``.
parameters : torch.Tensor
Shape ``(batch_size, n_vectors*dimension)``. The transformation parameters.
These parameter vectors are automatically rescaled so that ``|w| < |x|``.
Returns
-------
y : torch.Tensor
Shape ``(batch_size, n_vectors*dimension)``. The transformed vectors.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The logarithm of the absolute value of the Jacobian
determinant ``dy / dx``.
"""
# From shape (batch, n_vectors*dimension) to (batch, n_vectors, dimension)
batch_size, n_features = x.shape
x = x.reshape(batch_size, -1, self.dimension)
parameters = parameters.reshape(batch_size, -1, self.dimension)
y, log_det_J = symmetrized_moebius_transformer(
x,
parameters,
max_radius=self.max_radius,
)
# From shape (batch, n_vectors, dimension) to (batch, n_vectors*dimension)
y = y.reshape(batch_size, n_features)
return y, log_det_J
[docs]
def inverse(self, y: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]:
"""Reverse the transformation.
Parameters
----------
y : torch.Tensor
Shape ``(batch_size, n_vectors*dimension)``. Contiguous elements of ``y``
are interpreted as vectors (i.e., the first and second input vectors are
``y[:dimension]`` and ``y[dimension:2*dimension]``.
parameters : torch.Tensor
Shape ``(batch_size, n_vectors*dimension)``. The transformation parameters.
These parameter vectors are automatically rescaled so that ``|w| < |y|``.
Returns
-------
x : torch.Tensor
Shape ``(batch_size, n_vectors*dimension)``. The transformed vectors.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The logarithm of the absolute value of the Jacobian
determinant ``dx / dy``.
"""
# From shape (batch, n_vectors*dimension) to (batch, n_vectors, dimension)
batch_size, n_features = y.shape
y = y.reshape(batch_size, -1, self.dimension)
parameters = parameters.reshape(batch_size, -1, self.dimension)
x, log_det_J = symmetrized_moebius_transformer_inverse(
y,
parameters,
max_radius=self.max_radius,
)
# From shape (batch, n_vectors, dimension) to (batch, n_vectors*dimension)
x = x.reshape(batch_size, n_features)
return x, log_det_J
[docs]
def get_identity_parameters(self, n_features: int) -> torch.Tensor:
"""Return the value of the parameters that makes this the identity function.
This can be used to initialize the normalizing flow to perform the
identity transformation.
Parameters
----------
n_features : int
The dimension of the input vector passed to the transformer.
Returns
-------
w : torch.Tensor
A tensor of shape ``(n_features,)`` representing the parameter
vector to perform the identity function with a Moebius transformer.
"""
par = torch.rand(n_features)
return (2*par - 1) * self.identity_eps
[docs]
def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor:
"""Returns the degrees associated to the conditioner's output.
Parameters
----------
degrees_in : torch.Tensor
Shape ``(n_transformed_features,)``. The autoregressive degrees
associated to the features provided as input to the transformer.
Returns
-------
degrees_out : torch.Tensor
Shape ``(n_parameters,)``. The autoregressive degrees associated
to each output of the conditioner that will be fed to the
transformer as parameters.
"""
return degrees_in.detach().clone()
# =============================================================================
# FUNCTIONAL API
# =============================================================================
[docs]
def moebius_transformer(
x: torch.Tensor,
w: torch.Tensor,
max_radius: float = 0.99,
unit_sphere: bool = False,
return_log_det_J: bool = True
) -> tuple[torch.Tensor]:
r"""Moebius transformer.
This implements a generalization of the Moebius transformation proposed in
[1, 2] to non-unit spheres. The transformer will expand/contract the distribution
on the sphere of radius :math:`r`, where :math:`r` is the norm of the input
vector.
The transformation has the form
:math:`y = \frac{||x||^2 - ||w||^2}{||x - w||^2} (x - w) - w`
where :math:`y, x, w` are all ``dimension``-dimensional vectors and
:math:`||w|| < ||x||`. The function automatically rescales the ``w`` argument
following the same strategy as in [2] to satisfy the condition on the norm.
Consequently, ``w``s of any norm can be passed.
The implementation of the transformation on the unit sphere is slightly more
efficient and can be toggled with the ``unit_sphere`` argument.
Parameters
----------
x : torch.Tensor
Shape ``(batch_size, n_vectors, dimension)``. Input coordinates.
w : torch.Tensor
Shape ``(batch_size, n_vectors, dimension)``. The transformation parameters.
These parameter vectors are automatically rescaled so that ``|w| < |x|``.
max_radius : float
Must be strictly less than 1. Rescaling of the ``w`` vectors will be
performed so that its maximum norm will be ``max_radius * |x|``.
unit_sphere : bool
If ``True``, the input vectors ``x`` are assumed to be on the unit sphere,
which makes the implementation slightly faster.
return_log_det_J : bool, optional
Whether to return the ``log_det_J`` value. Default is ``True``.
Returns
-------
y : torch.Tensor
Shape ``(batch_size, n_vectors, dimension)``. The transformed vectors.
log_det_J : torch.Tensor, optional
Shape ``(batch_size,)``. The logarithm of the absolute value of the Jacobian
determinant ``dy / dx``. This is returned only if ``return_log_det_J``
is ``True``.
References
----------
[1] Kato S, McCullagh P. Moebius transformation and a Cauchy family
on the sphere. arXiv preprint arXiv:1510.07679. 2015 Oct 26.
[2] Rezende DJ, Papamakarios G, Racanière S, Albergo MS, Kanwar G,
Shanahan PE, Cranmer K. Normalizing Flows on Tori and Spheres.
arXiv preprint arXiv:2002.02428. 2020 Feb 6.
"""
batch_size, n_vectors, dimension = x.shape
# Compute the radius of the vectors.
w_norm = torch.linalg.norm(w, dim=-1, keepdim=True)
# First map the w vectors from R^d to the solid sphere of radius x_norms.
rescaling = max_radius / (1 + w_norm)
if not unit_sphere:
x_norm = torch.linalg.norm(x, dim=-1, keepdim=True)
rescaling = x_norm * rescaling
w = rescaling * w
w_norm = rescaling * w_norm
# Compute the transformed vectors.
if unit_sphere:
numerator = 1 - w_norm**2
else:
numerator = x_norm**2 - w_norm**2
diff = x - w
diff_norm = torch.linalg.norm(diff, dim=-1, keepdim=True)
y = numerator / diff_norm.pow(2) * diff - w
if not return_log_det_J:
return y
# Compute the log det Jacobian of the transformation on the unit sphere..
numerator = numerator.unsqueeze(-1)
diff_norm = diff_norm.unsqueeze(-1)
dd_outer = batchwise_outer(diff, diff)
eye = torch.eye(dimension).expand_as(dd_outer)
jac = numerator * (eye / diff_norm.pow(2) - 2 / diff_norm.pow(4) * dd_outer)
# Now compute the Jacobian of the transformation on the sphere of radius x_norm
if not unit_sphere:
x_norm_expand = x_norm.unsqueeze(-1)
jac2 = eye - batchwise_outer(x, x) / x_norm_expand**2
jac = torch.einsum("...ij, ...jk -> ...ik", jac, jac2) # Batchwise matrix multiplication.
jac = batchwise_outer(y, x) / x_norm_expand**2 + jac
# Compute the determinants of the blocks.
log_det_J = torch.linalg.slogdet(jac)[1]
log_det_J = log_det_J.sum(dim=-1)
return y, log_det_J
[docs]
def symmetrized_moebius_transformer(
x: torch.Tensor,
w: torch.Tensor,
max_radius: float = 0.99,
) -> tuple[torch.Tensor]:
r"""Symmetrized Moebius transformer.
This implements a generalization of the symmetrized Moebius transformation
proposed in [1] to non-unit spheres. The transformer will expand/contract
the distribution on the sphere of radius :math:`r`, where :math:`r` is the
norm of the input vector.
The transformation has the form
:math:`y = ||f(x; w)|| \frac{f(x; w) + f(x; -w)}{||f(x; w) + f(x; -w)||}`
where :math:`f` is the Moebius transform (see :class:``.MoebiusTransformer``),
and :math:`y, x, w` are all ``dimension``-dimensional vectors with
:math:`||w|| < ||x||`. The function automatically rescales the ``w`` argument
following the same strategy as in [2] to satisfy the condition on the norm.
Consequently, ``w``s of any norm can be passed.
Parameters
----------
x : torch.Tensor
Shape ``(batch_size, n_vectors, dimension)``. The input coordinates.
w : torch.Tensor
Shape ``(batch_size, n_vectors, dimension)``. The transformation parameters.
These parameter vectors are automatically rescaled so that ``|w| < |x|``.
max_radius : float
Must be strictly less than 1. Rescaling of the ``w`` vectors will be
performed so that its maximum norm will be ``max_radius * |x|``.
Returns
-------
y : torch.Tensor
Shape ``(batch_size, n_vectors, dimension)``. The transformed vectors.
log_det_J : torch.Tensor
Shape ``(batch_size,)``. The logarithm of the absolute value of the Jacobian
determinant ``dy / dx``.
References
----------
[1] Köhler J, Invernizzi M, De Haan P, Noé F. Rigid body flows for sampling
molecular crystal structures. arXiv preprint arXiv:2301.11355. 2023 Jan 26.
[2] Rezende DJ, Papamakarios G, Racanière S, Albergo MS, Kanwar G,
Shanahan PE, Cranmer K. Normalizing Flows on Tori and Spheres.
arXiv preprint arXiv:2002.02428. 2020 Feb 6.
"""
batch_size, n_vectors, dimension = x.shape
# Moebius transform.
f_w = moebius_transformer(x, w, max_radius, unit_sphere=False, return_log_det_J=False)
f_iw = moebius_transformer(x, -w, max_radius, unit_sphere=False, return_log_det_J=False)
f_symmetrized = f_w + f_iw
# Rescale to the sphere of radius ||x||
x_norm = torch.linalg.norm(x, dim=-1, keepdim=True)
f_symmetrized_norm = torch.linalg.norm(f_symmetrized, dim=-1, keepdim=True)
f_symmetrized_scaled = x_norm / f_symmetrized_norm * f_symmetrized
# Compute the Jacobian.
w = w.reshape(batch_size, -1, dimension)
w_norm = torch.linalg.norm(w, dim=-1, keepdim=True)
rescaling = max_radius / (1 + w_norm)
w = rescaling * w
w_norm = rescaling * w_norm
log_det_J = _symmetrized_moebius_transform_log_det_J(x / x_norm, w, w_norm**2)
return f_symmetrized_scaled, log_det_J
[docs]
def symmetrized_moebius_transformer_inverse(
x: torch.Tensor,
w: torch.Tensor,
max_radius: float = 0.99,
) -> tuple[torch.Tensor]:
r"""Inverse symmetrized Moebius transformer.
See :func:`.symmetrized_moebius_transformer` for the documentation.
"""
# We solve the inversion first on the unit sphere, and then project back.
x_norm = torch.linalg.norm(x, dim=-1, keepdim=True)
x_unit = x / x_norm
# Map parameter vector w to the solid unit sphere.
w_norm = torch.linalg.norm(w, dim=-1, keepdim=True)
rescaling = max_radius / (1 + w_norm)
w_unit = rescaling * w
w_unit_norm = rescaling * w_norm
# Change the coordinate system so that w = [r, 0, 0, ...] and x = [a, b, 0, ...] with b = -sqrt(1-a^2)
da = w_unit / w_unit_norm # First basis of the new coordinate system: w
a = batchwise_dot(x_unit, da, keepdim=True) # Project x on first basis.
db = x_unit - a * da # Second orthogonal basis: p - proj(p, q)
b = torch.linalg.norm(db, dim=-1, keepdim=True)
db = db / b # Normalize basis.
# Now the inversion is analytically solvable following Köhler J, Invernizzi M,
# De Haan P, Noé F. Rigid body flows for sampling molecular crystal structures.
# arXiv preprint arXiv:2301.11355. 2023 Jan 26.
r2 = w_unit_norm**2
numer = - a * (r2 + 1.0)
denom = torch.sqrt(1 + r2**2 + r2 * (4*a**2 - 2))
a_inv = numer / denom
b_inv = - torch.sqrt(1 - a_inv**2)
# Project back on the unit hyper-sphere.
x_unit_inv = - (a_inv*da + b_inv*db)
# Compute change of volume as the negative of the forward transformation.
# The contribution from dividing and multiplying by ||x|| cancel out.
log_det_J = - _symmetrized_moebius_transform_log_det_J(x_unit_inv, w_unit, r2)
# Project back on the hyper-sphere of radius ||x||.
x_inv = x_norm * x_unit_inv
return x_inv, log_det_J
# =============================================================================
# INTERNAL USE
# =============================================================================
def _symmetrized_moebius_transform_log_det_J(x, w, r2):
"""Compute the log_det_J of the symmetrized Moebius transform.
This is based on: Köhler J, Invernizzi M, De Haan P, Noé F. Rigid body flows
for sampling molecular crystal structures. arXiv preprint arXiv:2301.11355.
2023 Jan 26.
``x`` (input) and ``w`` (parameter) must be within the unit sphere and have
shape (batch_size, n_vectors, dimension).
``r2`` is the norm of ``w`` with shape (batch_size, n_vectors, 1).
Return log_det_J with shape (batch_size,).
"""
dimension = x.shape[-1]
qy2 = r2 - batchwise_dot(x, w, keepdim=True)**2
numer = (1 - r2) * (1 + r2)**(dimension-1)
denom = (4*qy2 + (1 - r2)**2)**(dimension / 2)
dV = numer / denom
log_det_J = torch.log(dV).squeeze(-1).sum(dim=1)
return log_det_J