#!/usr/bin/env python
# =============================================================================
# MODULE DOCSTRING
# =============================================================================
"""
Utility classes to create PyTorch ``Dataset``s from MDAnalysis trajectories.
The module provides a class :class:`.TrajectoryDataset` that wraps an
MDAnalysis ``Universe`` object (i.e., an object tying a topology and a
trajectory) and implements PyTorch's ``Dataset`` interface. This can be used,
for example, to specify the training dataset for the neural network
implementing the mapping function of targeted free energy perturbation.
The :class:`.TrajectoryDataset` can be subsampled at constant time interval,
while arbitrary subsets of the trajectory can be instead be created using the
:class:`.TrajectorySubset` class.
For usage examples see the documentation of
- :class:`.TrajectoryDataset`
- :class:`.TrajectorySubset`
"""
# =============================================================================
# GLOBAL IMPORTS
# =============================================================================
import copy
import numpy as np
import pint
import torch.utils.data
# =============================================================================
# TRAJECTORY DATASET
# =============================================================================
[docs]
class TrajectoryDataset(torch.utils.data.Dataset):
"""PyTorch ``Dataset`` wrapping an MDAnalysis trajectory.
The class wraps an ``MDAnalysis.Universe`` object and provide the interface
of a Pytorch ``Dataset`` to enable the iteration of the trajectory in batches.
When iterating each batch sample is a dictionary including the following
keys:
- ``"positions"``: The coordinates of the system in MDAnalysis units as a
``torch.Tensor`` of shape ``(batch_size, n_atoms * 3)``.
- ``"dataset_sample_index"`` (optional): The index in the dataset if
``return_dataset_sample_index`` is ``True``. This is useful to match
the frame index when the dataset is shuffled.
- ``"trajectory_sample_index"`` (optional): The index in the trajectory if
``return_trajectory_sample_index`` is ``True``. This is useful to match
the data point to the trajectory frame index.
- ``"aux1"`` (optional): The name of eventual auxiliary information found
in the ``universe.trajectory.aux`` dictionary.
- ``"aux2"`` ...
Parameters
----------
universe : MDAnalysis.Universe
An MDAnalysis ``Universe`` object encapsulating both the topology and
the trajectory.
return_dataset_sample_index : bool, optional
If ``True``, the keyword ``"dataset_sample_index"`` is included in the
batch sample when iterating over the dataset.
return_trajectory_sample_index : bool, optional
If ``True``, the keyword ``"trajectory_sample_index"`` is included in
the batch sample when iterating over the dataset.
Attributes
----------
universe : MDAnalysis.Universe
The MDAnalysis ``Universe`` object encapsulated by the dataset.
return_dataset_sample_index : bool, optional
Whether to return the keyword ``"dataset_sample_index"`` in the batch
sample.
return_trajectory_sample_index : bool, optional
Whether to return the keyword ``"trajectory_sample_index"`` in the batch
sample.
Examples
--------
First, you need to create an MDAnalysis ``Universe`` (see the
`MDAnalysis documentation <https://userguide.mdanalysis.org/stable/>`_).
In this case, we load a short trajectory with a timestep of 5 ps.
>>> import os
>>> import MDAnalysis
>>> test_data_dir_path = os.path.join(os.path.dirname(__file__), '..', '..', 'tests', 'data')
>>> pdb_file_path = os.path.join(test_data_dir_path, 'chloro-fluoromethane.pdb')
>>> universe = MDAnalysis.Universe(pdb_file_path, dt=5) # ps
``TrajectoryDataset`` objects can be used as a normal PyTorch ``Dataset``.
>>> import torch.utils.data
>>> trajectory_dataset = TrajectoryDataset(universe)
>>> data_loader = torch.utils.data.DataLoader(trajectory_dataset, batch_size=2, drop_last=True)
>>> for batch in data_loader:
... batch_positions = batch['positions']
By default, ``TrajectoryDataset`` flattens the coordinates of the trajectory
frames so that each batch has shape ``(batch_size, n_atoms * 3)``.
>>> trajectory_dataset.n_atoms
6
>>> batch_positions.shape
torch.Size([2, 18])
Only a subset of frames in the trajectories can be included in the dataset
by subsampling the trajectory at regular intervals (which can be given in
number of frames or in units of time). The following overwrites the previous
selection and discards the first and every other frame until the trajectory
reaches 20 ps of length (limits are included).
>>> import pint
>>> ureg = pint.UnitRegistry()
>>> trajectory_dataset.subsample(
... start=1, stop=20*ureg.picoseconds, step=2)
>>> len(trajectory_dataset)
2
It is also possible to select a subgroup of atoms to include in the dataset.
The string must follow `MDAnalysis selection syntax <https://docs.mdanalysis.org/stable/documentation_pages/selections.html>`_.
This selects the first 5 atoms.
>>> trajectory_dataset.select_atoms('index 0:4')
>>> trajectory_dataset.n_atoms
5
Auxiliary information in the MDAnalysis ``Trajectory`` is also automatically
discovered and returned while iterating.
>>> trajectory_dataset.universe.trajectory.add_auxiliary(
... 'my_aux_name', os.path.join(test_data_dir_path, 'auxiliary.xvg'))
>>> data_loader = torch.utils.data.DataLoader(trajectory_dataset, batch_size=2)
>>> for batch in data_loader:
... aux_info = batch['my_aux_name']
"""
[docs]
def __init__(
self,
universe,
return_dataset_sample_index=True,
return_trajectory_sample_index=True,
):
super().__init__()
self.universe = universe
self.return_dataset_sample_index = return_dataset_sample_index
self.return_trajectory_sample_index = return_trajectory_sample_index
# The indexes of the selected trajectory frames. None means all frames.
self.trajectory_sample_indices = None
# The MDAnalysis.core.groups.AtomGroup object encapsulating the atom
# selection. None means all atoms.
self._selected_atom_group = None
@property
def n_atoms(self):
"""Number of selected atoms in the dataset."""
if self._selected_atom_group is None:
return self.universe.atoms.n_atoms
return self._selected_atom_group.n_atoms
def __copy__(self):
copied_dataset = self.__class__(
self.universe.copy(),
self.return_dataset_sample_index,
self.return_trajectory_sample_index,
)
copied_dataset.trajectory_sample_indices = copy.copy(self.trajectory_sample_indices)
copied_dataset._selected_atom_group = copy.copy(self._selected_atom_group)
return copied_dataset
def __getitem__(self, idx):
"""Implement the ``__getitem__()`` method required for a PyTorch dataset.
Parameters
----------
index : int
The frame index of the timestep. This must always be comprised
between ``0`` and ``len(TrajectoryDataset)``. Note that the
dataset might contain a smaller of frames than the full trajectory
if a subset of frames was selected (for example, with
:func:`~tfep.io.dataset.TrajectoryDataset.subsample`).
Returns
-------
sample : dict
A dictionary including positions (as 1D ``Tensor`` of length
``n_atoms * 3``, and optionally the sample index and the auxiliary
information for the sample
"""
ts = self.get_timestep(idx)
sample = {}
# MDAnalysis loads coordinates with np.float32 dtype. We convert
# it to the default torch dtype and return them in flattened shape.
sample['positions'] = torch.tensor(np.ravel(ts.positions),
dtype=torch.get_default_dtype())
if ts.dimensions is not None:
sample['dimensions'] = torch.tensor(ts.dimensions, dtype=torch.get_default_dtype())
# Return the configurations and the auxiliary information. If an
# atom group is selected, this may have lost the auxiliary information
# so we go back to reading the main Trajectory Timestep for this.
for aux_name, aux_info in self.universe.trajectory.ts.aux.items():
sample[aux_name] = torch.tensor(aux_info)
# Return the requested indices.
if self.return_dataset_sample_index:
sample['dataset_sample_index'] = idx
if self.return_trajectory_sample_index:
if self.trajectory_sample_indices is None:
# We have selected all frames. Trajectory and dataset indices are the same.
sample['trajectory_sample_index'] = idx
else:
sample['trajectory_sample_index'] = self.trajectory_sample_indices[idx]
return sample
def __len__(self):
"""Number of samples in the dataset (i.e., selected trajectory frames)."""
if self.trajectory_sample_indices is None:
return len(self.universe.trajectory)
return len(self.trajectory_sample_indices)
[docs]
def get_timestep(self, idx):
"""Return the MDAnalysis ``Timestep`` object for the given index.
Parameters
----------
index : int
The frame index of the timestep. This must always be comprised
between ``0`` and ``len(TrajectoryDataset)``. Note that the
dataset might contain a smaller of frames than the full trajectory
if a subset of frames was selected (for example, with
:func:`~tfep.io.dataset.TrajectoryDataset.subsample`).
Returns
-------
ts : MDAnalysis.coordinates.base.Timestep
The MDAnalysis ``Timestep`` object with coordinate information of
the ``index``-th frame. If a subset of atoms was selected (for
example with :func:`~tfep.io.dataset.TrajectoryDataset.select_atoms`)
only the coordinates of those atoms are returned.
"""
# Make sure this is a Python int and not tensor.
int_idx = int(idx)
assert idx == int_idx
# First check if index refers to a subset of selected trajectory frames
# or to the full trajectory.
if self.trajectory_sample_indices is None:
ts = self.universe.trajectory[int_idx]
else:
ts = self.universe.trajectory[self.trajectory_sample_indices[int_idx]]
# If a subset of atoms was selected. Return the Timestep only for those.
if self._selected_atom_group is None:
return ts
else:
return self._selected_atom_group.ts
[docs]
def iterate_as_timestep(self):
"""Iterate over the selected frames/atoms as MDAnalysis ``Timestep`` objects.
Iterating over a ``TrajectoryDataset`` returns the trajectory information
in ``torch.Tensor`` format. This method enables iterating the samples in
the dataset as MDAnalysis ``Timestep`` objects.
Note that it is still possible to iterate over ``Timestep`` objects using
the MDAnalysis API and ``TrajectoryDataset.universe.trajectory``. However,
in this case, the selections of frames/atoms performed at the
``TrajectoryDataset`` level are ignored and all frames/atoms are returned.
Yields
------
ts : MDAnalysis.coordinates.base.Timestep
The current ``Timestep`` object.
"""
for i in range(len(self)):
yield self.get_timestep(i)
[docs]
def select_atoms(self, selection):
"""Select a subset of atoms.
Iterating over the dataset after selecting a subset of atoms yields only
the coordinates of these atoms.
For more information about the selection syntax consult the
`MDAnalysis documentation <https://docs.mdanalysis.org/stable/documentation_pages/selections.html>`_.
Parameters
----------
selection : str
The selection string following the MDAnalysis selection syntax.
"""
self._selected_atom_group = self.universe.select_atoms(selection)
[docs]
def subsample(self, start=None, stop=None, step=None, n_frames=None):
"""Select a subset of trajectory frames by subsampling it at regular intervals.
This function does not modify the trajectory. Thus
``TrajectoryDataset.universe.trajectory`` still have the same number of
frames. However, when iterating over the ``TrajectoryDataset`` only the
subsampled frames are returned.
``start``, ``stop``, and ``step`` can be given either as number of frames
or in units of time as ``pint.Quantity``. If the latter, the initial time
of the simulation t0 is taken into account. Note that this might not be
zero if, for example, the simulation was resumed. For example, if
``start`` is 2 ns and the simulation starts at 1 ns, only the first ns
of data is discarded.
Parameters
----------
start : int or pint.Quantity, optional
The first frame to include in the dataset specified either as a
frame index or in simulation time. If not provided, the subsampling
starts from the first frame in the trajectory.
stop : int or pint.Quantity, optional
The last frame to include in the dataset specified either as a
frame index or in simulation time. If not provided, the subsampling
ends at the last frame in the trajectory.
step : int or pint.Quantity, optional
The step used for subsampling specified either as a frame index or
or in simulation time. Only one between ``step`` and ``n_frames``
may be passed.
n_frames : int, optional
The total number of frames to include in the dataset. If this is
passed, the ``step`` will automatically be determined to satisfy this
requirement. Note that in this case the obtained samples in the
dataset might not be equally spaced if ``n_frames`` is not an exact
divisor of the number of frames. Only one between ``step`` and
``n_frames`` may be passed.
"""
# If all are None, there's no need to subsample.
if all([x is None for x in [start, stop, step, n_frames]]):
return
# Handle default arguments. step and n_frames are handled by get_subsampled_indices.
if start is None:
start = 0
if stop is None:
# Stop is the last index included in the subsampled trajectory.
stop = len(self.universe.trajectory)-1
# Look for a compatible unit registry, if given.
ureg = None
for quantity in [start, stop, step]:
if isinstance(quantity, pint.Quantity):
ureg = quantity._REGISTRY
break
if ureg is None:
ureg = pint.UnitRegistry()
# All time quantities in MDAnalysis are in picoseconds.
ps = ureg.picoseconds
self.trajectory_sample_indices = get_subsampled_indices(
dt=self.universe.trajectory.dt * ps,
stop=stop, start=start, step=step, n_frames=n_frames,
t0=self.universe.trajectory[0].time * ps)
# =============================================================================
# TRAJECTORY SUBSET
# =============================================================================
[docs]
class TrajectorySubset:
"""A subset of a ``TrajectoryDataset``.
Provides the same functionality of the PyTorch class ``torch.utils.data.Subset``,
which is to provided a subset of the main dataset, but for :class:`.TrajectoryDataset`.
Contrarily to ``torch.utils.data.Subset``, ``TrajectorySubset`` can also be
constructed from a filter function rather than only a list of indices.
The class exposes the same interface as :class:`.TrajectoryDataset`, with
the exception of :func:`.TrajectoryDataset.subsample`. The reason for this
exception is to avoid users to inadvertently leave an object in an undesired
state since the indices of ``TrajectorySubset`` might be meaningless after
the subsampling.
Parameters
----------
dataset : TrajectoryDataset or TrajectorySubset
The trajectory dataset.
indices : array_like
A list of indices of the ``dataset`` elements forming the subset.
Examples
--------
First we create the main ``TrajectoryDataset``
>>> import os
>>> import MDAnalysis
>>> test_data_dir_path = os.path.join(os.path.dirname(__file__), '..', '..', 'tests', 'data')
>>> pdb_file_path = os.path.join(test_data_dir_path, 'chloro-fluoromethane.pdb')
>>> universe = MDAnalysis.Universe(pdb_file_path, dt=5) # ps
>>> trajectory_dataset = TrajectoryDataset(universe)
We can then create a subset of the indices.
>>> len(trajectory_dataset)
5
>>> trajectory_subset = TrajectorySubset(trajectory_dataset, indices=[0, 2, 4])
>>> len(trajectory_subset)
3
Or alternatively from a filter function taking as input an MDAnalysis
``Timestep`` object and returning ``True`` or ``False`` whether the sample
must be included in the subset or not. The following trivial example takes
all samples for which the distance between two atoms is greater than 3
Angstrom.
>>> filter_func = lambda idx, ts: np.linalg.norm(ts.positions[1] - ts.positions[0]) > 3
>>> trajectory_subset = TrajectorySubset.from_filter(trajectory_dataset, filter_func)
>>> len(trajectory_subset)
2
The ``TrajectorySubset`` can be used as a normal ``TrajectoryDataset``.
>>> trajectory_subset.n_atoms
6
>>> trajectory_subset.select_atoms('index 0:2 or index 4')
"""
[docs]
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices
# Make sure indices is an array or the subset search won't work.
if not isinstance(self.indices, np.ndarray):
self.indices = np.array(self.indices)
[docs]
@classmethod
def from_filter(cls, dataset, filter_func):
"""Static constructor creating a subset based on a boolean filter function.
Parameters
----------
dataset : TrajectoryDataset
The trajectory dataset.
filter_func : Callable
A function taking as input (in this order) the index of the sample
in the original dataset and the MDAnalysis ``Timestep`` object and
returning ``True`` or ``False`` if the sample must be included in
the subset or not.
Returns
-------
subset : TrajectorySubset
A new ``TrajectorySubset`` object.
"""
indices = []
for idx, ts in enumerate(dataset.iterate_as_timestep()):
if filter_func(idx, ts):
indices.append(idx)
return cls(dataset, indices)
@property
def universe(self):
"""The MDAnalysis ``Universe`` object encapsulated by the dataset."""
return self.dataset.universe
@property
def return_dataset_sample_index(self):
"""Whether to return the keyword ``"dataset_sample_index"`` in the batch sample."""
return self.dataset.return_dataset_sample_index
@property
def return_trajectory_sample_index(self):
"""Whether to return the keyword ``"trajectory_sample_index"`` in the batch sample."""
return self.dataset.return_trajectory_sample_index
@property
def n_atoms(self):
"""Number of selected atoms in the dataset."""
return self.dataset.n_atoms
@property
def trajectory_sample_indices(self):
"""Indices of the dataset semples in the trajectory (before subsampling).
``trajectory_sample_indices[i]`` is the index of the ``i``-th sample in
``self.dataset.trajectory``.
"""
trajectory_sample_indices = self.dataset.trajectory_sample_indices
return trajectory_sample_indices[self.indices]
def __getitem__(self, idx):
"""Implement the ``__getitem__()`` method required for a PyTorch dataset."""
sample = self.dataset[self.indices[idx]]
# Update the index if return_dataset_sample_index is True.
# The trajectory index should already be correct.
if self.return_dataset_sample_index:
sample['dataset_sample_index'] = idx
return sample
def __len__(self):
"""Number of samples in the dataset (i.e., selected trajectory frames)."""
return len(self.indices)
[docs]
def get_timestep(self, item):
"""Return the MDAnalysis ``Timestep`` object of the frame with the given index.
See also :func:`.TrajectoryDataset.get_timestep`.
"""
return self.dataset.get_timestep(self.indices[item])
[docs]
def iterate_as_timestep(self):
"""Iterate over MDAnalysis ``Timestep`` objects.
See also :func:`.TrajectoryDataset.iterate_as_timestep`.
"""
for idx in range(len(self)):
yield self.get_timestep(idx)
[docs]
def select_atoms(self, selection):
"""Select a subset of atoms.
See also :func:`.TrajectoryDataset.iterate_as_timestep.select_atoms`.
"""
self.dataset.select_atoms(selection)
# =============================================================================
# SUBSAMPLING UTILITIES
# =============================================================================
[docs]
def get_subsampled_indices(
dt,
stop,
start=0,
step=None,
n_frames=None,
t0=0.0,
):
"""Subsamples the trajectory at a constant time interval after discarding an initial equilibration.
This function returns the indices of the trajectory frames that must be
selected for subsampling.
``start``, ``stop``, and ``step`` can be given either as number of frames
or in units of time as ``pint.Quantity``. If the latter, the initial time
of the simulation t0 is taken into account. Note that this might not be
zero if, for example, the simulation was resumed. For example, if
``start`` is 2 ns and the simulation starts at 1 ns, only the first ns
of data is discarded.
Parameters
----------
start : int or pint.Quantity
The first frame to include in the dataset specified either as a
frame index or in simulation time. If not provided, the subsampling
starts from the first frame in the trajectory.
stop : int or pint.Quantity
The last frame to include in the dataset specified either as a
frame index or in simulation time.
step : int or pint.Quantity, optional
The step used for subsampling specified either as a frame index or
or in simulation time. Only one between ``step`` and ``n_frames`` may
be passed.
n_frames : int, optional
The total number of frames to include in the dataset. If this is
passed, the ``step`` will automatically be determined to satisfy this
requirement. Note that in this case the obtained samples in the
dataset might not be equally spaced if ``n_frames`` is not an exact
divisor of the number of frames. Only one between ``step`` and ``n_frames``
may be passed.
t0 : pint.Quantity, optional
The time of the first frame in the trajectory to subsamples. This might
not be 0.0 if, for example, the simulation was resumed.
Returns
-------
trajectory_indices : numpy.ndarray
The indices of the trajectory frames to use for subsampling.
"""
# Check that only one between step and n_frames is given.
if (step is not None) and (n_frames is not None):
raise ValueError("Only one between 'step' and 'n_frames' may be passed.")
# Make time quantities unitless.
ureg = dt._REGISTRY
unit = ureg.picoseconds
dt = dt.to(unit).magnitude
if t0 is None:
t0 = 0.0
else:
t0 = t0.to(unit).magnitude
# Convert start, stop, and step to frame indices.
times = [start, stop, step]
for i, (t, label) in enumerate(zip(times, ['start', 'stop', 'step'])):
if isinstance(t, pint.Quantity):
if label == 'step':
# No need to subtract t0.
frame_idx = t.to(unit).magnitude / dt
else:
frame_idx = (t.to(unit).magnitude - t0) / dt
if not np.isclose(frame_idx, np.round(frame_idx)):
closest_times = dt * np.array([np.floor(frame_idx), np.ceil(frame_idx)]) * unit
raise ValueError(f'The time step {dt} is not compatible with {label} time {t}. '
f'The closest possible start times are {closest_times[0]} or '
f'{closest_times[1]}')
times[i] = int(round(frame_idx))
start, stop, step = times
# Check if the step must be instead determined by the number of frames.
if n_frames is not None:
# Check that there are enough frames.
if n_frames > stop - start + 1:
raise ValueError(f"There are not enough frames to select {n_frames} "
f"from a trajectory with the start time {start*dt} ps"
f" and stop time {stop*dt} ps")
# Create the frames with a constant number of frames.
return np.linspace(start, stop, n_frames).astype(int)
# Create the frames with a constant step. We include "stop" in the dataset.
if (step is None) and (n_frames is None):
step = 1
return np.arange(start, stop+1, step, dtype=int)