tfep.app.base.TFEPMapBase

class tfep.app.base.TFEPMapBase(potential_energy_func: Module, topology_file_path: str, coordinates_file_path: str | Sequence[str], temperature: Quantity, batch_size: int = 1, mapped_atoms: Sequence[int] | str | None = None, conditioning_atoms: Sequence[int] | str | None = None, origin_atom: int | str | None = None, axes_atoms: Sequence[int] | str | None = None, tfep_logger_dir_path: str = 'tfep_logs', dataloader_kwargs: Dict | None = None)[source]

Bases: ABC, LightningModule

A LightningModule to run TFEP calculations.

This abstract class implements several data-related utilities that are shared by all TFEP maps. In particular to:

  • Support correct mid-epoch resuming.

  • Log vectorial quantities such as the calculated potential energies and the absolute Jacobian terms that can later be used to estimate the free energy.

  • Identify and perform consistency checks on three atoms (origin, axis, and plane) that can be used to define a relative frame of reference for the flow.

  • Identify the mapped, conditioning, and fixed atom indices and handle fixed atoms.

For the latter, mapped atoms are defined as those that the flow maps. Conditioning atoms are not mapped but are given as input to the flow to condition the mapping. Fixed atoms are instead ignored. Note that the flow defined child class must handle only the mapped and conditioning atoms. The flow will be automatically wrapped in a PartialFlow to handle the fixed atoms.

The class further provides convenience methods to retrieve the indices of the mapped and conditioning atoms/degrees of freedom after the fixed atoms are removed through the methods get_mapped_indices() and get_conditioning_indices() (see example below). A similar function exist for the atom indices of the reference frame atoms called get_reference_atoms_indices()

The only required method to implement a concrete class is configure_flow().

Warning

Currently, this class is not multi-process or thread safe. Running with multiple processes may result in the corrupted logging of the potentials and Jacobians.

Examples

Here is an example of how to implement a working map using TFEPMapBase.

>>> class TFEPMap(TFEPMapBase):
...
...     def configure_flow(self):
...         # A simple 1-layer affine autoregressive flow.
...         # The flow must fix the relative frame of reference or raise
...         # an error when origin and axes atoms are set.
...         reference_atoms_indices = self.get_reference_atoms_indices()
...         if reference_atoms_indices is not None:
...             raise NotImplementedError('Relative frame of reference is not supported.')
...
...         # The flow must take care only of the mapped and conditioning atoms.
...         conditioning_indices = self.get_conditioning_indices(
...             idx_type="dof", remove_fixed=True, remove_reference=True)
...         return tfep.nn.flows.MAF(
...             degrees_in=tfep.nn.conditioners.generate_degrees(
...                 n_features=self.n_nonfixed_dofs, conditioning_indices=conditioning_indices)
...         )
...

After this, the TFEP calculation can be run using.

>>> from tfep.potentials.psi4 import Psi4Potential
>>> units = pint.UnitRegistry()
>>>
>>> tfep_map = TFEPMap(
...     potential_energy_func=Psi4Potential(name='mp2'),
...     topology_file_path='path/to/topology.psf',
...     coordinates_file_path='path/to/trajectory.dcd',
...     temperature=300*units.kelvin,
...     batch_size=64,
...     mapped_atoms='resname MOL',  # MDAnalysis selection syntax.
...     conditioning_atoms=range(10, 20),
... )
>>>
>>> # Train the flow and save the potential energies.
>>> trainer = lightning.Trainer()
>>> trainer.fit(tfep_map)  
__init__(potential_energy_func: Module, topology_file_path: str, coordinates_file_path: str | Sequence[str], temperature: Quantity, batch_size: int = 1, mapped_atoms: Sequence[int] | str | None = None, conditioning_atoms: Sequence[int] | str | None = None, origin_atom: int | str | None = None, axes_atoms: Sequence[int] | str | None = None, tfep_logger_dir_path: str = 'tfep_logs', dataloader_kwargs: Dict | None = None)[source]

Constructor.

Parameters:
  • potential_energy_func (torch.nn.Module) – A PyTorch module encapsulating the target potential energy function (e.g. tfep.potentials.psi4.ASEPotential).

  • topology_file_path (str) – The path to the topology file. The file can be in any format supported by MDAnalysis which is automatically detected from the file extension.

  • coordinates_file_path (str or Sequence[str]) – The path(s) to the trajectory file(s). If a sequence of files is given, the trajectories are concatenated into a single large dataset. The file(s) can be in any format supported by MDAnalysis which is automatically detected from the file extension.

  • temperature (pint.Quantity) – The temperature of the ensemble.

  • batch_size (int, optional) – The batch size.

  • mapped_atoms (Sequence[int] or str or None, optional) – The indices (0-based) of the atoms to map or a selection string in MDAnalysis syntax. If not passed, all atoms that are not conditioning are mapped (i.e., all atoms are mapped if also conditioning_atoms is not given.

  • conditioning_atoms (Sequence[int] or str or None, optional) – The indices (0-based) of the atoms conditioning the mapping or a selection string in MDAnalysis syntax. If not passed, no atom will condition the map.

  • origin_atom (int or str or None, optional) – The index (0-based) or a selection string in MDAnalysis syntax of an atom on which to center the origin of the relative frame of reference. While this atom affects the mapping of the mapped atoms, its position will be constrained during the mapping, and thus it must be a conditioning atom by definition.

  • axes_atoms (Sequence[int] or str or None, optional) – A pair of indices (0-based) or a selection string in MDAnalysis syntax for the two atoms determining the relative frame of reference.

    The details of how these are used to fix the orientation of the frame of reference depend on the implementation of configure_flow(). For example, axes_atoms[0]-th atom may lay on the z axis , and the axes_atoms[1]-th atom may lay on the plane spanned by the x and z axes.

    These atoms can be either conditioning or mapped. axes_atoms[0] has only 1 degree of freedom (DOF) while axes_atoms[1] has 2. Whether these DOFs are mapped or not depends on whether their atoms are indicated as mapped or conditioning, respectively.

  • tfep_logger_dir_path (str, optional) – The path where to save TFEP-related information (potential energies, sample indices, etc.).

  • dataloader_kwargs (Dict, optional) – Extra keyword arguments to pass to torch.utils.data.DataLoader.

Methods

__init__(potential_energy_func, ...[, ...])

Constructor.

add_module(name, module)

Add a child module to the current module.

all_gather(data[, group, sync_grads])

Gather tensors or collections of tensors from multiple processes.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

are_axes_atoms_mapped()

Return whether the two axes atoms (if any) are mapped.

backward(loss, *args, **kwargs)

Called to perform backward on the loss returned in training_step().

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

clip_gradients(optimizer[, ...])

Handles gradient clipping internally.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

configure_callbacks()

Configure model-specific callbacks.

configure_flow()

Initialize the normalizing flow.

configure_gradient_clipping(optimizer[, ...])

Perform gradient clipping for the optimizer parameters.

configure_model()

Hook to create modules in a strategy and precision aware context.

configure_optimizers()

Lightning method.

configure_sharded_model()

Deprecated.

cpu()

See torch.nn.Module.cpu().

create_dataset()

Create and return the Dataset object.

create_partial_flow(flow[, return_partial])

Wrap the flow to remove the fixed DOFs.

create_universe()

Create and return the MDAnalysis Universe.

cuda([device])

Moves all model parameters and buffers to the GPU.

determine_atom_indices()

Determine mapped, conditioning, fixed, and reference frame atom indices.

double()

See torch.nn.Module.double().

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

See torch.nn.Module.float().

forward(batch)

Execute the normalizing flow in the forward direction.

freeze()

Freeze all params for inference.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_conditioning_indices(idx_type, remove_fixed)

Return the indices of the conditioning atom or degrees of freedom (DOF).

get_extra_state()

Return any extra state to include in the module's state_dict.

get_mapped_indices(idx_type, remove_fixed)

Return the indices of the mapped atom or degrees of freedom (DOF).

get_nonfixed_indices(idx_type, remove_fixed)

Return the indices of the mapped and conditioning atom or degrees of freedom (DOF).

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_reference_atoms_indices(remove_fixed[, ...])

Return the atom indices of the origin and axes atoms.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

See torch.nn.Module.half().

inverse(batch)

Execute the normalizing flow in the inverse direction.

ipu([device])

Move all model parameters and buffers to the IPU.

load_from_checkpoint(checkpoint_path[, ...])

Primary way of loading a model from a checkpoint.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

log(name, value[, prog_bar, logger, ...])

Log a key, value pair.

log_dict(dictionary[, prog_bar, logger, ...])

Log a dictionary of values at once.

lr_scheduler_step(scheduler, metric)

Override this method to adjust the default way the Trainer calls each scheduler.

lr_schedulers()

Returns the learning rate scheduler(s) that are being used during training.

manual_backward(loss, *args, **kwargs)

Call this directly from your training_step() when doing optimizations manually.

modules()

Return an iterator over all modules in the network.

mtia([device])

Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

on_after_backward()

Called after loss.backward() and before optimizers are stepped.

on_after_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

on_before_backward(loss)

Called before loss.backward().

on_before_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

on_before_optimizer_step(optimizer)

Called before optimizer.step().

on_before_zero_grad(optimizer)

Called after training_step() and before optimizer.zero_grad().

on_fit_end()

Called at the very end of fit.

on_fit_start()

Called at the very beginning of fit.

on_load_checkpoint(checkpoint)

Lightning hook.

on_predict_batch_end(outputs, batch, batch_idx)

Called in the predict loop after the batch.

on_predict_batch_start(batch, batch_idx[, ...])

Called in the predict loop before anything happens for that batch.

on_predict_end()

Called at the end of predicting.

on_predict_epoch_end()

Called at the end of predicting.

on_predict_epoch_start()

Called at the beginning of predicting.

on_predict_model_eval()

Called when the predict loop starts.

on_predict_start()

Called at the beginning of predicting.

on_save_checkpoint(checkpoint)

Lightning hook.

on_test_batch_end(outputs, batch, batch_idx)

Called in the test loop after the batch.

on_test_batch_start(batch, batch_idx[, ...])

Called in the test loop before anything happens for that batch.

on_test_end()

Called at the end of testing.

on_test_epoch_end()

Called in the test loop at the very end of the epoch.

on_test_epoch_start()

Called in the test loop at the very beginning of the epoch.

on_test_model_eval()

Called when the test loop starts.

on_test_model_train()

Called when the test loop ends.

on_test_start()

Called at the beginning of testing.

on_train_batch_end(outputs, batch, batch_idx)

Called in the training loop after the batch.

on_train_batch_start(batch, batch_idx)

Called in the training loop before anything happens for that batch.

on_train_end()

Called at the end of training before logger experiment is closed.

on_train_epoch_end()

Called in the training loop at the very end of the epoch.

on_train_epoch_start()

Called in the training loop at the very beginning of the epoch.

on_train_start()

Called at the beginning of training after sanity check.

on_validation_batch_end(outputs, batch, ...)

Called in the validation loop after the batch.

on_validation_batch_start(batch, batch_idx)

Called in the validation loop before anything happens for that batch.

on_validation_end()

Called at the end of validation.

on_validation_epoch_end()

Called in the validation loop at the very end of the epoch.

on_validation_epoch_start()

Called in the validation loop at the very beginning of the epoch.

on_validation_model_eval()

Called when the validation loop starts.

on_validation_model_train()

Called when the validation loop ends.

on_validation_model_zero_grad()

Called by the training loop to release gradients before entering the validation loop.

on_validation_start()

Called at the beginning of validation.

optimizer_step(epoch, batch_idx, optimizer)

Override this method to adjust the default way the Trainer calls the optimizer.

optimizer_zero_grad(epoch, batch_idx, optimizer)

Override this method to change the default behaviour of optimizer.zero_grad().

optimizers([use_pl_optimizer])

Returns the optimizer(s) that are being used during training.

parameters([recurse])

Return an iterator over module parameters.

predict_dataloader()

An iterable or collection of iterables specifying prediction samples.

predict_step(*args, **kwargs)

Step function called during predict().

prepare_data()

Use this to download and prepare data.

print(*args, **kwargs)

Prints only from process 0.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook(hook)

Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_post_hook(hook)

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

save_hyperparameters(*args[, ignore, frame, ...])

Save arguments to hparams attribute.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

set_submodule(target, module)

Set the submodule given by target if it exists, otherwise throw an error.

setup([stage])

Lightning method.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

teardown(stage)

Called at the end of fit (train + validate), validate, test, or predict.

test_dataloader()

An iterable or collection of iterables specifying test samples.

test_step(*args, **kwargs)

Operates on a single batch of data from the test set.

to(*args, **kwargs)

See torch.nn.Module.to().

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

to_onnx(file_path[, input_sample])

Saves the model in ONNX format.

to_torchscript([file_path, method, ...])

By default compiles the whole model to a ScriptModule.

toggle_optimizer(optimizer)

Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.

train([mode])

Set the module in training mode.

train_dataloader()

Lightning method.

training_step(batch, batch_idx)

Lightning method.

transfer_batch_to_device(batch, device, ...)

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

type(dst_type)

See torch.nn.Module.type().

unfreeze()

Unfreeze all parameters for training.

untoggle_optimizer(optimizer)

Resets the state of required gradients that were toggled with toggle_optimizer().

val_dataloader()

An iterable or collection of iterables specifying validation samples.

validation_step(*args, **kwargs)

Operates on a single batch of data from the validation set.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

Attributes

CHECKPOINT_HYPER_PARAMS_KEY

CHECKPOINT_HYPER_PARAMS_NAME

CHECKPOINT_HYPER_PARAMS_TYPE

T_destination

automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

call_super_init

current_epoch

The current epoch in the Trainer, or 0 if not attached.

device

device_mesh

Strategies like ModelParallelStrategy will create a device mesh that can be accessed in the configure_model() hook to parallelize the LightningModule.

dtype

dump_patches

example_input_array

The example input array is a specification of what the module can consume in the forward() method.

fabric

global_rank

The index of the current process across all nodes and devices.

global_step

Total training batches seen across all epochs.

hparams

The collection of hyperparameters saved with save_hyperparameters().

hparams_initial

The collection of hyperparameters saved with save_hyperparameters().

local_rank

The index of the current process within a single node.

logger

Reference to the logger object in the Trainer.

loggers

Reference to the list of loggers in the Trainer.

n_conditioning_atoms

The number of conditioning atoms.

n_conditioning_dofs

The number of conditioning degrees of freedom (excluding the constrained DOFs of the reference frame atoms).

n_fixed_atoms

The number of fixed atoms.

n_mapped_atoms

The number of mapped atoms.

n_mapped_dofs

The number of mapped degrees of freedom (excluding the constrained DOFs of the reference frame atoms).

n_nonfixed_atoms

Total number of mapped and conditioning atoms.

n_nonfixed_dofs

Total number of mapped and conditioning degrees of freedom (excluding the constrained DOFs of the reference frame atoms).

on_gpu

Returns True if this model is currently located on a GPU.

strict_loading

Determines how Lightning loads this model using .load_state_dict(..., strict=model.strict_loading).

trainer

dataset

The dataset.

training

are_axes_atoms_mapped()[source]

Return whether the two axes atoms (if any) are mapped.

Returns:

are_mapped – A pair (is_axes_atom_0_mapped, is_axes_atom_1_mapped) or None if there are no axes atoms.

Return type:

None or Tuple[bool]

abstract configure_flow() Module[source]

Initialize the normalizing flow.

Note that the flow must handle only the mapped and conditioning atoms. The fixed atoms will be instead automatically wrapped in a PartialFlow.

The method must also set the flow to fix the reference frame of reference based on the origin and axes atoms. For example by using :class`~tfep.nn.flows.OrientedFlow` and :class`~tfep.nn.flows.CenteredCentroidFlow`.

Returns:

flow – The normalizing flow.

Return type:

torch.nn.Module

configure_optimizers()[source]

Lightning method.

Returns:

optimizer – The optimizer to use for the training.

Return type:

torch.optim.optimizer.Optimizer

create_dataset()[source]

Create and return the Dataset object.

Returns:

dataset – The PyTorch dataset.

Return type:

torch.utils.data.Dataset

create_partial_flow(flow: Module, return_partial: bool = False) Module[source]

Wrap the flow to remove the fixed DOFs.

Parameters:
  • flow (torch.nn.Module) – The flow to be wrapped in the Partial and/or Oriented/CenteredCentroid flows.

  • return_partial (bool, optional) – The return_partial flag of the PartialFlow.

Returns:

flow – The wrapped flow.

Return type:

torch.nn.Module

create_universe()[source]

Create and return the MDAnalysis Universe.

Returns:

universe – The MDAnalysis Universe object.

Return type:

MDAnalysis.Universe

dataset: tfep.io.dataset.TrajectoryDataset | None

The dataset.

determine_atom_indices()[source]

Determine mapped, conditioning, fixed, and reference frame atom indices.

This initializes the following attributes - self._mapped_atom_indices - self._conditioning_atom_indices - self._fixed_atom_indices - self._origin_atom_idx - self._axes_atoms_indices

forward(batch: Dict) dict[str, Tensor][source]

Execute the normalizing flow in the forward direction.

Parameters:

batch (dict[str, torch.Tensor]) – Batch data. Must have the key 'positions'.

Returns:

result – The output of the normalizing flow with at least the following keys:

  • 'positions': Shape (batch_size, n_atoms*3). The mapped coordinates of the flow.

  • 'log_det_J': Shape (batch_size,). The log weight of the transformation. This is usually the logarithm of the absolute value of the Jacobian determinant for deterministic maps, but it can be also the trace for continuous flows, or the log-ratio of the forward and backward paths in stochastic flows.

  • 'regularization': Shape (batch_size,). Optional. Arbitrary regularization terms. The mean across batches is added to the loss.

Any other tensor in this dictionary that is a scalar or of shape (batch_size,) is automatically logged. Any other tensor is ignored.

Return type:

dict[str, torch.Tensor]

get_conditioning_indices(idx_type: Literal['atom', 'dof'], remove_fixed: bool) Tensor[source]

Return the indices of the conditioning atom or degrees of freedom (DOF).

Each atom generally has 3 degrees of freedom, except for the atoms used to set the relative frame of reference. If the axes_atoms (or only one of them) have been indicated as conditioning, the returned conditioning DOFs indices also include the DOFs of the axes_atoms that are not constrained, i.e., the x coordinate of axes_atoms[0], and the x,y coordinates of axes_atoms[1]. The origin_atom is always a conditioning atom by definition, and it is thus included in the returned indices.

Parameters:
  • idx_type (Literal[‘atom’, ‘dof’]) – Whether to return the indices of the atom or the degrees of freedom.

  • remove_fixed (bool) – If True, the returned tensor represent the indices after the fixed atoms have been removed.

Returns:

indices – The conditioning atom/DOFs indices.

Return type:

torch.Tensor

get_mapped_indices(idx_type: Literal['atom', 'dof'], remove_fixed: bool) Tensor[source]

Return the indices of the mapped atom or degrees of freedom (DOF).

Each atom generally has 3 degrees of freedom, except for the atoms used to set the relative frame of reference. If the axes_atoms (or only one of them) have been indicated as mapped, the returned conditioning DOFs indices also include the DOFs of the axes_atoms that are not constrained, i.e., the x coordinate of axes_atoms[0], and the x,y coordinates of axes_atoms[1].

Parameters:
  • idx_type (Literal[‘atom’, ‘dof’]) – Whether to return the indices of the atom or the degrees of freedom.

  • remove_fixed (bool) – If True, the returned tensor represent the indices after the fixed atoms have been removed.

Returns:

indices – The mapped atom/DOFs indices.

Return type:

torch.Tensor

get_nonfixed_indices(idx_type: Literal['atom', 'dof'], remove_fixed: bool) Tensor[source]

Return the indices of the mapped and conditioning atom or degrees of freedom (DOF).

This is a more efficient way of obtaining all mapped and conditioning indices tha concatenating and sorting the results of TFEPMapBase.get_mapped_indices() and TFEPMapBase.get_conditioning_indices().

Parameters:
  • idx_type (Literal[‘atom’, ‘dof’]) – Whether to return the indices of the atom or the degrees of freedom.

  • remove_fixed (bool) – If True, the returned tensor represent the indices after the fixed atoms have been removed.

Returns:

indices – The conditioning atom/DOFs indices.

Return type:

torch.Tensor

get_reference_atoms_indices(remove_fixed: bool, separate_origin_axes: bool = False) Tensor | None | List[Tensor | None][source]

Return the atom indices of the origin and axes atoms.

Parameters:
  • remove_fixed (bool) – If True, the returned tensor represent the indices after the fixed atoms have been removed.

  • separate_origin_axes (bool, optional) – If True, the origin and axes atoms are returned separately in two Tensors. Otherwise, a single Tensor is returned.

Returns:

reference_atom_indices – If separate_origin_axes is False, a single Tensor including the indices, in this order, of the origin, axis, and plane atoms (if they exist) or None if there are no origin and axes atoms.

If separate_origin_axes is False, this is a pair of Tensors (or None) holding the origin atom index and the axes atom indices.

Return type:

torch.Tensor or None or List[torch.Tensor | None]

inverse(batch: Dict) dict[str, Tensor][source]

Execute the normalizing flow in the inverse direction.

See TFEPMapBase.forward() for the documentation on the input parameters and returned value.

property n_conditioning_atoms: int

The number of conditioning atoms.

property n_conditioning_dofs: int

The number of conditioning degrees of freedom (excluding the constrained DOFs of the reference frame atoms).

property n_fixed_atoms: int

The number of fixed atoms.

property n_mapped_atoms: int

The number of mapped atoms.

property n_mapped_dofs: int

The number of mapped degrees of freedom (excluding the constrained DOFs of the reference frame atoms).

property n_nonfixed_atoms: int

Total number of mapped and conditioning atoms.

property n_nonfixed_dofs: int

Total number of mapped and conditioning degrees of freedom (excluding the constrained DOFs of the reference frame atoms).

on_load_checkpoint(checkpoint: Dict[str, Any])[source]

Lightning hook.

Used to restore the state of the batch sampler for mid-epoch resuming.

on_save_checkpoint(checkpoint: Dict[str, Any])[source]

Lightning hook.

Used to store the state of the batch sampler for mid-epoch resuming.

setup(stage: str = 'fit')[source]

Lightning method.

This is executed on all processes by Lightning in DDP mode (contrary to __init__) and can be used to initialize objects like the Dataset and all data-dependent objects.

train_dataloader()[source]

Lightning method.

Returns:

data_loader – The training data loader.

Return type:

torch.utils.data.DataLoader

training_step(batch, batch_idx)[source]

Lightning method.

Execute a training step.