tfep.io.dataset.traj.TrajectoryDataset

class tfep.io.dataset.traj.TrajectoryDataset(universe, return_dataset_sample_index=True, return_trajectory_sample_index=True)[source]

Bases: Dataset

PyTorch Dataset wrapping an MDAnalysis trajectory.

The class wraps an MDAnalysis.Universe object and provide the interface of a Pytorch Dataset to enable the iteration of the trajectory in batches.

When iterating each batch sample is a dictionary including the following keys:

  • "positions": The coordinates of the system in MDAnalysis units as a

    torch.Tensor of shape (batch_size, n_atoms * 3).

  • "dataset_sample_index" (optional): The index in the dataset if

    return_dataset_sample_index is True. This is useful to match the frame index when the dataset is shuffled.

  • "trajectory_sample_index" (optional): The index in the trajectory if

    return_trajectory_sample_index is True. This is useful to match the data point to the trajectory frame index.

  • "aux1" (optional): The name of eventual auxiliary information found

    in the universe.trajectory.aux dictionary.

  • "aux2"

Parameters:
  • universe (MDAnalysis.Universe) – An MDAnalysis Universe object encapsulating both the topology and the trajectory.

  • return_dataset_sample_index (bool, optional) – If True, the keyword "dataset_sample_index" is included in the batch sample when iterating over the dataset.

  • return_trajectory_sample_index (bool, optional) – If True, the keyword "trajectory_sample_index" is included in the batch sample when iterating over the dataset.

Variables:
  • universe (MDAnalysis.Universe) – The MDAnalysis Universe object encapsulated by the dataset.

  • return_dataset_sample_index (bool, optional) – Whether to return the keyword "dataset_sample_index" in the batch sample.

  • return_trajectory_sample_index (bool, optional) – Whether to return the keyword "trajectory_sample_index" in the batch sample.

Examples

First, you need to create an MDAnalysis Universe (see the MDAnalysis documentation). In this case, we load a short trajectory with a timestep of 5 ps.

>>> import os
>>> import MDAnalysis
>>> test_data_dir_path = os.path.join(os.path.dirname(__file__), '..', '..', 'tests', 'data')
>>> pdb_file_path = os.path.join(test_data_dir_path, 'chloro-fluoromethane.pdb')
>>> universe = MDAnalysis.Universe(pdb_file_path, dt=5)  # ps

TrajectoryDataset objects can be used as a normal PyTorch Dataset.

>>> import torch.utils.data
>>> trajectory_dataset = TrajectoryDataset(universe)
>>> data_loader = torch.utils.data.DataLoader(trajectory_dataset, batch_size=2, drop_last=True)
>>> for batch in data_loader:
...     batch_positions = batch['positions']

By default, TrajectoryDataset flattens the coordinates of the trajectory frames so that each batch has shape (batch_size, n_atoms * 3).

>>> trajectory_dataset.n_atoms
6
>>> batch_positions.shape
torch.Size([2, 18])

Only a subset of frames in the trajectories can be included in the dataset by subsampling the trajectory at regular intervals (which can be given in number of frames or in units of time). The following overwrites the previous selection and discards the first and every other frame until the trajectory reaches 20 ps of length (limits are included).

>>> import pint
>>> ureg = pint.UnitRegistry()
>>> trajectory_dataset.subsample(
...     start=1, stop=20*ureg.picoseconds, step=2)
>>> len(trajectory_dataset)
2

It is also possible to select a subgroup of atoms to include in the dataset. The string must follow MDAnalysis selection syntax. This selects the first 5 atoms.

>>> trajectory_dataset.select_atoms('index 0:4')
>>> trajectory_dataset.n_atoms
5

Auxiliary information in the MDAnalysis Trajectory is also automatically discovered and returned while iterating.

>>> trajectory_dataset.universe.trajectory.add_auxiliary(
...    'my_aux_name', os.path.join(test_data_dir_path, 'auxiliary.xvg'))
>>> data_loader = torch.utils.data.DataLoader(trajectory_dataset, batch_size=2)
>>> for batch in data_loader:
...     aux_info = batch['my_aux_name']
__init__(universe, return_dataset_sample_index=True, return_trajectory_sample_index=True)[source]

Methods

__init__(universe[, ...])

get_timestep(idx)

Return the MDAnalysis Timestep object for the given index.

iterate_as_timestep()

Iterate over the selected frames/atoms as MDAnalysis Timestep objects.

select_atoms(selection)

Select a subset of atoms.

subsample([start, stop, step, n_frames])

Select a subset of trajectory frames by subsampling it at regular intervals.

Attributes

n_atoms

Number of selected atoms in the dataset.

get_timestep(idx)[source]

Return the MDAnalysis Timestep object for the given index.

Parameters:

index (int) – The frame index of the timestep. This must always be comprised between 0 and len(TrajectoryDataset). Note that the dataset might contain a smaller of frames than the full trajectory if a subset of frames was selected (for example, with subsample()).

Returns:

ts – The MDAnalysis Timestep object with coordinate information of the index-th frame. If a subset of atoms was selected (for example with select_atoms()) only the coordinates of those atoms are returned.

Return type:

MDAnalysis.coordinates.base.Timestep

iterate_as_timestep()[source]

Iterate over the selected frames/atoms as MDAnalysis Timestep objects.

Iterating over a TrajectoryDataset returns the trajectory information in torch.Tensor format. This method enables iterating the samples in the dataset as MDAnalysis Timestep objects.

Note that it is still possible to iterate over Timestep objects using the MDAnalysis API and TrajectoryDataset.universe.trajectory. However, in this case, the selections of frames/atoms performed at the TrajectoryDataset level are ignored and all frames/atoms are returned.

Yields:

ts (MDAnalysis.coordinates.base.Timestep) – The current Timestep object.

property n_atoms

Number of selected atoms in the dataset.

select_atoms(selection)[source]

Select a subset of atoms.

Iterating over the dataset after selecting a subset of atoms yields only the coordinates of these atoms.

For more information about the selection syntax consult the MDAnalysis documentation.

Parameters:

selection (str) – The selection string following the MDAnalysis selection syntax.

subsample(start=None, stop=None, step=None, n_frames=None)[source]

Select a subset of trajectory frames by subsampling it at regular intervals.

This function does not modify the trajectory. Thus TrajectoryDataset.universe.trajectory still have the same number of frames. However, when iterating over the TrajectoryDataset only the subsampled frames are returned.

start, stop, and step can be given either as number of frames or in units of time as pint.Quantity. If the latter, the initial time of the simulation t0 is taken into account. Note that this might not be zero if, for example, the simulation was resumed. For example, if start is 2 ns and the simulation starts at 1 ns, only the first ns of data is discarded.

Parameters:
  • start (int or pint.Quantity, optional) – The first frame to include in the dataset specified either as a frame index or in simulation time. If not provided, the subsampling starts from the first frame in the trajectory.

  • stop (int or pint.Quantity, optional) – The last frame to include in the dataset specified either as a frame index or in simulation time. If not provided, the subsampling ends at the last frame in the trajectory.

  • step (int or pint.Quantity, optional) – The step used for subsampling specified either as a frame index or or in simulation time. Only one between step and n_frames may be passed.

  • n_frames (int, optional) – The total number of frames to include in the dataset. If this is passed, the step will automatically be determined to satisfy this requirement. Note that in this case the obtained samples in the dataset might not be equally spaced if n_frames is not an exact divisor of the number of frames. Only one between step and n_frames may be passed.