tfep.app.mixedmaf.MixedMAFMap

class tfep.app.mixedmaf.MixedMAFMap(potential_energy_func: Module, topology_file_path: str, coordinates_file_path: str | Sequence[str], temperature: Quantity, batch_size: int = 1, mapped_atoms: Sequence[int] | str | None = None, conditioning_atoms: Sequence[int] | str | None = None, origin_atom: int | str | None = None, axes_atoms: Sequence[int] | str | None = None, remove_translation: bool = False, remove_rotation: bool = False, tfep_logger_dir_path: str = 'tfep_logs', n_maf_layers: int = 6, distance_lower_limit_displacement: Quantity | None = None, dataloader_kwargs: Dict | None = None, **kwargs)[source]

Bases: TFEPMapBase

A TFEP map using a masked autoregressive flow in mixed internal and Cartesian coordinates.

The class divides the atoms of the entire system in mapped, conditioning, and fixed. Mapped atoms are defined as those that the flow maps. Conditioning atoms are not mapped but are given as input to the flow to condition the mapping. Fixed atoms are instead ignored.

Before going through the MAF, the coordinates of are transformed into a mixed Cartesian/internal coordinate system. To this end, the system is first divided into connected fragments based on the bond topology, and a separate Z-matrix is built for each fragment.

The Z-matrix is automatically determined based on a heuristic. The first atom is chosen as the center of the graph representing the molecule. The graph is then traversed breath first from the center and, for each atom, the bond, angle, and dihedral atoms are selected from those in the current Z-matrix according to these priorities: 1) closest to the inserted atom; 2) only for angle and dihedral atoms, closest to the bond atom; 3) most recently added to the Z-matrix 4) only for heavy atoms, hydrogens are de-prioritized. In particular, 2) and 3) limit the occurrence of undefined angles and instability during training as a result of a triplet of collinear atoms.

The first three atoms of each molecule’s Z-matrix and all conditioning atoms are represented as Cartesian, while the remaining mapped atoms are converted to internal coordinates.

The flow also rototranslates the Cartesian coordinates into a relative frame of reference which based on the position of an origin_atom and two axes_atoms that determine the origin and the orientation of the axes, respectively. When given, these atoms are prioritized for the choice of the first three atoms of a molecule’s Z-matrix. If not passed, these 3 atoms are automatically chosen as the first three atoms in the Z-matrix of the largest fragment. Optionally, the roto-translational degrees of freedom can be removed from the mapping with remove_translation and remove_rotation.

When remove_rotation is True the axes_atoms are represented in internal coordinates (2 distances w.r.t. the origin atom and 1 angle). When False, the axis/plane atoms are represented in Cartesian/cylindrical coordinates. This is just because it simplifies the support for removing the global rotational degrees of freedom with remove_rotation.

The class further supports logging the potential energies computed during training (required for the multimap TFEP analysis) and mid-epoch resuming.

Warning

Currently, this class is not multi-process or thread safe. Running with multiple processes may result in the corrupted logging of the potentials and Jacobians.

See also

tfep.app.base.TFEPMapBase

The base class for TFEP maps with more detailed explanations of how the relative reference frame and the division in mapped/conditioning/fixed atoms work.

Examples

>>> from tfep.potentials.psi4 import Psi4Potential
>>> units = pint.UnitRegistry()
>>>
>>> tfep_map = MixedMAFMap(
...     potential_energy_func=Psi4Potential(name='mp2'),
...     topology_file_path='path/to/topology.psf',
...     coordinates_file_path='path/to/trajectory.dcd',
...     temperature=300*units.kelvin,
...     batch_size=64,
...     mapped_atoms='resname MOL',  # MDAnalysis selection syntax.
...     conditioning_atoms=range(10, 20),
...     origin_atom=12,  # Fix the origin of the relative reference frame on atom 123.
...     axes_atoms=[13, 16],  # Determine the orientation of the reference frame.
... )
>>>
>>> # Train the flow and save the potential energies.
>>> import lightning
>>> trainer = lightning.Trainer()
>>> trainer.fit(tfep_map)  
__init__(potential_energy_func: Module, topology_file_path: str, coordinates_file_path: str | Sequence[str], temperature: Quantity, batch_size: int = 1, mapped_atoms: Sequence[int] | str | None = None, conditioning_atoms: Sequence[int] | str | None = None, origin_atom: int | str | None = None, axes_atoms: Sequence[int] | str | None = None, remove_translation: bool = False, remove_rotation: bool = False, tfep_logger_dir_path: str = 'tfep_logs', n_maf_layers: int = 6, distance_lower_limit_displacement: Quantity | None = None, dataloader_kwargs: Dict | None = None, **kwargs)[source]

Constructor.

Parameters:
  • potential_energy_func (torch.nn.Module) – A PyTorch module encapsulating the target potential energy function (e.g. tfep.potentials.psi4.ASEPotential).

  • topology_file_path (str) – The path to the topology file. The file can be in any format supported by MDAnalysis which is automatically detected from the file extension.

  • coordinates_file_path (str or Sequence[str]) – The path(s) to the trajectory file(s). If a sequence of files is given, the trajectories are concatenated into a single large dataset. The file(s) can be in any format supported by MDAnalysis which is automatically detected from the file extension.

  • temperature (pint.Quantity) – The temperature of the ensemble.

  • batch_size (int, optional) – The batch size.

  • mapped_atoms (Sequence[int] or str, optional) – The indices (0-based) of the atoms to map or a selection string in MDAnalysis syntax. If not passed, all atoms that are not conditioning are mapped (i.e., all atoms are mapped if also conditioning_atoms is not given).

  • conditioning_atoms (Sequence[int] or str, optional) – The indices (0-based) of the atoms conditioning the mapping or a selection string in MDAnalysis syntax. If not passed, no atom will condition the map. These atoms cannot overlap with mapped_atoms.

  • origin_atom (int or str or None, optional) – The index (0-based) or a selection string in MDAnalysis syntax of an atom on which to center the origin of the relative frame of reference. If a conditioning atom, the coordinates are not passed to the flow as they would be always zero. By default, this is chosen as the 1st atom in the Z-matrix of the largest fragment.

  • axes_atoms (Sequence[int] or str or None, optional) – A pair of indices (0-based) or a selection string in MDAnalysis syntax for the two atoms determining the relative frame of reference. The axes_atoms[0]-th atom will lay on the z axis , and the axes_atoms[1]-th atom will lay on the plane spanned by the x and z axes. The y axis will be set as the cross product of x and y.

    If conditioning atoms, the coordinates that after the rotation are 0 are not passed to the flow. The other degrees of freedom are converted into two distances (from the origin atom) and a valence angle. By default, these are chosen as the 2nd and 3rd atoms in the Z-matrix of the largest fragment.

  • remove_translation (bool, optional) – If True, the 3 degrees of freedom of the origin_atom are not mapped even if origin_atom is mapped. When origin_atom is conditioning, this option has no effect.

  • remove_rotation (bool, optional) – If True, the 3 rotational degrees of freedom of axes_atoms are not mapped even if axes_atoms are mapped atoms. In this case, only their 2 distances from the origin atom and the valence angle between them is passed to the flow. When axes_atoms are conditioning, this option has no effect.

  • tfep_logger_dir_path (str, optional) – The path where to save TFEP-related information (potential energies, sample indices, etc.).

  • n_maf_layers (int, optional) – The number of MAF layers.

  • distance_lower_limit_displacement (pint.Quantity or None) – This controls the (fixed) lower limit for the neural spline used to transform bond lengths. This lower limit is set to max(0, min_observed - distance_lower_limit_displacement) where min_observed is the minimum bond length observed for the specific bond on a random sample in the dataset. The default value is 0.3 Angstrom.

    Note that the same maximum displacement is applied to control the two distances between the axes atoms and the origin.

  • dataloader_kwargs (Dict, optional) – Extra keyword arguments to pass to torch.utils.data.DataLoader.

  • **kwargs – Other keyword arguments to pass to the constructor of tfep.nn.flows.MAF.

Methods

__init__(potential_energy_func, ...[, ...])

Constructor.

add_module(name, module)

Add a child module to the current module.

all_gather(data[, group, sync_grads])

Gather tensors or collections of tensors from multiple processes.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

are_axes_atoms_mapped()

Return whether the two axes atoms (if any) are mapped.

backward(loss, *args, **kwargs)

Called to perform backward on the loss returned in training_step().

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

clip_gradients(optimizer[, ...])

Handles gradient clipping internally.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

configure_callbacks()

Configure model-specific callbacks.

configure_flow()

Initialize the normalizing flow.

configure_gradient_clipping(optimizer[, ...])

Perform gradient clipping for the optimizer parameters.

configure_model()

Hook to create modules in a strategy and precision aware context.

configure_optimizers()

Lightning method.

configure_sharded_model()

Deprecated.

cpu()

See torch.nn.Module.cpu().

create_dataset()

Create and return the Dataset object.

create_partial_flow(flow[, return_partial])

Wrap the flow to remove the fixed DOFs.

create_universe()

Create and return the MDAnalysis Universe.

cuda([device])

Moves all model parameters and buffers to the GPU.

determine_atom_indices()

Determine mapped, conditioning, fixed, and reference frame atom indices.

double()

See torch.nn.Module.double().

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

See torch.nn.Module.float().

forward(batch)

Execute the normalizing flow in the forward direction.

freeze()

Freeze all params for inference.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_conditioning_indices(idx_type, remove_fixed)

Return the indices of the conditioning atom or degrees of freedom (DOF).

get_extra_state()

Return any extra state to include in the module's state_dict.

get_mapped_indices(idx_type, remove_fixed)

Return the indices of the mapped atom or degrees of freedom (DOF).

get_nonfixed_indices(idx_type, remove_fixed)

Return the indices of the mapped and conditioning atom or degrees of freedom (DOF).

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_reference_atoms_indices(remove_fixed[, ...])

Return the atom indices of the origin and axes atoms.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

See torch.nn.Module.half().

inverse(batch)

Execute the normalizing flow in the inverse direction.

ipu([device])

Move all model parameters and buffers to the IPU.

load_from_checkpoint(checkpoint_path[, ...])

Primary way of loading a model from a checkpoint.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

log(name, value[, prog_bar, logger, ...])

Log a key, value pair.

log_dict(dictionary[, prog_bar, logger, ...])

Log a dictionary of values at once.

lr_scheduler_step(scheduler, metric)

Override this method to adjust the default way the Trainer calls each scheduler.

lr_schedulers()

Returns the learning rate scheduler(s) that are being used during training.

manual_backward(loss, *args, **kwargs)

Call this directly from your training_step() when doing optimizations manually.

modules()

Return an iterator over all modules in the network.

mtia([device])

Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

on_after_backward()

Called after loss.backward() and before optimizers are stepped.

on_after_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

on_before_backward(loss)

Called before loss.backward().

on_before_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

on_before_optimizer_step(optimizer)

Called before optimizer.step().

on_before_zero_grad(optimizer)

Called after training_step() and before optimizer.zero_grad().

on_fit_end()

Called at the very end of fit.

on_fit_start()

Called at the very beginning of fit.

on_load_checkpoint(checkpoint)

Lightning hook.

on_predict_batch_end(outputs, batch, batch_idx)

Called in the predict loop after the batch.

on_predict_batch_start(batch, batch_idx[, ...])

Called in the predict loop before anything happens for that batch.

on_predict_end()

Called at the end of predicting.

on_predict_epoch_end()

Called at the end of predicting.

on_predict_epoch_start()

Called at the beginning of predicting.

on_predict_model_eval()

Called when the predict loop starts.

on_predict_start()

Called at the beginning of predicting.

on_save_checkpoint(checkpoint)

Lightning hook.

on_test_batch_end(outputs, batch, batch_idx)

Called in the test loop after the batch.

on_test_batch_start(batch, batch_idx[, ...])

Called in the test loop before anything happens for that batch.

on_test_end()

Called at the end of testing.

on_test_epoch_end()

Called in the test loop at the very end of the epoch.

on_test_epoch_start()

Called in the test loop at the very beginning of the epoch.

on_test_model_eval()

Called when the test loop starts.

on_test_model_train()

Called when the test loop ends.

on_test_start()

Called at the beginning of testing.

on_train_batch_end(outputs, batch, batch_idx)

Called in the training loop after the batch.

on_train_batch_start(batch, batch_idx)

Called in the training loop before anything happens for that batch.

on_train_end()

Called at the end of training before logger experiment is closed.

on_train_epoch_end()

Called in the training loop at the very end of the epoch.

on_train_epoch_start()

Called in the training loop at the very beginning of the epoch.

on_train_start()

Called at the beginning of training after sanity check.

on_validation_batch_end(outputs, batch, ...)

Called in the validation loop after the batch.

on_validation_batch_start(batch, batch_idx)

Called in the validation loop before anything happens for that batch.

on_validation_end()

Called at the end of validation.

on_validation_epoch_end()

Called in the validation loop at the very end of the epoch.

on_validation_epoch_start()

Called in the validation loop at the very beginning of the epoch.

on_validation_model_eval()

Called when the validation loop starts.

on_validation_model_train()

Called when the validation loop ends.

on_validation_model_zero_grad()

Called by the training loop to release gradients before entering the validation loop.

on_validation_start()

Called at the beginning of validation.

optimizer_step(epoch, batch_idx, optimizer)

Override this method to adjust the default way the Trainer calls the optimizer.

optimizer_zero_grad(epoch, batch_idx, optimizer)

Override this method to change the default behaviour of optimizer.zero_grad().

optimizers([use_pl_optimizer])

Returns the optimizer(s) that are being used during training.

parameters([recurse])

Return an iterator over module parameters.

predict_dataloader()

An iterable or collection of iterables specifying prediction samples.

predict_step(*args, **kwargs)

Step function called during predict().

prepare_data()

Use this to download and prepare data.

print(*args, **kwargs)

Prints only from process 0.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook(hook)

Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_post_hook(hook)

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

save_hyperparameters(*args[, ignore, frame, ...])

Save arguments to hparams attribute.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

set_submodule(target, module)

Set the submodule given by target if it exists, otherwise throw an error.

setup([stage])

Lightning method.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

teardown(stage)

Called at the end of fit (train + validate), validate, test, or predict.

test_dataloader()

An iterable or collection of iterables specifying test samples.

test_step(*args, **kwargs)

Operates on a single batch of data from the test set.

to(*args, **kwargs)

See torch.nn.Module.to().

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

to_onnx(file_path[, input_sample])

Saves the model in ONNX format.

to_torchscript([file_path, method, ...])

By default compiles the whole model to a ScriptModule.

toggle_optimizer(optimizer)

Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.

train([mode])

Set the module in training mode.

train_dataloader()

Lightning method.

training_step(batch, batch_idx)

Lightning method.

transfer_batch_to_device(batch, device, ...)

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

type(dst_type)

See torch.nn.Module.type().

unfreeze()

Unfreeze all parameters for training.

untoggle_optimizer(optimizer)

Resets the state of required gradients that were toggled with toggle_optimizer().

val_dataloader()

An iterable or collection of iterables specifying validation samples.

validation_step(*args, **kwargs)

Operates on a single batch of data from the validation set.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

Attributes

CHECKPOINT_HYPER_PARAMS_KEY

CHECKPOINT_HYPER_PARAMS_NAME

CHECKPOINT_HYPER_PARAMS_TYPE

T_destination

automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

call_super_init

current_epoch

The current epoch in the Trainer, or 0 if not attached.

device

device_mesh

Strategies like ModelParallelStrategy will create a device mesh that can be accessed in the configure_model() hook to parallelize the LightningModule.

dtype

dump_patches

example_input_array

The example input array is a specification of what the module can consume in the forward() method.

fabric

global_rank

The index of the current process across all nodes and devices.

global_step

Total training batches seen across all epochs.

hparams

The collection of hyperparameters saved with save_hyperparameters().

hparams_initial

The collection of hyperparameters saved with save_hyperparameters().

local_rank

The index of the current process within a single node.

logger

Reference to the logger object in the Trainer.

loggers

Reference to the list of loggers in the Trainer.

n_conditioning_atoms

The number of conditioning atoms.

n_conditioning_dofs

The number of conditioning degrees of freedom (excluding the constrained DOFs of the reference frame atoms).

n_fixed_atoms

The number of fixed atoms.

n_mapped_atoms

The number of mapped atoms.

n_mapped_dofs

The number of mapped degrees of freedom (excluding the constrained DOFs of the reference frame atoms).

n_nonfixed_atoms

Total number of mapped and conditioning atoms.

n_nonfixed_dofs

Total number of mapped and conditioning degrees of freedom (excluding the constrained DOFs of the reference frame atoms).

on_gpu

Returns True if this model is currently located on a GPU.

strict_loading

Determines how Lightning loads this model using .load_state_dict(..., strict=model.strict_loading).

trainer

dataset

The dataset.

training

are_axes_atoms_mapped()

Return whether the two axes atoms (if any) are mapped.

Returns:

are_mapped – A pair (is_axes_atom_0_mapped, is_axes_atom_1_mapped) or None if there are no axes atoms.

Return type:

None or Tuple[bool]

configure_flow() Module[source]

Initialize the normalizing flow.

Returns:

flow – The normalizing flow.

Return type:

torch.nn.Module

configure_optimizers()

Lightning method.

Returns:

optimizer – The optimizer to use for the training.

Return type:

torch.optim.optimizer.Optimizer

create_dataset()

Create and return the Dataset object.

Returns:

dataset – The PyTorch dataset.

Return type:

torch.utils.data.Dataset

create_partial_flow(flow: Module, return_partial: bool = False) Module

Wrap the flow to remove the fixed DOFs.

Parameters:
  • flow (torch.nn.Module) – The flow to be wrapped in the Partial and/or Oriented/CenteredCentroid flows.

  • return_partial (bool, optional) – The return_partial flag of the PartialFlow.

Returns:

flow – The wrapped flow.

Return type:

torch.nn.Module

create_universe()

Create and return the MDAnalysis Universe.

Returns:

universe – The MDAnalysis Universe object.

Return type:

MDAnalysis.Universe

dataset: tfep.io.dataset.TrajectoryDataset | None

The dataset.

determine_atom_indices()

Determine mapped, conditioning, fixed, and reference frame atom indices.

This initializes the following attributes - self._mapped_atom_indices - self._conditioning_atom_indices - self._fixed_atom_indices - self._origin_atom_idx - self._axes_atoms_indices

forward(batch: Dict) dict[str, Tensor]

Execute the normalizing flow in the forward direction.

Parameters:

batch (dict[str, torch.Tensor]) – Batch data. Must have the key 'positions'.

Returns:

result – The output of the normalizing flow with at least the following keys:

  • 'positions': Shape (batch_size, n_atoms*3). The mapped coordinates of the flow.

  • 'log_det_J': Shape (batch_size,). The log weight of the transformation. This is usually the logarithm of the absolute value of the Jacobian determinant for deterministic maps, but it can be also the trace for continuous flows, or the log-ratio of the forward and backward paths in stochastic flows.

  • 'regularization': Shape (batch_size,). Optional. Arbitrary regularization terms. The mean across batches is added to the loss.

Any other tensor in this dictionary that is a scalar or of shape (batch_size,) is automatically logged. Any other tensor is ignored.

Return type:

dict[str, torch.Tensor]

get_conditioning_indices(idx_type: Literal['atom', 'dof'], remove_fixed: bool) Tensor

Return the indices of the conditioning atom or degrees of freedom (DOF).

Each atom generally has 3 degrees of freedom, except for the atoms used to set the relative frame of reference. If the axes_atoms (or only one of them) have been indicated as conditioning, the returned conditioning DOFs indices also include the DOFs of the axes_atoms that are not constrained, i.e., the x coordinate of axes_atoms[0], and the x,y coordinates of axes_atoms[1]. The origin_atom is always a conditioning atom by definition, and it is thus included in the returned indices.

Parameters:
  • idx_type (Literal[‘atom’, ‘dof’]) – Whether to return the indices of the atom or the degrees of freedom.

  • remove_fixed (bool) – If True, the returned tensor represent the indices after the fixed atoms have been removed.

Returns:

indices – The conditioning atom/DOFs indices.

Return type:

torch.Tensor

get_mapped_indices(idx_type: Literal['atom', 'dof'], remove_fixed: bool) Tensor

Return the indices of the mapped atom or degrees of freedom (DOF).

Each atom generally has 3 degrees of freedom, except for the atoms used to set the relative frame of reference. If the axes_atoms (or only one of them) have been indicated as mapped, the returned conditioning DOFs indices also include the DOFs of the axes_atoms that are not constrained, i.e., the x coordinate of axes_atoms[0], and the x,y coordinates of axes_atoms[1].

Parameters:
  • idx_type (Literal[‘atom’, ‘dof’]) – Whether to return the indices of the atom or the degrees of freedom.

  • remove_fixed (bool) – If True, the returned tensor represent the indices after the fixed atoms have been removed.

Returns:

indices – The mapped atom/DOFs indices.

Return type:

torch.Tensor

get_nonfixed_indices(idx_type: Literal['atom', 'dof'], remove_fixed: bool) Tensor

Return the indices of the mapped and conditioning atom or degrees of freedom (DOF).

This is a more efficient way of obtaining all mapped and conditioning indices tha concatenating and sorting the results of TFEPMapBase.get_mapped_indices() and TFEPMapBase.get_conditioning_indices().

Parameters:
  • idx_type (Literal[‘atom’, ‘dof’]) – Whether to return the indices of the atom or the degrees of freedom.

  • remove_fixed (bool) – If True, the returned tensor represent the indices after the fixed atoms have been removed.

Returns:

indices – The conditioning atom/DOFs indices.

Return type:

torch.Tensor

get_reference_atoms_indices(remove_fixed: bool, separate_origin_axes: bool = False) Tensor | None | List[Tensor | None]

Return the atom indices of the origin and axes atoms.

Parameters:
  • remove_fixed (bool) – If True, the returned tensor represent the indices after the fixed atoms have been removed.

  • separate_origin_axes (bool, optional) – If True, the origin and axes atoms are returned separately in two Tensors. Otherwise, a single Tensor is returned.

Returns:

reference_atom_indices – If separate_origin_axes is False, a single Tensor including the indices, in this order, of the origin, axis, and plane atoms (if they exist) or None if there are no origin and axes atoms.

If separate_origin_axes is False, this is a pair of Tensors (or None) holding the origin atom index and the axes atom indices.

Return type:

torch.Tensor or None or List[torch.Tensor | None]

inverse(batch: Dict) dict[str, Tensor]

Execute the normalizing flow in the inverse direction.

See TFEPMapBase.forward() for the documentation on the input parameters and returned value.

property n_conditioning_atoms: int

The number of conditioning atoms.

property n_conditioning_dofs: int

The number of conditioning degrees of freedom (excluding the constrained DOFs of the reference frame atoms).

property n_fixed_atoms: int

The number of fixed atoms.

property n_mapped_atoms: int

The number of mapped atoms.

property n_mapped_dofs: int

The number of mapped degrees of freedom (excluding the constrained DOFs of the reference frame atoms).

property n_nonfixed_atoms: int

Total number of mapped and conditioning atoms.

property n_nonfixed_dofs: int

Total number of mapped and conditioning degrees of freedom (excluding the constrained DOFs of the reference frame atoms).

on_load_checkpoint(checkpoint: Dict[str, Any])

Lightning hook.

Used to restore the state of the batch sampler for mid-epoch resuming.

on_save_checkpoint(checkpoint: Dict[str, Any])

Lightning hook.

Used to store the state of the batch sampler for mid-epoch resuming.

setup(stage: str = 'fit')

Lightning method.

This is executed on all processes by Lightning in DDP mode (contrary to __init__) and can be used to initialize objects like the Dataset and all data-dependent objects.

train_dataloader()

Lightning method.

Returns:

data_loader – The training data loader.

Return type:

torch.utils.data.DataLoader

training_step(batch, batch_idx)

Lightning method.

Execute a training step.