Source code for tfep.nn.graph

#!/usr/bin/env python


# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Utility functions for graphs.
"""


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from collections.abc import Sequence
import itertools
from typing import Optional

import torch

from tfep.utils.misc import ensure_tensor_sequence


# =============================================================================
# UTILITIES
# =============================================================================

[docs] class FixedGraph(torch.nn.Module): """Graph base class with a fixed topology. The class provides utilities to determine the edges between nodes (optionally based on a mask) and cache them. It is assumed that the edges do not change after constructions. The edges must be accessed by the parent class through the ``get_edges()`` method, which takes care of determining the edges compatible with features in the shape ``(batch_size*n_nodes, n_feats_per_node)``. """
[docs] def __init__( self, node_types : Sequence[int], mask : Optional[torch.Tensor] = None, ): """Constructor. Parameters ---------- node_types : Sequence[int] Shape ``(n_nodes,)``. ``node_types[i]`` is the ID of the node type for the i-th node. These are usually used to indicate an atom element. These are encoded into a ``self._node_types_one_hot`` encoding using ``torch.nn.functional.one_hot`` so they should start from 0 and contain only consecutive numbers to limit the size of the encoding (i.e., ``0 <= node_types[i] < n_node_types`` for all ``i``). mask : torch.Tensor, optional Shape ``(n_nodes, n_nodes)``. A (directional) edge from node ``i`` to node ``j`` is created only if ``mask[i, j] != 0``. If ``mask`` is not provided, all nodes are connected to all nodes (excluding self interactions). """ super().__init__() # Encode node types. Convert them from int to floats to ease embedding. node_types = ensure_tensor_sequence(node_types) node_types_encoding = torch.nn.functional.one_hot(node_types).to(torch.get_default_dtype()) self.register_buffer('_node_types_one_hot', node_types_encoding) # We initially build the edges for a batch of size 1 so at this point # self._last_batch_edges will have shape (2, n_edges). However, we'll # use this to cache the number of edges for the latest batch size. self._last_batch_edges = get_all_edges(batch_size=1, n_nodes=self.n_nodes, mask=mask) # The number of edges for a single batch. Needed by get_edges(). self._n_edges = int(self._last_batch_edges.shape[1])
@property def n_nodes(self): """int: Number of nodes in the graph.""" return len(self._node_types_one_hot) @property def n_edges(self): """int: Number of edges in the graph.""" return self._n_edges
[docs] def get_edges(self, batch_size): """Return the edges between nodes for the given batch size. Parameters ---------- batch_size : int The size of the current batch. Returns ------- edges : torch.Tensor Shape ``(2, batch_size*n_edges)``. The ``i``-th edge is created from node ``edges[0][i]`` to ``edges[1][i]``, where ``edges[0][i]`` is a node index in the range ``[0, batch_size*n_nodes]``. Edges are directional so if a message must be passed in both directions, two entries connecting the nodes with inverse order are present. """ # Store new edges so that if the next batch is the same this will be faster. self._last_batch_edges = fix_node_indices_batch_size( node_indices=self._last_batch_edges, new_batch_size=batch_size, n_indices=self._n_edges, n_nodes=self.n_nodes, ) return self._last_batch_edges
[docs] def get_all_edges(batch_size, n_nodes, mask=None): """Return all possible edges between nodes after applying the mask. Parameters ---------- batch_size : int The batch size. n_nodes : int The number of nodes in the graph. mask : torch.Tensor, optional Shape ``(n_nodes, n_nodes)``. A (directional) edge from node ``i`` to node ``j`` is created only if ``mask[i, j] != 0``. If ``mask`` is not provided, all nodes are connected to all nodes (excluding self interactions). Returns ------- edges : torch.Tensor Shape ``(2, batch_size*n_edges)``. The ``i``-th edge is created from node ``edges[0][i]`` to ``edges[1][i]``, where ``edges[0][i]`` is a node index in the range ``[0, batch_size*n_nodes]``. Edges are directional so if a message must be passed in both directions, two entries connecting the nodes with inverse order are present. """ # Determine the edges for a single batch. edges has shape (2, n_edges). if mask is None: if n_nodes == 1: return torch.empty(2, 0) edges = itertools.permutations(range(n_nodes), 2) edges = torch.tensor(list(zip(*edges))) else: if mask.shape != (n_nodes, n_nodes): raise ValueError('mask must have shape (n_nodes, n_nodes)') edges = mask.nonzero() edges = edges.t() # Determine the edges for the given shape. edges = fix_node_indices_batch_size( node_indices=edges, new_batch_size=batch_size, n_indices=edges.shape[-1], n_nodes=n_nodes, ) return edges
[docs] def fix_node_indices_batch_size( node_indices: torch.Tensor, new_batch_size: int, n_indices : int, n_nodes: int, ): """Expand a tensor of node indices to be compatible to shape ``(*, new_batch_size*n_indices)``. This is useful, for example, to update the indices of source/destination nodes in the edges for the last batch, which might be smaller than the normal batch size. Parameters ---------- node_indices : torch.Tensor Shape ``(*, old_batch_size*n_indices)``. new_batch_size : int The output batch size. n_indices : int The number of node indices for batch size = 1. n_nodes : int, optional The total number of nodes in the graph. Returns ------- fixed_indices : torch.Tensor Shape ``(*, new_batch_size*n_indices)``. The new edges after fixing the batch size. """ batch_indices_dim = new_batch_size * n_indices if node_indices.shape[-1] == batch_indices_dim: return node_indices # Check if we have too many batches. if node_indices.shape[-1] > batch_indices_dim: return node_indices[..., :batch_indices_dim] # Otherwise, reduce to 1 batch. node_indices = node_indices[..., :n_indices] # To shape (*, new_batch_size, n_indices). node_indices = node_indices.unsqueeze(-2) size = list(node_indices.shape) size[-2] = new_batch_size node_indices = node_indices.expand(*size) # Now multiply repeated node indices by the batch index. shift has shape (1, new_batch_size, 1). shift = torch.arange(0, new_batch_size*n_nodes, n_nodes).unsqueeze(-1) node_indices = node_indices + shift # From shape (*, new_batch_size, n_indices) to (*, new_batch_size*n_indices) node_indices = node_indices.flatten(start_dim=-2) return node_indices
# TODO: MERGE THIS WITH tfep.utils.geometry.pdist?
[docs] def compute_edge_distances(x, edges, normalize_directions=False, inverse_directions=False): """Return distances between nodes across edges. Parameters ---------- x : torch.Tensor Positions of particles with shape (batch_size*n_nodes, 3). edges : torch.Tensor Shape ``(2, batch_size*n_edges)``. The ``i``-th edge is created from node ``edges[0][i]`` to ``edges[1][i]``, where ``edges[0][i]`` is a node index in the range ``[0, batch_size*n_nodes]``. normalize_directions : bool, optional If ``True``, the returned directions are normalized by their norms. inverse_directions : bool, optional If ``True``, the direction vectors go from the destination node to the source node rather than the opposite. Returns ------- distances : torch.Tensor Shape ``(batch_size*n_edges,)``. ``distances[i]`` is the distance between the nodes of the ``i``-th edge. direction : torch.Tensor Shape ``(batch_size*n_edges, 3)``. ``direction[i]`` is the vector connecting the nodes of the ``i``-th edge. The vectors are normalized by their norms if ``normalize_direction`` is ``True``. """ if inverse_directions: directions = x[edges[0]] - x[edges[1]] else: directions = x[edges[1]] - x[edges[0]] distances = torch.sqrt(torch.sum(directions**2, dim=-1)) if normalize_directions: directions = directions / distances.unsqueeze(-1) return distances, directions
[docs] def prune_long_edges(r_cutoff, edges, distances, *args): """Detect which edges have distances larger than the cutoff and remove them. Parameters ---------- r_cutoff : float The radial cutoff. All edges connecting nodes at distance greater than this cutoff will be pruned. This must be in the same units as the nodes positions. edges : torch.Tensor Shape ``(2, batch_size*n_edges)``. The ``i``-th edge is created from node ``edges[0][i]`` to ``edges[1][i]``, where ``edges[0][i]`` is a node index in the range ``[0, batch_size*n_nodes]``. distances : torch.Tensor Shape ``(batch_size*n_edges,)``. ``distances[i]`` is the distance between the nodes of the ``i``-th edge. *args : Sequence[torch.Tensor] Other tensors of shape ``(batch_size*n_edges, *)`` to prune in the same way. Returns ------- edges : torch.Tensor Shape ``(2, batch_size*n_pruned_edges)``. The edges after the pruning. distances : torch.Tensor, optional Shape ``(batch_size*n_pruned_edges,)``. The distances of the nodes across the edges after the pruning. *other : torch.Tensor, optional Other pruned tensors of shape ``(batch_size*n_pruned_edges, *)``. """ mask = distances <= r_cutoff edges = edges[:, mask] distances = distances[mask] pruned_args = [arg[mask] for arg in args] return edges, distances, *pruned_args
[docs] def unsorted_segment_sum(data, segment_ids, n_segments): """Replicates TensorFlow's tf.math.unsorted_segment_sum in PyTorch.""" segment_sum = data.new_full((n_segments, data.shape[1]), fill_value=0.0) segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.shape[1]) segment_sum.scatter_add_(0, segment_ids, data) return segment_sum