tfep.nn.graph.fix_node_indices_batch_size

tfep.nn.graph.fix_node_indices_batch_size(node_indices: Tensor, new_batch_size: int, n_indices: int, n_nodes: int)[source]

Expand a tensor of node indices to be compatible to shape (*, new_batch_size*n_indices).

This is useful, for example, to update the indices of source/destination nodes in the edges for the last batch, which might be smaller than the normal batch size.

Parameters:
  • node_indices (torch.Tensor) – Shape (*, old_batch_size*n_indices).

  • new_batch_size (int) – The output batch size.

  • n_indices (int) – The number of node indices for batch size = 1.

  • n_nodes (int, optional) – The total number of nodes in the graph.

Returns:

fixed_indices – Shape (*, new_batch_size*n_indices). The new edges after fixing the batch size.

Return type:

torch.Tensor