Source code for tfep.io.dataset.dict

#!/usr/bin/env python


# =============================================================================
# MODULE DOCSTRING
# =============================================================================

r"""
Utility class to create a map-style PyTorch ``Dataset``\ s from a dictionary of tensors.

For usage examples see the documentation of :class:`.DictDataset`.

"""


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from collections.abc import Sequence

import torch.utils.data


# =============================================================================
# DICTIONARY DATASET
# =============================================================================

[docs] class DictDataset(torch.utils.data.Dataset): r"""Utility class to create a map-style PyTorch ``Dataset``\ s from a dictionary of tensors. The class automatically converts non-tensor dictionary values into tensors. Examples -------- >>> import torch >>> data = {'a': torch.tensor([1.0, 2.0]), 'b': [3, 4]} >>> dict_dataset = DictDataset(data) >>> dict_dataset[1] {'a': tensor(2.), 'b': tensor(4)} """
[docs] def __init__(self, tensor_dict : dict[str, Sequence]): """Constructor. Parameters ---------- tensor_dict : dict[str, torch.Tensor] A dictionary of named tensors. """ # Check that all the columns have the same lengths. lengths = set(len(v) for v in tensor_dict.values()) if len(lengths) > 1: raise ValueError('The values of tensor_dict must all have the same length.') # Convert all values to tensors. self._tensor_dict = {k: torch.as_tensor(v) for k, v in tensor_dict.items()}
def __getitem__(self, item): """Retrieve a dataset sample. Parameters ---------- item : int or slice The index (or slice) of the sample(s). Returns ------- samples : dict[str, torch.Tensor] A dictionary of named tensor. """ return {k: v[item] for k, v in self._tensor_dict.items()} def __len__(self): """The number of samples in the dataset.""" for vals in self._tensor_dict.values(): return len(vals)