tfep.nn.graph.get_all_edges

tfep.nn.graph.get_all_edges(batch_size, n_nodes, mask=None)[source]

Return all possible edges between nodes after applying the mask.

Parameters:
  • batch_size (int) – The batch size.

  • n_nodes (int) – The number of nodes in the graph.

  • mask (torch.Tensor, optional) – Shape (n_nodes, n_nodes). A (directional) edge from node i to node j is created only if mask[i, j] != 0. If mask is not provided, all nodes are connected to all nodes (excluding self interactions).

Returns:

edges – Shape (2, batch_size*n_edges). The i-th edge is created from node edges[0][i] to edges[1][i], where edges[0][i] is a node index in the range [0, batch_size*n_nodes].

Edges are directional so if a message must be passed in both directions, two entries connecting the nodes with inverse order are present.

Return type:

torch.Tensor