Source code for tfep.tests.app

"""
Test objects and function in the package ``tfep.app``.
"""


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from collections.abc import Sequence
from typing import List, Optional, Type, Union

import tfep.app.base


# =============================================================================
# SHARED APP TEST UTILITIES
# =============================================================================

[docs] def check_atom_groups( tfep_map_cls: Type[tfep.app.base.TFEPMapBase], fix_origin: bool, fix_orientation: bool, mapped_atoms: Optional[Union[Sequence[int], str]], conditioning_atoms: Optional[Union[Sequence[int], str]], expected_mapped: Optional[List[int]], expected_conditioning: Optional[List[int]], expected_fixed: Optional[List[int]], expected_mapped_fixed_removed: Optional[List[int]], expected_conditioning_fixed_removed: Optional[List[int]], round_trip: bool = True, **kwargs, ): """Test selection of mapped, conditioning, fixed, and reference frame atoms. This also tests: - That a forward-inverse round trip yields the original input. - That the conditioning atoms are not changed but affect the output. - That the fixed atoms are not changed and do not affect the output. Parameters ---------- round_trip : bool, optional If ``False``, this will not check that a forward-inverse round trip yields the original input. """ import lightning import numpy as np import pytest import tempfile import torch from tfep.utils.misc import atom_to_flattened, flattened_to_atom, temporary_cd from tfep.utils.geometry import batchwise_dot # Since we select randomly the reference atoms, there's a chance to pick # collinear atoms. We repeat the selection until the error is not thrown. max_n_attempts = 10 for attempt_idx in range(max_n_attempts): try: # Select a random fixed atom to fix the rotational degrees of freedom. if fix_origin: if expected_conditioning is None: pytest.skip('fixing the translational DOFs require the presence of a conditioning atom') origin_atom = np.random.choice(expected_conditioning) kwargs['origin_atom'] = origin_atom # Select axis and plane atoms among the remaining atoms. if fix_orientation: remaining = [] [remaining.extend(l) for l in (expected_mapped, expected_conditioning) if l is not None] remaining = sorted([i for i in remaining if i != kwargs.get('origin_atom', None)]) if len(remaining) < 2: pytest.skip('fixing the orientation of the reference frame requires at least 2 mapped or conditioning atoms.') axes_atoms = np.random.choice(remaining, size=2, replace=False).tolist() kwargs['axes_atoms'] = axes_atoms with tempfile.TemporaryDirectory() as tmp_dir_path: with temporary_cd(tmp_dir_path): # Initialize the map. tfep_map = tfep_map_cls( mapped_atoms=mapped_atoms, conditioning_atoms=conditioning_atoms, **kwargs, ) # Train for one step to make sure that the map is not the identity. trainer = lightning.Trainer( max_steps=1, logger=False, enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, ) trainer.fit(tfep_map) break except RuntimeError as err: if 'collinear' in str(err): continue raise else: raise err # Compare expected indices. for expected_indices, tfep_indices in zip( [ expected_mapped_fixed_removed, expected_conditioning_fixed_removed, expected_fixed ], [ tfep_map.get_mapped_indices(idx_type='atom', remove_fixed=True), tfep_map.get_conditioning_indices(idx_type='atom', remove_fixed=True), tfep_map._fixed_atom_indices ]): if expected_indices is None: assert tfep_indices is None else: assert torch.all(tfep_indices == torch.tensor(expected_indices)) # Create position input. x = torch.tensor([tfep_map.dataset.universe.trajectory[i].positions for i in range(tfep_map.hparams.batch_size)], dtype=torch.get_default_dtype()) x = atom_to_flattened(x) x.requires_grad = True # Test forward and inverse. result = tfep_map({'positions': x}) y, log_det_J = result['positions'], result['log_det_J'] result = tfep_map.inverse({'positions': y}) if round_trip: x_inv, log_det_J_inv = result['positions'], result['log_det_J'] assert torch.allclose(x, x_inv) # Compute gradients w.r.t. the input. loss = y.sum() loss.backward() x_grad = flattened_to_atom(x.grad) # The flow must take care of mapped and conditioning, while the fixed atoms # are handled automatically. x = flattened_to_atom(x) y = flattened_to_atom(y) # Check that the map is not the identity or this test doesn't make sense. assert not torch.allclose(x[:, expected_mapped], y[:, expected_mapped]) # The flow doesn't alter but still depends on the conditioning DOFs. if expected_conditioning is not None: assert torch.allclose(x[:, expected_conditioning], y[:, expected_conditioning]) assert torch.all(~torch.isclose(x_grad[:, expected_conditioning], torch.ones(*x[:, expected_conditioning].shape), rtol=0.0)) # The flow doesn't alter and doesn't depend on the fixed DOFs. if expected_fixed is not None: assert torch.allclose(x[:, expected_fixed], y[:, expected_fixed]) # The output does not depend on the fixed DOFs. assert torch.allclose(x_grad[:, expected_fixed], torch.ones(*x[:, expected_fixed].shape)) # The center atom should be left untouched. if fix_origin: assert torch.allclose(x[:, origin_atom], y[:, origin_atom]) # Check rotational frame of reference. if fix_orientation: if fix_origin: origin = x[:, origin_atom] else: origin = torch.zeros(tfep_map.hparams.batch_size, 3) # The direction center-axis should be the same (up to a flip). dir_01_x = torch.nn.functional.normalize(x[:, axes_atoms[0]] - origin) dir_01_y = torch.nn.functional.normalize(y[:, axes_atoms[0]] - origin) sign = torch.sign(batchwise_dot(dir_01_x, dir_01_y)).unsqueeze(1) assert torch.allclose(sign*dir_01_x, dir_01_y) # The mapped plane atom should be orthogonal to the plane-center-axis plane normal. dir_02_x = torch.nn.functional.normalize(x[:, axes_atoms[1]] - origin) plane_x = torch.cross(dir_01_x, dir_02_x) dir_02_y = torch.nn.functional.normalize(y[:, axes_atoms[1]] - origin) assert torch.allclose(batchwise_dot(plane_x, dir_02_y), torch.zeros(len(dir_02_y)))