Source code for tfep.potentials.base

#!/usr/bin/env python

# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Base class to implement potential energy functions.

"""


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from typing import Optional

import pint
import torch


# =============================================================================
# TORCH MODULE API
# =============================================================================

[docs] class PotentialBase(torch.nn.Module): """Base class for potential energy functions. This ``Module`` implements units related utilities to easily define default units and handle ``pint`` unit registries. To inherit from this class one needs to define the following class variables. - :attr:`~PotentialBase.DEFAULT_ENERGY_UNIT` - :attr:`~PotentialBase.DEFAULT_POSITIONS_UNIT` """ #: The default energy unit. DEFAULT_ENERGY_UNIT : str = '' #: The default positions unit. DEFAULT_POSITIONS_UNIT : str = ''
[docs] def __init__( self, positions_unit: Optional[pint.Quantity] = None, energy_unit: Optional[pint.Quantity] = None, ): r"""Constructor. Parameters ---------- positions_unit : pint.Unit, optional The unit of the positions passed to the class methods. Since input ``Tensor``\ s do not have units attached, this is used to appropriately convert ``batch_positions`` to ASE units. If ``None``, no conversion is performed, which assumes that the input positions are in the units specified by the class attribute :attr:`~PotentialBase.DEFAULT_POSITIONS_UNIT`. energy_unit : pint.Unit, optional The unit used for the returned energies (and as a consequence forces). Since ``Tensor``\ s do not have units attached, this is used to appropriately convert ASE energies into the desired units. If ``None``, no conversion is performed, which means that energies will be returned in the units specified by the class attribute :attr:`~PotentialBase.DEFAULT_ENERGY_UNIT`. """ super().__init__() self._positions_unit = positions_unit self._energy_unit = energy_unit
@property def positions_unit(self) -> pint.Quantity: """The positions unit requested for the input.""" if self._positions_unit is None: ureg = self._get_unit_registry() return getattr(ureg, self.DEFAULT_POSITIONS_UNIT) return self._positions_unit @property def energy_unit(self) -> pint.Quantity: """The energy units of the returned potential.""" if self._energy_unit is None: ureg = self._get_unit_registry() return getattr(ureg, self.DEFAULT_ENERGY_UNIT) return self._energy_unit
[docs] @classmethod def default_positions_unit(cls, unit_registry) -> pint.Quantity: """Return the default positions units.""" return getattr(unit_registry, cls.DEFAULT_POSITIONS_UNIT)
[docs] @classmethod def default_energy_unit(cls, unit_registry) -> pint.Quantity: """Return the default energy units.""" return getattr(unit_registry, cls.DEFAULT_ENERGY_UNIT)
def _get_unit_registry(self): """Return a unit registry. The class tries to obtain a ``pint.UnitRegistry`` from the units passed on initialization. If none was found, it creates a new one. """ if self._positions_unit is not None: return self._positions_unit._REGISTRY if self._energy_unit is not None: return self._energy_unit._REGISTRY return pint.UnitRegistry()