#!/usr/bin/env python
# =============================================================================
# MODULE DOCSTRING
# =============================================================================
"""
A PyTorch stateful batch sampler that can be used to correctly resume training mid-epoch.
See documentation of the class :class:`StatefulBatchSampler` for more details.
"""
# =============================================================================
# GLOBAL IMPORTS
# =============================================================================
from typing import Any, Iterator, Optional
import torch
# =============================================================================
# STATEFUL BATCH SAMPLER
# =============================================================================
[docs]
class StatefulBatchSampler(torch.utils.data.Sampler):
"""A PyTorch stateful batch sampler to resume training mid-epoch.
This class can be used with a PyTorch Lightning ``Trainer`` to implement
a correct data checkpointing. If the training is interrupted mid-epoch, and
the ``DataLoader`` uses this batch sampler, the training will resume correctly,
i.e., it will complete the epoch by training only on the data points that
were not previously seen.
Examples
--------
>>> import torch
>>> import lightning
>>> import tfep.io
>>>
>>> # Initialize the trainer. This must be passed to StatefulBatchSampler.
>>> trainer = lightning.Trainer()
>>>
>>> # Initialize the dataset and data loader.
>>> dataset = tfep.io.DictDataset({'a': [0, 1, 2, 3, 4]})
>>> sampler = StatefulBatchSampler(
... dataset,
... batch_size=2,
... shuffle=True,
... drop_last=True,
... trainer=trainer
... )
>>> dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)
>>>
>>> # Train your model.
>>> trainer.fit(your_model, dataloader) # doctest: +SKIP
"""
[docs]
def __init__(
self,
dataset : torch.utils.data.Dataset,
batch_size : int = 1,
shuffle : bool = False,
drop_last : bool = False,
trainer = None
):
"""Constructor.
Parameters
----------
dataset : torch.utils.data.Dataset
The dataset to sample.
batch_size : int, optional
The batch size.
shuffle : bool, optional
If ``True``, data samples are reshuffled at every epoch.
drop_last : bool, optional
If the dataset size is not divisible by the batch size the last batch
is dropped if this is ``True`` or just smaller if this is ``False``.
trainer : object or None, optional
An object exposing a ``global_step`` attribute that holds the total
number of seen batches during the entire training. This is usually a
PyTorch Lightning ``Trainer`` object. If not given on initialization,
this must be passed later through the :attr:`~StatefulBatchSampler.trainer`
attribute.
"""
try:
super().__init__()
except TypeError:
# PyTorch < 2.1 requires passing a dataset but Pytorch > 2.1 requires not passing it.
super().__init__(dataset)
self._dataset = dataset
self._batch_size = batch_size
self._shuffle = shuffle
self._drop_last = drop_last
# Keeps track of the seed used to shuffle the data in the current epoch.
self._current_epoch_seed = None
#: The trainer object exposing a ``global_step`` attribute with the total number of batches seen during the entire training.
self.trainer : Optional[Any] = trainer
@property
def batch_size(self) -> int:
"""The batch size."""
return self._batch_size
@property
def shuffle(self) -> bool:
"""Whether to reshuffle the data at each new epoch."""
return self._shuffle
@property
def drop_last(self) -> bool:
"""Whether the last incomplete batch is dropped or yielded."""
return self._drop_last
def __len__(self) -> int:
"""The number of batches per epoch."""
if self.drop_last:
return len(self._dataset) // self.batch_size
return (len(self._dataset) + self.batch_size - 1) // self.batch_size
def __iter__(self) -> Iterator[torch.Tensor]:
"""Iterate over batches.
Yields
------
batch_indices : torch.Tensor[int]
A tensor of sample indices forming the batch.
"""
if self.trainer is None:
raise RuntimeError('trainer must be set before starting the training.')
# This is usually called at the start of each epoch but current_batch_idx
# might be != 0 if this is resumed from a mid-epoch checkpoint.
current_batch_idx = self.trainer.global_step % len(self)
# Get the random indices.
if self.shuffle:
# If this is a new epoch, regenerate the seed.
if current_batch_idx == 0:
self._current_epoch_seed = int(torch.empty((), dtype=torch.int64).random_().item())
# Create a random permutation of the sample indices.
generator = torch.Generator()
generator.manual_seed(self._current_epoch_seed)
epoch_indices = torch.randperm(len(self._dataset), generator=generator)
else: # Sequential.
epoch_indices = torch.arange(0, len(self._dataset), dtype=int)
# Yield indices.
for batch_idx in range(current_batch_idx, len(self)):
start = batch_idx * self.batch_size
end = (batch_idx + 1) * self.batch_size
yield epoch_indices[start:end]
[docs]
def state_dict(self) -> dict[str, Any]:
"""Serialize the internal state in dictionary format.
Note that the parameters passed in the constructor are not serialized.
Returns
-------
state_dict : dict[str, Any]
The serialized internal state.
"""
return {'current_epoch_seed': self._current_epoch_seed}
[docs]
def load_state_dict(self, state_dict: dict[str, Any]):
"""Load the internal state from a dictionary.
The dictionary must be generated with :func:`~StatefulBatchSampler.state_dict`.
Note that the parameters passed in the constructor are not serialized,
and thus the object must be initialized with the same arguments to recover
the same object.
Parameters
----------
state_dict : dict[str, Any]
The serialized internal state.
"""
self._current_epoch_seed = state_dict['current_epoch_seed']