tfep.nn.graph.fix_node_indices_batch_size
- tfep.nn.graph.fix_node_indices_batch_size(node_indices: Tensor, new_batch_size: int, n_indices: int, n_nodes: int)[source]
Expand a tensor of node indices to be compatible to shape
(*, new_batch_size*n_indices).This is useful, for example, to update the indices of source/destination nodes in the edges for the last batch, which might be smaller than the normal batch size.
- Parameters:
node_indices (torch.Tensor) – Shape
(*, old_batch_size*n_indices).new_batch_size (int) – The output batch size.
n_indices (int) – The number of node indices for batch size = 1.
n_nodes (int, optional) – The total number of nodes in the graph.
- Returns:
fixed_indices – Shape
(*, new_batch_size*n_indices). The new edges after fixing the batch size.- Return type:
torch.Tensor