Source code for tfep.utils.geometry

#!/usr/bin/env python


# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Math and geometry utility functions to manipulate coordinates.
"""


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from typing import Dict, List, Literal, Optional, Tuple

import torch

from tfep.utils.math import batchwise_dot, batchwise_outer


# =============================================================================
# INTERNAL COORDINATES
# =============================================================================

[docs] def pdist(x, pairs=None, return_diff=False): """Compute p-norm distances between pairs of row vectors. In comparison to ``torch.nn.functional.pdist``, the function can handle batches and compute distances between a subset of all pairs. Only the Euclidean norm is currently supported. Parameters ---------- x : torch.Tensor Positions of particles with shape ``(batch_size, n_particles, D)``, where ``D`` is the dimensionality of the vector space. pairs : torch.Tensor, optional A tensor of shape ``(2, n_pairs)``. For each batch sample, the function will compute the ``i``-th distance between the ``pairs[0, i]``-th and the ``pairs[1, i]``-th atoms. If not passed, all pairwise distances are computed. return_diff : bool, optional If ``True``, the difference vector between pairs of particles are also returned. Returns ------- distances : torch.Tensor This has shape ``(batch_size, n_pairs)``, and ``distances[b, i]`` is the distance between the particles of the ``i``-th pair for the ``b``-th batch sample diff : torch.Tensor, optional This has shape ``(batch_size, n_pairs, 3)``, and ``diff[b, i]`` is the vector ``p1-p0``, where ``pX`` is the position of particle ``pairs[b, X]``. This is returned only if ``return_diff`` is ``True``. """ n_particles = x.shape[-2] if pairs is None: pairs = torch.triu_indices(n_particles, n_particles, offset=1) diff = x[:, pairs[1]] - x[:, pairs[0]] distances = torch.sqrt(torch.sum(diff**2, dim=-1)) if return_diff: return distances, diff return distances
[docs] def vector_vector_angle(x1, x2): """Return the angle in radians between a two vectors. If both ``x1`` and ``x2`` have multiple vectors, the angles are computed in a batchwise fashion. Parameters ---------- x1 : torch.Tensor A tensor of shape ``(*, D)``, where ``D`` is the vector dimensionality. x2 : torch.Tensor A tensor of shape ``(*, D)``, where ``D`` is the vector dimensionality. Returns ------- angles : torch.Tensor A tensor of shape ``(*,)``. As an example, if both inputs have shape ``(batch_size, D)``, then ``angles`` has shape ``(batch_size,)`` and ``angles[i]`` is the angle between vectors ``x1[i]`` and ``x2[i]``. The angle is from 0 to pi. """ x1_norm = torch.linalg.vector_norm(x1, dim=-1) x2_norm = torch.linalg.vector_norm(x2, dim=-1) cos_theta = batchwise_dot(x1, x2) / (x1_norm * x2_norm) # Catch round-offs. cos_theta = torch.clamp(cos_theta, min=-1, max=1) return torch.acos(cos_theta)
[docs] def vector_plane_angle(x, plane): """Return the angle in radians between a batch of vectors and another vector. Parameters ---------- x : torch.Tensor A tensor of shape ``(batch_size, N)`` or ``(N,)``. plane : torch.Tensor A tensor of shape ``(N,)`` that represent a normal vector to the plane. Returns ------- angle : torch.Tensor A tensor of shape ``(batch_size,)`` where ``angle[i]`` is the angle between vector ``x[i]`` and plane ``plane``. """ x_norm = torch.linalg.vector_norm(x, dim=-1) plane_norm = torch.linalg.vector_norm(plane, dim=-1) cos_theta = batchwise_dot(x, plane) / (x_norm * plane_norm) # Catch round-offs. cos_theta = torch.clamp(cos_theta, min=-1, max=1) return torch.asin(cos_theta) # asin(x) = pi/2 - acos(x).
[docs] def proper_dihedral_angle(x1, x2, x3): """Compute the proper dihedral angle between the plane ``x1``-``x2`` and ``x2``-``x3``. If both ``x1``, ``x2``, and ``x3`` have multiple vectors, the angles are computed in a batchwise fashion. In the description of the parameters, we will use the example of four atoms at positions p0, p1, p2, and p3. Parameters ---------- x1 : torch.Tensor The vector p1 - p0 with shape ``(*, D)``, where ``D`` is the vector dimensionality. x2 : torch.Tensor The vector p2 - p1 with shape ``(*, D)``, where ``D`` is the vector dimensionality. x3 : torch.Tensor The vector p3 - p2 with shape ``(*, D)``, where ``D`` is the vector dimensionality. Returns ------- dihedrals : torch.Tensor A tensor of shape ``(*,)``. As an example, if all inputs have shape ``(batch_size, D)``, then ``dihedrals`` has shape ``(batch_size,)`` and ``dihedrals[i]`` is the angle between the planes ``x1[i]``-``x2[i]`` and ``x2[i]``-``x3[i]``. """ # The implementation is from Praxeolitic. # see: https://stackoverflow.com/questions/20305272/dihedral-torsion-angle-from-four-points-in-cartesian-coordinates-in-python x1 = -x1 # normalize x2 so that it does not influence magnitude of vector # rejections that come next x2 = x2 / torch.linalg.vector_norm(x2, dim=-1, keepdim=True) # vector rejections # v = projection of x1 onto plane perpendicular to x2 # = x1 minus component that aligns with x2 # w = projection of x3 onto plane perpendicular to x2 # = x3 minus component that aligns with x2 v = x1 - batchwise_dot(x1, x2, keepdim=True) * x2 w = x3 - batchwise_dot(x3, x2, keepdim=True) * x2 # angle between v and w in a plane is the torsion angle # v and w may not be normalized but that's fine since tan is y/x x = batchwise_dot(v, w) x2_cross_v = torch.cross(x2, v, dim=-1) y = batchwise_dot(x2_cross_v, w) return torch.atan2(y, x)
# ============================================================================= # ROTATION # =============================================================================
[docs] def rotation_matrix_3d(angles, directions): """Return the matrix rotating vectors for the given angle about a direction. The rotation matrix is built using Rodrigues' rotation formula. Parameters ---------- angles : torch.Tensor A tensor of shape ``(batch_size,)``. directions : torch.Tensor A tensor of shape ``(batch_size, 3)`` or ``(3,)``. Returns ------- R : torch.Tensor ``R[i]`` is the 3 by 3 matrix rotating by the angle ``angles[i]`` about the vector ``directions[i]``. """ batch_size = len(angles) sina = torch.sin(angles) cosa = torch.cos(angles) # unit rotation vectors (batch_size, 3). k = torch.nn.functional.normalize(directions, dim=-1) if len(k.shape) < 2: k = k.unsqueeze(0) # Reshape cosa to have (batch_size, 1, 1) dimension. cosa = cosa.unsqueeze(-1).unsqueeze(-1) # R[i] is cosa[i] * torch.eye(3). R = cosa * torch.eye(3).expand(batch_size, 3, 3).to(cosa) # New term of R[i] is outer(k[i], k[i]) * (1 - cosa[i]). R = R + (1 - cosa) * batchwise_outer(k, k) # Last term of R[i] is cross_product_matrix(k[i]) * sina[i] sina_k = sina.unsqueeze(-1) * k # cross_matrix has shape (3, 3, batch_size) zeros = torch.zeros_like(angles) cross_matrix = torch.stack([ torch.stack([zeros, -sina_k[:,2], sina_k[:,1]]), torch.stack([sina_k[:,2], zeros, -sina_k[:,0]]), torch.stack([-sina_k[:,1], sina_k[:,0], zeros]), ]) # Put batch_size back at the beginning to sum correctly. R = R + cross_matrix.permute(2, 0, 1) return R
[docs] def batchwise_rotate(x, rotation_matrices, inverse=False): """Rotate a batch of configurations with their respective rotation matrix. Parameters ---------- x : torch.Tensor A tensor of shape ``(batch_size, n_vectors, 3)``. rotation_matrices : torch.Tensor A tensor of shape ``(batch_size, 3, 3)``. inverse : bool, optional If ``True`` the inverse rotation is performed. Returns ------- y : torch.Tensor A tensor of shape ``(batch_size, n_vectors, 3))`` where ``y[b][i]`` is the result of rotating vector ``x[b][i]`` with the rotation matrix ``rotation_matrices[b]``. """ if inverse: return torch.bmm(x, rotation_matrices) else: return torch.bmm(x, rotation_matrices.permute(0, 2, 1))
# ============================================================================= # COORDINATE TRANSFORMATIONS # ============================================================================= # Map the name of an axis to its 3D unit vector representation. We instantiate # the tensor in get_axis_from_name to make sure it is represented by the default # floating type, which might not be set on import. _AXIS_NAME_TO_VECTOR: Dict[Literal['x', 'y', 'z'], List] = { 'x': [1.0, 0.0, 0.0], 'y': [0.0, 1.0, 0.0], 'z': [0.0, 0.0, 1.0], }
[docs] def get_axis_from_name(name: Literal['x', 'y', 'z']) -> torch.Tensor: """Return the 3D vector representation of an axis. Parameters ---------- name : Literal['x', 'y', 'z'] The name of the axis. Returns ------- axis : torch.Tensor Shape ``(3,)``. The unit vector representation of the axis. """ return torch.tensor(_AXIS_NAME_TO_VECTOR[name])
[docs] def reference_frame_rotation_matrix( axis_atom_positions: torch.Tensor, plane_atom_positions: torch.Tensor, axis: torch.Tensor, plane_axis: torch.Tensor, plane_normal: Optional[torch.Tensor] = None, project_on_positive_axis: bool = False ) -> torch.Tensor: """Return the rotation matrix required to rotate the frame of reference based on two atoms. After the rotation matrix is applied to the coordinates, ``axis_atom_positions`` lie on the given ``axis`` vector while ``plane_atom_positions`` lie on the plane spanned by the ``axis`` and ``plane_axis`` vectors. Parameters ---------- axis_atom_positions : torch.Tensor Shape ``(batch_size, 3)``. The position of the atom placed on ``axis``. plane_atom_positions : torch.Tensor Shape ``(batch_size, 3)``. The position of the atom placed on the ``axis``-``plane_axis`` plane. axis : torch.Tensor Shape ``(3,)``. The axis on which to the axis atom is placed. Must be a unit vector. plane_axis : torch.Tensor, Shape ``(3,)``. The second axis used to determine the plane where the plane atom is placed. Must be a unit vector and not parallel to ``axis``. plane_normal : Optional[torch.Tensor] The vector normal to ``axis`` and ``plane_axis``. If not given, it is computed here. project_on_positive_axis : bool If ``True``, the axis atom is rotated so that it always lies on the positive ``axis``. Otherwise, it is rotated on the positive or negative ``axis`` based on whichever is closest. Note that if this is ``True``, a transformation that flips the sign of the coordinate of the axis atom might become impossible to invert in practice. Returns ------- rotation_matrices : torch.Tensor Shape ``(batch_size, 3, 3)``. The rotation matrices. Examples -------- >>> # Initialize the coordinates. >>> batch_size, n_atoms = 2, 4 >>> coordinates = torch.randn(batch_size, n_atoms, 3) >>> # Fix the orientation of the coordiante frames based on the 2nd and 4th atoms. >>> axis_atom_pos = coordinates[:, 1] >>> plane_atom_pos = coordinates[:, 3] >>> rotation_matrices = reference_frame_rotation_matrix( ... axis_atom_pos, ... plane_atom_pos, ... axis=torch.tensor([1.0, 0, 0]), # axis atom lies on x-axis ... plane_axis=torch.tensor([0.0, 0, 1]), # plane atomlies on x-z plane ... ) ... >>> # Rotate the coordinates. >>> new_coordinates = batchwise_rotate(coordinates, rotation_matrices) >>> # Reverse the change of reference frame. >>> old_coordinates = batchwise_rotate(new_coordinates, rotation_matrices, inverse=True) """ # Default argument. if plane_normal is None: plane_normal = torch.cross(axis, plane_axis, dim=0) # Find the direction perpendicular to the plane formed by the axis atom, # and the axis. rotation_vectors has shape (batch_size, 3). rotation_vectors = torch.cross(axis_atom_positions, axis.unsqueeze(0), dim=1) # If axis_atom_positions lies exactly on axis, any perpendicular vector # will do. Shape (batch_size,). is_parallel = torch.isclose(rotation_vectors, torch.zeros(1)).all(dim=1) rotation_vectors[is_parallel] = torch.cross(plane_axis, axis, dim=0) # Find the first rotation angle. r1_angle has shape (batch_size,). r1_angles = vector_vector_angle(axis_atom_positions, axis) # r1_angles goes from 0 to pi. We want to rotate the point onto the # negative/positive axis, depending which is closest. if not project_on_positive_axis: r1_angles = r1_angles - torch.pi * (r1_angles > torch.pi/2).to(r1_angles.dtype) # This are the rotation matrices that bring the axis points onto the axis. r1_rotation_matrices = rotation_matrix_3d(r1_angles, rotation_vectors) # To bring the plane atom in position, we perform a rotation about # axis so that we don't modify the position of the axis atom. We # perform the first rotation only on the atom position that will # determine the next rotation matrix for now so that we run only # a single matmul on all atoms. plane_points = plane_atom_positions.unsqueeze(1) plane_points = batchwise_rotate(plane_points, r1_rotation_matrices) plane_points = plane_points.squeeze(1) # Project the atom on the plane perpendicular to the rotation axis plane # to measure the rotation angle. plane_points = plane_points - axis*batchwise_dot(plane_points, axis, keepdim=True) r2_angles = vector_plane_angle(plane_points, plane_normal) # r2_angles will be positive in the octants where plane_normal lies # and negative in the opposite direction but the rotation happens # counterclockwise/clockwise with positive/negative angle so we need # to fix the sign of the angle based on where it is. r2_angles_sign = -torch.sign(batchwise_dot(plane_points, plane_axis)) r2_rotation_matrices = rotation_matrix_3d(r2_angles_sign * r2_angles, axis) # Now build the rotation composition. rotation_matrices = torch.bmm(r2_rotation_matrices, r1_rotation_matrices) return rotation_matrices
[docs] def cartesian_to_polar(x: torch.Tensor, y: torch.Tensor, return_log_det_J: bool = False) -> Tuple[torch.Tensor]: """Transform Cartesian coordinates into polar. Parameters ---------- x : torch.Tensor Shape ``(batch_size,)``. The x Cartesian coordinate. y : torch.Tensor Shape ``(batch_size,)``. The y Cartesian coordinate. return_log_det_J: bool, optional If ``True``, the absolute value of the Jacobian determinant of the transformation is also returned. Returns ------- r : torch.Tensor Shape ``(batch_size,)``. The radius coordinate. angle : torch.Tensor Shape ``(batch_size,)``. The angle coordinate in radians. log_det_J : torch.Tensor, optional The absolute value of the Jacobian determinant of the transformation. """ r = (x.pow(2) + y.pow(2)).sqrt() angle = torch.atan2(y, x) if return_log_det_J: return r, angle, -torch.log(r) return r, angle
[docs] def polar_to_cartesian(r: torch.Tensor, angle: torch.Tensor, return_log_det_J: bool = False) -> Tuple[torch.Tensor]: """Transform polar coordinates into Cartesian. Parameters ---------- r : torch.Tensor Shape ``(batch_size,)``. The radius coordinate. angle : torch.Tensor Shape ``(batch_size,)``. The angle coordinate in radians. return_log_det_J: bool, optional If ``True``, the absolute value of the Jacobian determinant of the transformation is also returned. Returns ------- x : torch.Tensor Shape ``(batch_size,)``. The x Cartesian coordinate. y : torch.Tensor Shape ``(batch_size,)``. The y Cartesian coordinate. log_det_J : torch.Tensor, optional The absolute value of the Jacobian determinant of the transformation. """ x = r * torch.cos(angle) y = r * torch.sin(angle) if return_log_det_J: return x, y, torch.log(r) return x, y