tfep.io.sampler.StatefulBatchSampler
- class tfep.io.sampler.StatefulBatchSampler(dataset: Dataset, batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, trainer=None)[source]
Bases:
SamplerA PyTorch stateful batch sampler to resume training mid-epoch.
This class can be used with a PyTorch Lightning
Trainerto implement a correct data checkpointing. If the training is interrupted mid-epoch, and theDataLoaderuses this batch sampler, the training will resume correctly, i.e., it will complete the epoch by training only on the data points that were not previously seen.Examples
>>> import torch >>> import lightning >>> import tfep.io >>> >>> # Initialize the trainer. This must be passed to StatefulBatchSampler. >>> trainer = lightning.Trainer() >>> >>> # Initialize the dataset and data loader. >>> dataset = tfep.io.DictDataset({'a': [0, 1, 2, 3, 4]}) >>> sampler = StatefulBatchSampler( ... dataset, ... batch_size=2, ... shuffle=True, ... drop_last=True, ... trainer=trainer ... ) >>> dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler) >>> >>> # Train your model. >>> trainer.fit(your_model, dataloader)
- __init__(dataset: Dataset, batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, trainer=None)[source]
Constructor.
- Parameters:
dataset (torch.utils.data.Dataset) – The dataset to sample.
batch_size (int, optional) – The batch size.
shuffle (bool, optional) – If
True, data samples are reshuffled at every epoch.drop_last (bool, optional) – If the dataset size is not divisible by the batch size the last batch is dropped if this is
Trueor just smaller if this isFalse.trainer (object or None, optional) – An object exposing a
global_stepattribute that holds the total number of seen batches during the entire training. This is usually a PyTorch LightningTrainerobject. If not given on initialization, this must be passed later through thetrainerattribute.
Methods
__init__(dataset[, batch_size, shuffle, ...])Constructor.
load_state_dict(state_dict)Load the internal state from a dictionary.
Serialize the internal state in dictionary format.
Attributes
The batch size.
Whether the last incomplete batch is dropped or yielded.
Whether to reshuffle the data at each new epoch.
The trainer object exposing a
global_stepattribute with the total number of batches seen during the entire training.- property batch_size: int
The batch size.
- property drop_last: bool
Whether the last incomplete batch is dropped or yielded.
- load_state_dict(state_dict: dict[str, Any])[source]
Load the internal state from a dictionary.
The dictionary must be generated with
state_dict(). Note that the parameters passed in the constructor are not serialized, and thus the object must be initialized with the same arguments to recover the same object.- Parameters:
state_dict (dict[str, Any]) – The serialized internal state.
- property shuffle: bool
Whether to reshuffle the data at each new epoch.
- state_dict() dict[str, Any][source]
Serialize the internal state in dictionary format.
Note that the parameters passed in the constructor are not serialized.
- Returns:
state_dict – The serialized internal state.
- Return type:
dict[str, Any]
- trainer: Any | None
The trainer object exposing a
global_stepattribute with the total number of batches seen during the entire training.