Source code for tfep.utils.misc

#!/usr/bin/env python

# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""Miscellanea utility functions."""


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from collections.abc import Sequence
import contextlib
import os
from typing import Union

import numpy as np
import pint
import torch


# =============================================================================
# CONVERSION UTILITY FUNCTIONS
# =============================================================================

[docs] def flattened_to_atom(positions, space_dimension=3): """Compute a positions from flattened to standard atom format. The function takes a configuration (or a batch of configurations) with shape ``(n_atoms*space_dimension)`` and converts them into the standard shape ``(n_atoms, space_dimension)``. It converts both ``torch.Tensors`` and ``numpy.ndarray``, with and without ``pint`` units. Parameters ---------- positions : torch.Tensor, numpy.ndarray, or pint.Quantity The input can have the following shapes: ``(batch_size, n_atoms*space_dimension)`` or ``(n_atoms * space_dimension,)``. space_dimension : int, optional The dimensionality of the phase space. Returns ------- reshaped_positions : torch.Tensor, numpy.ndarray, or pint.Quantity A view of the original tensor or array with shape ``(batch_size, n_atoms, 3)`` or ``(n_atoms, 3)``. """ n_atoms = positions.shape[-1] // space_dimension if len(positions.shape) > 1: batch_size = positions.shape[0] standard_shape = (batch_size, n_atoms, space_dimension) else: standard_shape = (n_atoms, space_dimension) return positions.reshape(standard_shape)
[docs] def atom_to_flattened(positions): """Compute a positions from standard atom to flattened format. The inverse operation of :func:`.flattened_to_atom`. Parameters ---------- positions : torch.Tensor, numpy.ndarray, or pint.Quantity The input can have the following shapes: ``(batch_size, n_atoms, N)`` or ``(n_atoms, N)``, where ``N`` is the dimensionality of the phase space. Returns ------- reshaped_positions : torch.Tensor, numpy.ndarray, or pint.Quantity A view of the original tensor or array with shape ``(batch_size, n_atoms*N)`` or ``(n_atoms*N)``. See Also -------- flattened_to_atom """ n_atoms = positions.shape[-2] space_dimension = positions.shape[-1] if len(positions.shape) > 2: batch_size = positions.shape[0] flattened_shape = (batch_size, n_atoms*space_dimension) else: flattened_shape = (n_atoms*space_dimension,) return positions.reshape(flattened_shape)
[docs] def atom_to_flattened_indices(atom_indices, space_dimension=3): """Convert atom indices to the indices of the corresponding degrees of freedom in flattened format. Parameters ---------- atom_indices : torch.Tensor or numpy.ndarray The input can have the following shapes: ``(batch_size, n_atoms)`` or ``(n_atoms,)``. space_dimension : int, optional The dimensionality of the coordinate space (default is 3). Returns ------- flattened_indices : torch.Tensor, numpy.ndarray, or pint.Quantity The indices of the corresponding degrees of freedom in flattened format with shape ``(batch_size, n_atoms*3)`` or ``(n_atoms*3,)``. Examples -------- The function works both with ``Tensor``s and numpy arrays. >>> atom_indices_np = np.array([0, 2]) >>> list(atom_to_flattened_indices(atom_indices_np)) [0, 1, 2, 6, 7, 8] >>> atom_indices_torch = torch.tensor(atom_indices_np) >>> atom_to_flattened_indices(atom_indices_torch, space_dimension=2).tolist() [0, 1, 4, 5] Batches of indices are supported. >>> atom_indices = torch.tensor([[0, 2], [1, 3]]) >>> atom_to_flattened_indices(atom_indices).tolist() [[0, 1, 2, 6, 7, 8], [3, 4, 5, 9, 10, 11]] """ is_numpy = isinstance(atom_indices, np.ndarray) # else is Tensor. is_not_batch = len(atom_indices.shape) == 1 flattened_indices = atom_indices * space_dimension # Add fake dimension to avoid code branching if necessary. if is_not_batch: if is_numpy: flattened_indices = np.expand_dims(flattened_indices, axis=0) else: flattened_indices = torch.unsqueeze(flattened_indices, dim=0) # Each indices array has three times the number of indices. if is_numpy: flattened_indices = np.repeat(flattened_indices, space_dimension, axis=1) else: # Tensor. flattened_indices = torch.repeat_interleave(flattened_indices, space_dimension, dim=1) # Update indices of other dimensions. for i in range(1, space_dimension): flattened_indices[:, i::space_dimension] += i if is_not_batch: return flattened_indices[0] return flattened_indices
[docs] def ensure_tensor_sequence(x: Union[str, int, float, Sequence], dtype=None) -> torch.Tensor: r"""If x is a sequence, return it as a torch.Tensor without copying the memory (if possible). Parameters ---------- x : str, int, float, or Sequence The input. Sequences (that are not strings) are converted to ``torch.Tensor``\ s. dtype : dtype or None, optional If set, forces the tensor to a data type. Returns ------- converted_x : str, int, float, or torch.Tensor The input, eventually converted to a tensor. """ # as_tensor supports scalars but not None or strings. if not np.isscalar(x): try: x = torch.as_tensor(x, dtype=dtype) except (TypeError, RuntimeError): pass return x
[docs] def energies_array_to_tensor(energies, energy_unit=None, dtype=None): """Helper function to convert the a batch of energies from numpy array to PyTorch tensor. Parameters ---------- energies : pint.Quantity The energies with shape ``(batch_size,)`` with units. energy_unit : pint.Unit, optional The units of energy used in the returned energies. If ``None``, no conversion is performed, and the energy will be in the same units as the input. dtype : type, optional The ``torch`` data type to be used for the returned ``Tensor``. Returns ------- energies : torch.Tensor The energies with shape ``(batch_size,)`` as a unitless ``Tensor`` in units of ``energy_unit``. """ if energy_unit is not None: try: # Convert to Hartree/mol. energies = (energies * energy_unit._REGISTRY.avogadro_constant).to(energy_unit) except pint.errors.DimensionalityError: energies = energies.to(energy_unit) # Reconvert Pint array to tensor. return torch.tensor(energies.magnitude, dtype=dtype)
[docs] def forces_array_to_tensor(forces, distance_unit=None, energy_unit=None, dtype=None): """Helper function to convert the a batch of forces from numpy array to PyTorch tensor. ``distance_unit`` and ``energy_unit`` must be passed together. If they are both ``None`` no conversion is performed. If only one of them is ``None`` an error is raised. Parameters ---------- forces : pint.Quantity The forces with shape ``(batch_size, n_atoms, 3)`` with units. distance_unit : pint.Unit, optional The units of distance used in the returned forces. If ``None``, no conversion is performed, and the forces will be in the same units as the input. energy_unit : pint.Unit, optional The units of energy used in the returned forces. If ``None``, no conversion is performed, and the energy will be in the same units as the input. dtype : type, optional The ``torch`` data type to be used for the returned ``Tensor``. Returns ------- forces : torch.Tensor The forces with shape ``(batch_size, n_atoms*3)`` as a unitless ``Tensor`` in units of ``energy_unit/distance_unit``. Raises ------ ValueError If only one between ``distance_unit`` and ``energy_unit`` is passed. """ if (energy_unit is not None) and (distance_unit is not None): force_unit = energy_unit / distance_unit try: # Convert to Hartree/(Bohr mol). forces = (forces * force_unit._REGISTRY.avogadro_constant).to(force_unit) except pint.errors.DimensionalityError: forces = forces.to(force_unit) elif not ((energy_unit is None) and (distance_unit is None)): raise ValueError('Both or neither energy_unit and distance_unit must be passed.') # The tensor must be unitless and with shape (batch_size, n_atoms*3). forces = atom_to_flattened(forces) return torch.tensor(forces.magnitude, dtype=dtype)
[docs] def remove_and_shift_sorted_indices( indices: torch.Tensor, removed_indices: torch.Tensor, remove: bool = True, shift: bool = True, ) -> torch.Tensor: """Remove from ``indices`` the indices in ``removed_indices`` (by value). Both ``indices`` and ``removed_indices`` must be sorted tensors of non-negative integers. The indices in ``indices`` are (optionally) shifted so that it can be used to point to elements of an array where ``removed_indices`` have been removed. Parameters ---------- indices : torch.Tensor The tensor from which to remove indices. removed_indices : torch.Tensor The indices that must be removed from ``indices``. remove : bool, optional If ``indices`` and ``removed_indices`` do not overlap, and only shifting is necessary, this can be set to ``False``. Default ``True``. shift : bool, optional If ``False`` shifting the indices is not performed. Returns ------- out : torch.Tensor The ``indices`` tensor after removing and shifting the elements. Examples -------- >>> remove_and_shift_sorted_indices( ... indices=torch.tensor([0, 3, 9, 13]), ... removed_indices=torch.tensor([3, 12]), ... shift=False, ... ).tolist() [0, 9, 13] >>> remove_and_shift_sorted_indices( ... indices=torch.tensor([0, 3, 9, 13]), ... removed_indices=torch.tensor([3, 12]), ... shift=True, ... ).tolist() [0, 8, 11] """ insert_indices = torch.searchsorted(removed_indices, indices) # Remove. if remove: # The maximum index returned by searchsorted is len(removed_indices) so # we pad to avoid IndexErrors. We use a -1 since all elements of # indices must be non-negative padded_removed_indices = torch.nn.functional.pad( removed_indices, pad=(0, 1), value=-1) mask = padded_removed_indices[insert_indices] != indices indices = indices[mask] insert_indices = insert_indices[mask] # Shift. if shift: indices = indices - insert_indices return indices
# ============================================================================= # I/O # =============================================================================
[docs] @contextlib.contextmanager def temporary_cd(dir_path): """Context manager that temporarily sets the working directory. Parameters ---------- dir_path : str or None The path to the temporary working directory. If ``None``, the working directory is not changed. This might be useful to avoid branching code. """ if dir_path is None: yield else: old_dir_path = os.getcwd() os.chdir(dir_path) try: yield finally: os.chdir(old_dir_path)