tfep.utils.misc.remove_and_shift_sorted_indices

tfep.utils.misc.remove_and_shift_sorted_indices(indices: Tensor, removed_indices: Tensor, remove: bool = True, shift: bool = True) Tensor[source]

Remove from indices the indices in removed_indices (by value).

Both indices and removed_indices must be sorted tensors of non-negative integers. The indices in indices are (optionally) shifted so that it can be used to point to elements of an array where removed_indices have been removed.

Parameters:
  • indices (torch.Tensor) – The tensor from which to remove indices.

  • removed_indices (torch.Tensor) – The indices that must be removed from indices.

  • remove (bool, optional) – If indices and removed_indices do not overlap, and only shifting is necessary, this can be set to False. Default True.

  • shift (bool, optional) – If False shifting the indices is not performed.

Returns:

out – The indices tensor after removing and shifting the elements.

Return type:

torch.Tensor

Examples

>>> remove_and_shift_sorted_indices(
...     indices=torch.tensor([0, 3, 9, 13]),
...     removed_indices=torch.tensor([3, 12]),
...     shift=False,
... ).tolist()
[0, 9, 13]
>>> remove_and_shift_sorted_indices(
...     indices=torch.tensor([0, 3, 9, 13]),
...     removed_indices=torch.tensor([3, 12]),
...     shift=True,
... ).tolist()
[0, 8, 11]