tfep.potentials.ase.ASEPotential
- class tfep.potentials.ase.ASEPotential(calculator, symbols=None, numbers=None, pbc=None, positions_unit=None, energy_unit=None, parallelization_strategy=None, **atoms_kwargs)[source]
Bases:
PotentialBasePotential energy and forces with ASE.
This
Modulewraps :class:.ASEPotentialEnergyFuncto provide a differentiable potential energy function for training.Warning
Currently double-backpropagation is not supported, which means force matching cannot be performed during training.
- __init__(calculator, symbols=None, numbers=None, pbc=None, positions_unit=None, energy_unit=None, parallelization_strategy=None, **atoms_kwargs)[source]
Constructor.
- Parameters:
calculator (ase.calculators.calculator.Calculator) – The ASE calculator used to compute energies and forces.
symbols (str or List[str]) – The symbols of the atoms elements used to initialize the
ase.Atomsobject. It can be a string formula, a list of symbols, or a list ofase.Atomobjects. Examples:'H2O','COPt12',['H', 'H', 'O'],[Atom('Ne', (x, y, z)), ...].numbers (List[int]) – Atomic numbers (use only one between symbols and numbers).
pbc (bool or three bool) – Periodic boundary conditions flags. Examples:
True,False,0,1,(1, 1, 0),(True, False, False).positions_unit (pint.Unit, optional) – The unit of the positions passed to the class methods. Since input
Tensor``s do not have units attached, this is used to appropriately convert ``batch_positionsto ASE units. IfNone, no conversion is performed, which assumes that the input positions are in the same units used by ASE.energy_unit (pint.Unit, optional) – The unit used for the returned energies (and as a consequence forces). Since
Tensor``s do not have units attached, this is used to appropriately convert ASE energies into the desired units. If ``Noneis performed, which means that energies and forces will be returned in ASE units.parallelization_strategy (tfep.utils.parallel.ParallelizationStrategy, optional) – The parallelization strategy used to distribute batches of energy and gradient calculations. By default, these are executed serially.
**atoms_kwargs – Other keyword arguments for
ase.Atoms.
See also
ASEPotentialEnergyFuncMore details on input parameters and implementation details.
Methods
__init__(calculator[, symbols, numbers, ...])Constructor.
add_module(name, module)Add a child module to the current module.
apply(fn)Apply
fnrecursively to every submodule (as returned by.children()) as well as self.bfloat16()Casts all floating point parameters and buffers to
bfloat16datatype.buffers([recurse])Return an iterator over module buffers.
children()Return an iterator over immediate children modules.
compile(*args, **kwargs)Compile this Module's forward using
torch.compile().cpu()Move all model parameters and buffers to the CPU.
cuda([device])Move all model parameters and buffers to the GPU.
default_energy_unit(unit_registry)Return the default energy units.
default_positions_unit(unit_registry)Return the default positions units.
double()Casts all floating point parameters and buffers to
doubledatatype.eval()Set the module in evaluation mode.
extra_repr()Set the extra representation of the module.
float()Casts all floating point parameters and buffers to
floatdatatype.forward(batch_positions[, batch_cell])Compute a differential potential energy for a batch of configurations.
get_buffer(target)Return the buffer given by
targetif it exists, otherwise throw an error.get_extra_state()Return any extra state to include in the module's state_dict.
get_parameter(target)Return the parameter given by
targetif it exists, otherwise throw an error.get_submodule(target)Return the submodule given by
targetif it exists, otherwise throw an error.half()Casts all floating point parameters and buffers to
halfdatatype.ipu([device])Move all model parameters and buffers to the IPU.
load_state_dict(state_dict[, strict, assign])Copy parameters and buffers from
state_dictinto this module and its descendants.modules()Return an iterator over all modules in the network.
mtia([device])Move all model parameters and buffers to the MTIA.
named_buffers([prefix, recurse, ...])Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children()Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules([memo, prefix, remove_duplicate])Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters([prefix, recurse, ...])Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parameters([recurse])Return an iterator over module parameters.
register_backward_hook(hook)Register a backward hook on the module.
register_buffer(name, tensor[, persistent])Add a buffer to the module.
register_forward_hook(hook, *[, prepend, ...])Register a forward hook on the module.
register_forward_pre_hook(hook, *[, ...])Register a forward pre-hook on the module.
register_full_backward_hook(hook[, prepend])Register a backward hook on the module.
register_full_backward_pre_hook(hook[, prepend])Register a backward pre-hook on the module.
register_load_state_dict_post_hook(hook)Register a post-hook to be run after module's
load_state_dict()is called.register_load_state_dict_pre_hook(hook)Register a pre-hook to be run before module's
load_state_dict()is called.register_module(name, module)Alias for
add_module().register_parameter(name, param)Add a parameter to the module.
register_state_dict_post_hook(hook)Register a post-hook for the
state_dict()method.register_state_dict_pre_hook(hook)Register a pre-hook for the
state_dict()method.requires_grad_([requires_grad])Change if autograd should record operations on parameters in this module.
set_extra_state(state)Set extra state contained in the loaded state_dict.
set_submodule(target, module)Set the submodule given by
targetif it exists, otherwise throw an error.share_memory()See
torch.Tensor.share_memory_().state_dict(*args[, destination, prefix, ...])Return a dictionary containing references to the whole state of the module.
to(*args, **kwargs)Move and/or cast the parameters and buffers.
to_empty(*, device[, recurse])Move the parameters and buffers to the specified device without copying storage.
train([mode])Set the module in training mode.
type(dst_type)Casts all parameters and buffers to
dst_type.xpu([device])Move all model parameters and buffers to the XPU.
zero_grad([set_to_none])Reset gradients of all model parameters.
Attributes
The default energy unit.
The default positions unit.
T_destinationcall_super_initdump_patchesThe energy units of the returned potential.
The positions unit requested for the input.
training- DEFAULT_ENERGY_UNIT: str = 'eV'
The default energy unit.
- DEFAULT_POSITIONS_UNIT: str = 'angstrom'
The default positions unit.
- classmethod default_energy_unit(unit_registry) Quantity
Return the default energy units.
- classmethod default_positions_unit(unit_registry) Quantity
Return the default positions units.
- property energy_unit: Quantity
The energy units of the returned potential.
- forward(batch_positions, batch_cell=None)[source]
Compute a differential potential energy for a batch of configurations.
- Parameters:
batch_positions (torch.Tensor) – Shape
(batch_size, 3*n_atoms). The atoms positions in units ofself.positions_unit.batch_cell (torch.Tensor, optional) – Shape
(batch_size, 3, 3)or(batch_size, 3)or(batch_size, 6). Unit cell vectors. Can also be given as just three numbers for orthorhombic cells, or 6 numbers, where first three are lengths of unit cell vectors (in units ofself.positions_unit, and the other three are angles between them (in degrees), in following order:[len(a), len(b), len(c), angle(b,c), angle(a,c), angle(a,b)]. First vector will lie in x-direction, second in xy-plane, and the third one in z-positive subspace.
- Returns:
potential_energy –
potential_energy[i]is the potential energy of configurationbatch_positions[i]in units ofself.energy_unit.- Return type:
torch.Tensor
- property positions_unit: Quantity
The positions unit requested for the input.