tfep.nn.graph.prune_long_edges

tfep.nn.graph.prune_long_edges(r_cutoff, edges, distances, *args)[source]

Detect which edges have distances larger than the cutoff and remove them.

Parameters:
  • r_cutoff (float) – The radial cutoff. All edges connecting nodes at distance greater than this cutoff will be pruned. This must be in the same units as the nodes positions.

  • edges (torch.Tensor) – Shape (2, batch_size*n_edges). The i-th edge is created from node edges[0][i] to edges[1][i], where edges[0][i] is a node index in the range [0, batch_size*n_nodes].

  • distances (torch.Tensor) – Shape (batch_size*n_edges,). distances[i] is the distance between the nodes of the i-th edge.

  • *args (Sequence[torch.Tensor]) – Other tensors of shape (batch_size*n_edges, *) to prune in the same way.

Returns:

  • edges (torch.Tensor) – Shape (2, batch_size*n_pruned_edges). The edges after the pruning.

  • distances (torch.Tensor, optional) – Shape (batch_size*n_pruned_edges,). The distances of the nodes across the edges after the pruning.

  • *other (torch.Tensor, optional) – Other pruned tensors of shape (batch_size*n_pruned_edges, *).