tfep.app.cartesianmaf.CartesianMAFMap
- class tfep.app.cartesianmaf.CartesianMAFMap(potential_energy_func: Module, topology_file_path: str, coordinates_file_path: str | Sequence[str], temperature: Quantity, batch_size: int = 1, mapped_atoms: Sequence[int] | str | None = None, conditioning_atoms: Sequence[int] | str | None = None, origin_atom: int | str | None = None, axes_atoms: Sequence[int] | str | None = None, tfep_logger_dir_path: str = 'tfep_logs', n_maf_layers: int = 6, dataloader_kwargs: Dict | None = None, **kwargs)[source]
Bases:
TFEPMapBaseA TFEP map using a masked autoregressive flow in Cartesian coordinates.
The class divides the atoms of the entire system in mapped, conditioning, and fixed. Mapped atoms are defined as those that the flow maps. Conditioning atoms are not mapped but are given as input to the flow to condition the mapping. Fixed atoms are instead ignored.
Optionally, the flow can map the atoms in a relative frame of reference which based on the position of an
origin_atomand twoaxes_atomsthat determine the origin and the orientation of the axes, respectively.The class further supports logging the potential energies computed during training (required for the multimap TFEP analysis) and mid-epoch resuming.
Warning
Currently, this class is not multi-process or thread safe. Running with multiple processes may result in the corrupted logging of the potentials and Jacobians.
See also
tfep.app.base.TFEPMapBaseThe base class for TFEP maps with more detailed explanations on the division in mapped/conditioning/fixed atoms.
Examples
>>> from tfep.potentials.psi4 import Psi4Potential >>> units = pint.UnitRegistry() >>> >>> tfep_map = CartesianMAFMap( ... potential_energy_func=Psi4Potential(name='mp2'), ... topology_file_path='path/to/topology.psf', ... coordinates_file_path='path/to/trajectory.dcd', ... temperature=300*units.kelvin, ... batch_size=64, ... mapped_atoms='resname MOL', # MDAnalysis selection syntax. ... conditioning_atoms=range(10, 20), ... origin_atom=12, # Fix the origin of the relative reference frame on atom 123. ... axes_atoms=[13, 16], # Determine the orientation of the reference frame. ... ) >>> >>> # Train the flow and save the potential energies. >>> import lightning >>> trainer = lightning.Trainer() >>> trainer.fit(tfep_map)
- __init__(potential_energy_func: Module, topology_file_path: str, coordinates_file_path: str | Sequence[str], temperature: Quantity, batch_size: int = 1, mapped_atoms: Sequence[int] | str | None = None, conditioning_atoms: Sequence[int] | str | None = None, origin_atom: int | str | None = None, axes_atoms: Sequence[int] | str | None = None, tfep_logger_dir_path: str = 'tfep_logs', n_maf_layers: int = 6, dataloader_kwargs: Dict | None = None, **kwargs)[source]
Constructor.
- Parameters:
potential_energy_func (torch.nn.Module) – A PyTorch module encapsulating the target potential energy function (e.g.
tfep.potentials.psi4.ASEPotential).topology_file_path (str) – The path to the topology file. The file can be in any format supported by MDAnalysis which is automatically detected from the file extension.
coordinates_file_path (str or Sequence[str]) – The path(s) to the trajectory file(s). If a sequence of files is given, the trajectories are concatenated into a single large dataset. The file(s) can be in any format supported by MDAnalysis which is automatically detected from the file extension.
temperature (pint.Quantity) – The temperature of the ensemble.
batch_size (int, optional) – The batch size.
mapped_atoms (Sequence[int] or str, optional) – The indices (0-based) of the atoms to map or a selection string in MDAnalysis syntax. If not passed, all atoms that are not conditioning are mapped (i.e., all atoms are mapped if also
conditioning_atomsis not given.conditioning_atoms (Sequence[int] or str, optional) – The indices (0-based) of the atoms conditioning the mapping or a selection string in MDAnalysis syntax. If not passed, no atom will condition the map.
origin_atom (int or str or None, optional) – The index (0-based) or a selection string in MDAnalysis syntax of an atom on which to center the origin of the relative frame of reference. While this atom affects the mapping of the mapped atoms, its position will be constrained during the mapping, and thus it must be a conditioning atom by definition.
axes_atoms (Sequence[int] or str or None, optional) – A pair of indices (0-based) or a selection string in MDAnalysis syntax for the two atoms determining the relative frame of reference. The
axes_atoms[0]-th atom will lay on thezaxis , and theaxes_atoms[1]-th atom will lay on the plane spanned by thexandzaxes. Theyaxis will be set as the cross product ofxandz.These atoms can be either conditioning or mapped.
axes_atoms[0]has only 1 degree of freedom (DOF) whileaxes_atoms[1]has 2. Whether these DOFs are mapped or not depends on whether their atoms are indicated as mapped or conditioning, respectively.tfep_logger_dir_path (str, optional) – The path where to save TFEP-related information (potential energies, sample indices, etc.).
n_maf_layers (int, optional) – The number of MAF layers.
dataloader_kwargs (Dict, optional) – Extra keyword arguments to pass to
torch.utils.data.DataLoader.**kwargs – Other keyword arguments to pass to the constructor of
tfep.nn.flows.MAF.
See also
Methods
__init__(potential_energy_func, ...[, ...])Constructor.
add_module(name, module)Add a child module to the current module.
all_gather(data[, group, sync_grads])Gather tensors or collections of tensors from multiple processes.
apply(fn)Apply
fnrecursively to every submodule (as returned by.children()) as well as self.Return whether the two axes atoms (if any) are mapped.
backward(loss, *args, **kwargs)Called to perform backward on the loss returned in
training_step().bfloat16()Casts all floating point parameters and buffers to
bfloat16datatype.buffers([recurse])Return an iterator over module buffers.
children()Return an iterator over immediate children modules.
clip_gradients(optimizer[, ...])Handles gradient clipping internally.
compile(*args, **kwargs)Compile this Module's forward using
torch.compile().configure_callbacks()Configure model-specific callbacks.
Initialize the normalizing flow.
configure_gradient_clipping(optimizer[, ...])Perform gradient clipping for the optimizer parameters.
configure_model()Hook to create modules in a strategy and precision aware context.
Lightning method.
configure_sharded_model()Deprecated.
cpu()See
torch.nn.Module.cpu().Create and return the
Datasetobject.create_partial_flow(flow[, return_partial])Wrap the flow to remove the fixed DOFs.
Create and return the MDAnalysis
Universe.cuda([device])Moves all model parameters and buffers to the GPU.
Override CartesianMAFMap.determine_atom_indices to check that the origin atom is conditioning.
double()See
torch.nn.Module.double().eval()Set the module in evaluation mode.
extra_repr()Set the extra representation of the module.
float()See
torch.nn.Module.float().forward(batch)Execute the normalizing flow in the forward direction.
freeze()Freeze all params for inference.
get_buffer(target)Return the buffer given by
targetif it exists, otherwise throw an error.get_conditioning_indices(idx_type, remove_fixed)Return the indices of the conditioning atom or degrees of freedom (DOF).
get_extra_state()Return any extra state to include in the module's state_dict.
get_mapped_indices(idx_type, remove_fixed[, ...])Return the indices of the mapped atom or degrees of freedom (DOF).
get_nonfixed_indices(idx_type, remove_fixed)Return the indices of the mapped and conditioning atom or degrees of freedom (DOF).
get_parameter(target)Return the parameter given by
targetif it exists, otherwise throw an error.get_reference_atoms_indices(remove_fixed[, ...])Return the atom indices of the origin and axes atoms.
get_submodule(target)Return the submodule given by
targetif it exists, otherwise throw an error.half()See
torch.nn.Module.half().inverse(batch)Execute the normalizing flow in the inverse direction.
ipu([device])Move all model parameters and buffers to the IPU.
load_from_checkpoint(checkpoint_path[, ...])Primary way of loading a model from a checkpoint.
load_state_dict(state_dict[, strict, assign])Copy parameters and buffers from
state_dictinto this module and its descendants.log(name, value[, prog_bar, logger, ...])Log a key, value pair.
log_dict(dictionary[, prog_bar, logger, ...])Log a dictionary of values at once.
lr_scheduler_step(scheduler, metric)Override this method to adjust the default way the
Trainercalls each scheduler.lr_schedulers()Returns the learning rate scheduler(s) that are being used during training.
manual_backward(loss, *args, **kwargs)Call this directly from your
training_step()when doing optimizations manually.modules()Return an iterator over all modules in the network.
mtia([device])Move all model parameters and buffers to the MTIA.
named_buffers([prefix, recurse, ...])Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children()Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules([memo, prefix, remove_duplicate])Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters([prefix, recurse, ...])Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
on_after_backward()Called after
loss.backward()and before optimizers are stepped.on_after_batch_transfer(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch after it is transferred to the device.
on_before_backward(loss)Called before
loss.backward().on_before_batch_transfer(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch before it is transferred to the device.
on_before_optimizer_step(optimizer)Called before
optimizer.step().on_before_zero_grad(optimizer)Called after
training_step()and beforeoptimizer.zero_grad().on_fit_end()Called at the very end of fit.
on_fit_start()Called at the very beginning of fit.
on_load_checkpoint(checkpoint)Lightning hook.
on_predict_batch_end(outputs, batch, batch_idx)Called in the predict loop after the batch.
on_predict_batch_start(batch, batch_idx[, ...])Called in the predict loop before anything happens for that batch.
on_predict_end()Called at the end of predicting.
on_predict_epoch_end()Called at the end of predicting.
on_predict_epoch_start()Called at the beginning of predicting.
on_predict_model_eval()Called when the predict loop starts.
on_predict_start()Called at the beginning of predicting.
on_save_checkpoint(checkpoint)Lightning hook.
on_test_batch_end(outputs, batch, batch_idx)Called in the test loop after the batch.
on_test_batch_start(batch, batch_idx[, ...])Called in the test loop before anything happens for that batch.
on_test_end()Called at the end of testing.
on_test_epoch_end()Called in the test loop at the very end of the epoch.
on_test_epoch_start()Called in the test loop at the very beginning of the epoch.
on_test_model_eval()Called when the test loop starts.
on_test_model_train()Called when the test loop ends.
on_test_start()Called at the beginning of testing.
on_train_batch_end(outputs, batch, batch_idx)Called in the training loop after the batch.
on_train_batch_start(batch, batch_idx)Called in the training loop before anything happens for that batch.
on_train_end()Called at the end of training before logger experiment is closed.
on_train_epoch_end()Called in the training loop at the very end of the epoch.
on_train_epoch_start()Called in the training loop at the very beginning of the epoch.
on_train_start()Called at the beginning of training after sanity check.
on_validation_batch_end(outputs, batch, ...)Called in the validation loop after the batch.
on_validation_batch_start(batch, batch_idx)Called in the validation loop before anything happens for that batch.
on_validation_end()Called at the end of validation.
on_validation_epoch_end()Called in the validation loop at the very end of the epoch.
on_validation_epoch_start()Called in the validation loop at the very beginning of the epoch.
on_validation_model_eval()Called when the validation loop starts.
on_validation_model_train()Called when the validation loop ends.
on_validation_model_zero_grad()Called by the training loop to release gradients before entering the validation loop.
on_validation_start()Called at the beginning of validation.
optimizer_step(epoch, batch_idx, optimizer)Override this method to adjust the default way the
Trainercalls the optimizer.optimizer_zero_grad(epoch, batch_idx, optimizer)Override this method to change the default behaviour of
optimizer.zero_grad().optimizers([use_pl_optimizer])Returns the optimizer(s) that are being used during training.
parameters([recurse])Return an iterator over module parameters.
predict_dataloader()An iterable or collection of iterables specifying prediction samples.
predict_step(*args, **kwargs)Step function called during
predict().prepare_data()Use this to download and prepare data.
print(*args, **kwargs)Prints only from process 0.
register_backward_hook(hook)Register a backward hook on the module.
register_buffer(name, tensor[, persistent])Add a buffer to the module.
register_forward_hook(hook, *[, prepend, ...])Register a forward hook on the module.
register_forward_pre_hook(hook, *[, ...])Register a forward pre-hook on the module.
register_full_backward_hook(hook[, prepend])Register a backward hook on the module.
register_full_backward_pre_hook(hook[, prepend])Register a backward pre-hook on the module.
register_load_state_dict_post_hook(hook)Register a post-hook to be run after module's
load_state_dict()is called.register_load_state_dict_pre_hook(hook)Register a pre-hook to be run before module's
load_state_dict()is called.register_module(name, module)Alias for
add_module().register_parameter(name, param)Add a parameter to the module.
register_state_dict_post_hook(hook)Register a post-hook for the
state_dict()method.register_state_dict_pre_hook(hook)Register a pre-hook for the
state_dict()method.requires_grad_([requires_grad])Change if autograd should record operations on parameters in this module.
save_hyperparameters(*args[, ignore, frame, ...])Save arguments to
hparamsattribute.set_extra_state(state)Set extra state contained in the loaded state_dict.
set_submodule(target, module)Set the submodule given by
targetif it exists, otherwise throw an error.setup([stage])Lightning method.
share_memory()See
torch.Tensor.share_memory_().state_dict(*args[, destination, prefix, ...])Return a dictionary containing references to the whole state of the module.
teardown(stage)Called at the end of fit (train + validate), validate, test, or predict.
test_dataloader()An iterable or collection of iterables specifying test samples.
test_step(*args, **kwargs)Operates on a single batch of data from the test set.
to(*args, **kwargs)See
torch.nn.Module.to().to_empty(*, device[, recurse])Move the parameters and buffers to the specified device without copying storage.
to_onnx(file_path[, input_sample])Saves the model in ONNX format.
to_torchscript([file_path, method, ...])By default compiles the whole model to a
ScriptModule.toggle_optimizer(optimizer)Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.
train([mode])Set the module in training mode.
Lightning method.
training_step(batch, batch_idx)Lightning method.
transfer_batch_to_device(batch, device, ...)Override this hook if your
DataLoaderreturns tensors wrapped in a custom data structure.type(dst_type)See
torch.nn.Module.type().unfreeze()Unfreeze all parameters for training.
untoggle_optimizer(optimizer)Resets the state of required gradients that were toggled with
toggle_optimizer().val_dataloader()An iterable or collection of iterables specifying validation samples.
validation_step(*args, **kwargs)Operates on a single batch of data from the validation set.
xpu([device])Move all model parameters and buffers to the XPU.
zero_grad([set_to_none])Reset gradients of all model parameters.
Attributes
CHECKPOINT_HYPER_PARAMS_KEYCHECKPOINT_HYPER_PARAMS_NAMECHECKPOINT_HYPER_PARAMS_TYPET_destinationautomatic_optimizationIf set to
Falseyou are responsible for calling.backward(),.step(),.zero_grad().call_super_initcurrent_epochThe current epoch in the
Trainer, or 0 if not attached.devicedevice_meshStrategies like
ModelParallelStrategywill create a device mesh that can be accessed in theconfigure_model()hook to parallelize the LightningModule.dtypedump_patchesexample_input_arrayThe example input array is a specification of what the module can consume in the
forward()method.fabricglobal_rankThe index of the current process across all nodes and devices.
global_stepTotal training batches seen across all epochs.
hparamsThe collection of hyperparameters saved with
save_hyperparameters().hparams_initialThe collection of hyperparameters saved with
save_hyperparameters().local_rankThe index of the current process within a single node.
loggerReference to the logger object in the Trainer.
loggersReference to the list of loggers in the Trainer.
The number of conditioning atoms.
The number of conditioning degrees of freedom (excluding the constrained DOFs of the reference frame atoms).
The number of fixed atoms.
The number of mapped atoms.
The number of mapped degrees of freedom (excluding the constrained DOFs of the reference frame atoms).
Total number of mapped and conditioning atoms.
Total number of mapped and conditioning degrees of freedom (excluding the constrained DOFs of the reference frame atoms).
on_gpuReturns
Trueif this model is currently located on a GPU.strict_loadingDetermines how Lightning loads this model using .load_state_dict(..., strict=model.strict_loading).
trainerThe dataset.
training- are_axes_atoms_mapped()
Return whether the two axes atoms (if any) are mapped.
- Returns:
are_mapped – A pair
(is_axes_atom_0_mapped, is_axes_atom_1_mapped)orNoneif there are no axes atoms.- Return type:
None or Tuple[bool]
- configure_flow() Module[source]
Initialize the normalizing flow.
- Returns:
flow – The normalizing flow.
- Return type:
torch.nn.Module
- configure_optimizers()
Lightning method.
- Returns:
optimizer – The optimizer to use for the training.
- Return type:
torch.optim.optimizer.Optimizer
- create_dataset()
Create and return the
Datasetobject.- Returns:
dataset – The PyTorch dataset.
- Return type:
torch.utils.data.Dataset
- create_partial_flow(flow: Module, return_partial: bool = False) Module
Wrap the flow to remove the fixed DOFs.
- Parameters:
flow (torch.nn.Module) – The flow to be wrapped in the Partial and/or Oriented/CenteredCentroid flows.
return_partial (bool, optional) – The
return_partialflag of thePartialFlow.
- Returns:
flow – The wrapped flow.
- Return type:
torch.nn.Module
- create_universe()
Create and return the MDAnalysis
Universe.- Returns:
universe – The MDAnalysis
Universeobject.- Return type:
MDAnalysis.Universe
- dataset: tfep.io.dataset.TrajectoryDataset | None
The dataset.
- determine_atom_indices()[source]
Override CartesianMAFMap.determine_atom_indices to check that the origin atom is conditioning.
- forward(batch: Dict) dict[str, Tensor]
Execute the normalizing flow in the forward direction.
- Parameters:
batch (dict[str, torch.Tensor]) – Batch data. Must have the key
'positions'.- Returns:
result – The output of the normalizing flow with at least the following keys:
'positions': Shape(batch_size, n_atoms*3). The mapped coordinates of the flow.'log_det_J': Shape(batch_size,). The log weight of the transformation. This is usually the logarithm of the absolute value of the Jacobian determinant for deterministic maps, but it can be also the trace for continuous flows, or the log-ratio of the forward and backward paths in stochastic flows.'regularization': Shape(batch_size,). Optional. Arbitrary regularization terms. The mean across batches is added to the loss.
Any other tensor in this dictionary that is a scalar or of shape
(batch_size,)is automatically logged. Any other tensor is ignored.- Return type:
dict[str, torch.Tensor]
- get_conditioning_indices(idx_type: Literal['atom', 'dof'], remove_fixed: bool, remove_reference: bool = False) Tensor[source]
Return the indices of the conditioning atom or degrees of freedom (DOF).
Each atom generally has 3 degrees of freedom, except for the atoms used to set the relative frame of reference. If the
axes_atoms(or only one of them) have been indicated as conditioning, the returned conditioning DOFs indices also include the DOFs of theaxes_atomsthat are not constrained, i.e., thexcoordinate ofaxes_atoms[0], and thex,ycoordinates ofaxes_atoms[1]. Theorigin_atomis always a conditioning atom by definition, and it is thus included in the returned indices unlessremove_constrained is True.- Parameters:
idx_type (Literal[‘atom’, ‘dof’]) – Whether to return the indices of the atom or the degrees of freedom.
remove_fixed (bool) – If
True, the returned tensor represent the indices after the fixed atoms have been removed.remove_reference (bool, optional) – If
True, the returned tensor represent the indices after the reference frame atoms (i.e., origin and axes atoms) have been removed. Note that ifidx_type == 'dof', only the constrained DOFs of the reference frame atoms are removed.
- Returns:
indices – The conditioning atom/DOFs indices.
- Return type:
torch.Tensor
- get_mapped_indices(idx_type: Literal['atom', 'dof'], remove_fixed: bool, remove_reference: bool = False) Tensor[source]
Return the indices of the mapped atom or degrees of freedom (DOF).
Each atom generally has 3 degrees of freedom, except for the atoms used to set the relative frame of reference. If the
axes_atoms(or only one of them) have been indicated as mapped, the returned conditioning DOFs indices also include the DOFs of theaxes_atomsthat are not constrained, i.e., thexcoordinate ofaxes_atoms[0], and thex,ycoordinates ofaxes_atoms[1].- Parameters:
idx_type (Literal[‘atom’, ‘dof’]) – Whether to return the indices of the atom or the degrees of freedom.
remove_fixed (bool) – If
True, the returned tensor represent the indices after the fixed atoms have been removed.remove_reference (bool, optional) – If
True, the returned tensor represent the indices after the reference frame atoms (i.e., origin and axes atoms) have been removed. Note that ifidx_type == 'dof', only the constrained DOFs of the reference frame atoms are removed.
- Returns:
indices – The mapped atom/DOFs indices.
- Return type:
torch.Tensor
- get_nonfixed_indices(idx_type: Literal['atom', 'dof'], remove_fixed: bool) Tensor
Return the indices of the mapped and conditioning atom or degrees of freedom (DOF).
This is a more efficient way of obtaining all mapped and conditioning indices tha concatenating and sorting the results of
TFEPMapBase.get_mapped_indices()andTFEPMapBase.get_conditioning_indices().- Parameters:
idx_type (Literal[‘atom’, ‘dof’]) – Whether to return the indices of the atom or the degrees of freedom.
remove_fixed (bool) – If
True, the returned tensor represent the indices after the fixed atoms have been removed.
- Returns:
indices – The conditioning atom/DOFs indices.
- Return type:
torch.Tensor
- get_reference_atoms_indices(remove_fixed: bool, separate_origin_axes: bool = False) Tensor | None | List[Tensor | None]
Return the atom indices of the origin and axes atoms.
- Parameters:
remove_fixed (bool) – If
True, the returned tensor represent the indices after the fixed atoms have been removed.separate_origin_axes (bool, optional) – If
True, the origin and axes atoms are returned separately in twoTensors. Otherwise, a singleTensoris returned.
- Returns:
reference_atom_indices – If
separate_origin_axes is False, a singleTensorincluding the indices, in this order, of the origin, axis, and plane atoms (if they exist) orNoneif there are no origin and axes atoms.If
separate_origin_axes is False, this is a pair ofTensors(orNone) holding the origin atom index and the axes atom indices.- Return type:
torch.Tensor or None or List[torch.Tensor | None]
- inverse(batch: Dict) dict[str, Tensor]
Execute the normalizing flow in the inverse direction.
See
TFEPMapBase.forward()for the documentation on the input parameters and returned value.
- property n_conditioning_atoms: int
The number of conditioning atoms.
- property n_conditioning_dofs: int
The number of conditioning degrees of freedom (excluding the constrained DOFs of the reference frame atoms).
- property n_fixed_atoms: int
The number of fixed atoms.
- property n_mapped_atoms: int
The number of mapped atoms.
- property n_mapped_dofs: int
The number of mapped degrees of freedom (excluding the constrained DOFs of the reference frame atoms).
- property n_nonfixed_atoms: int
Total number of mapped and conditioning atoms.
- property n_nonfixed_dofs: int
Total number of mapped and conditioning degrees of freedom (excluding the constrained DOFs of the reference frame atoms).
- on_load_checkpoint(checkpoint: Dict[str, Any])
Lightning hook.
Used to restore the state of the batch sampler for mid-epoch resuming.
- on_save_checkpoint(checkpoint: Dict[str, Any])
Lightning hook.
Used to store the state of the batch sampler for mid-epoch resuming.
- setup(stage: str = 'fit')
Lightning method.
This is executed on all processes by Lightning in DDP mode (contrary to
__init__) and can be used to initialize objects like theDatasetand all data-dependent objects.
- train_dataloader()
Lightning method.
- Returns:
data_loader – The training data loader.
- Return type:
torch.utils.data.DataLoader
- training_step(batch, batch_idx)
Lightning method.
Execute a training step.