#!/usr/bin/env python
# =============================================================================
# MODULE DOCSTRING
# =============================================================================
r"""
Utility class to merge multiple PyTorch ``Dataset``\ s.
For usage examples see the documentation of :class:`.MergedDataset`.
"""
# =============================================================================
# GLOBAL IMPORTS
# =============================================================================
import torch.utils.data
# =============================================================================
# MERGED DATASET
# =============================================================================
[docs]
class MergedDataset(torch.utils.data.Dataset):
r"""Dataset merging multiple ``Dataset``\ s.
The dataset constructs a batch by merging batches from the wrapped datasets.
Currently, it supports only map-style datasets that return samples in
dictionary format.
Examples
--------
>>> from tfep.io.dataset import DictDataset
>>> dataset1 = DictDataset({'a': [1., 2.]})
>>> dataset2 = DictDataset({'b': [3, 4]})
>>> merged = MergedDataset(dataset1, dataset2)
>>> merged[1]
{'a': tensor(2.), 'b': tensor(4)}
"""
[docs]
def __init__(self, *datasets):
"""Constructor.
Parameters
----------
*datasets : torch.utils.data.Dataset
The map-style datasets to be merged. These must all have the same
number of samples and return samples in dictionary format.
"""
super().__init__()
# Check that the datasets all have the same number of samples.
for i in range(len(datasets)-1):
if len(datasets[i]) != len(datasets[i+1]):
raise ValueError(f'Datasets {i} and {i+1} have different numbers '
f'of samples ({len(datasets[i])} and {len(datasets[i+1])})')
# Check that the datasets have different keys so that no data is overridden.
n_keys = 0
all_keys = set()
for dataset in datasets:
keys = list(dataset[0].keys())
n_keys += len(keys)
all_keys.update(keys)
if len(all_keys) != n_keys:
raise ValueError(f'The merged datasets have overlapping keys.')
# We save the datasets as an internal attribute because we don't perform
# any other checks if the datasets are modified.
self._datasets = datasets
def __getitem__(self, item):
sample = {}
for dataset in self._datasets:
# We have already calculated in __init__ that keys in different
# datasets do not overlap.
sample.update(dataset[item])
return sample
def __len__(self):
return len(self._datasets[0])