tfep.nn.graph.get_all_edges
- tfep.nn.graph.get_all_edges(batch_size, n_nodes, mask=None)[source]
Return all possible edges between nodes after applying the mask.
- Parameters:
batch_size (int) – The batch size.
n_nodes (int) – The number of nodes in the graph.
mask (torch.Tensor, optional) – Shape
(n_nodes, n_nodes). A (directional) edge from nodeito nodejis created only ifmask[i, j] != 0. Ifmaskis not provided, all nodes are connected to all nodes (excluding self interactions).
- Returns:
edges – Shape
(2, batch_size*n_edges). Thei-th edge is created from nodeedges[0][i]toedges[1][i], whereedges[0][i]is a node index in the range[0, batch_size*n_nodes].Edges are directional so if a message must be passed in both directions, two entries connecting the nodes with inverse order are present.
- Return type:
torch.Tensor