tfep.nn.graph

Utility functions for graphs.

Functions

compute_edge_distances(x, edges[, ...])

Return distances between nodes across edges.

fix_node_indices_batch_size(node_indices, ...)

Expand a tensor of node indices to be compatible to shape (*, new_batch_size*n_indices).

get_all_edges(batch_size, n_nodes[, mask])

Return all possible edges between nodes after applying the mask.

prune_long_edges(r_cutoff, edges, distances, ...)

Detect which edges have distances larger than the cutoff and remove them.

unsorted_segment_sum(data, segment_ids, ...)

Replicates TensorFlow's tf.math.unsorted_segment_sum in PyTorch.

Classes

FixedGraph(node_types[, mask])

Graph base class with a fixed topology.