#!/usr/bin/env python
# =============================================================================
# MODULE DOCSTRING
# =============================================================================
"""
Loss functions to train PyTorch normalizing flows for reweighting.
"""
# =============================================================================
# GLOBAL IMPORTS
# =============================================================================
from typing import Optional
import torch
# =============================================================================
# LOSS FUNCTIONS
# =============================================================================
[docs]
class BoltzmannKLDivLoss(torch.nn.Module):
"""KL divergence between two Boltzmann distributions.
The loss function assumes the sampling is done in the reference distribution
A. The KL divergence between two Boltzmann distribution is then given by
:math:`D_{KL}[p_A||p_B] = \int p_A(x) \Delta u_{AB}(x) dx - \Delta f_{AB}`
where :math:`p_A(x)` is distribution A, :math:`\Delta u_{AB}(x) = u_B(x) - u_A(x)`
is the difference between the reduced potential energies B and A for configuration
x (in units of :math:`k_B T`), and :math:`\Delta f_{AB} = f_B - f_A` is the
reduced free energy difference (also in units of :math:`k_B T`).
In TFEP, the KL divergence of interest is between A and the mapped
distribution B', whose potential energy includes the logarithm of the
absolute value of the Jacobian of the map M
:math:`u_{B'}(x) = u_B(x) - log|det J_M(x)|`
Moreover, because the free energy difference and reference potential energies
do not depend on the map, they can be ignored, and the loss function can be
optimized by minimizing
:math:`\frac{1}{N} \sum_i u_{B'}(x_i)`
Finally, if the samples were not sampled from A, the mean must be weighted.
If log-weights are passed to the function, the loss is
:math:`\frac{1}{N} \sum_i \frac{e^{w_i}}{\sum_i e^{w_i}} u_{B'}(x_i)`
where :math:`w_i` is the log-weight for the i-th sample, and correspond to
potential energy difference between the sampled and A distributions.
"""
[docs]
def __init__(self, ignore_nan: bool = False):
"""Constructor.
Parameters
----------
ignore_nan : bool, optional
Whether to ignore NaNs when computing the loss or not (which will
cause the loss value to be NaN as well).
"""
super().__init__()
#: Whether to ignore NaNs when computing the loss or not.
self.ignore_nan = ignore_nan
[docs]
def forward(
self,
target_potentials: torch.Tensor,
log_det_J: Optional[torch.Tensor] = None,
log_weights: Optional[torch.Tensor] = None,
ref_potentials: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute the loss.
.. warning::
Because ``Tensor``s are unit-less you need to make sure all arguments
are passed using consistent units.
Typically, the ``log_det_J`` obtained as output of the normalizing
flow will be in units of :math:`k_BT` so potentials and log-weights
should be divided by :math:`k_BT` as well.
Parameters
----------
target_potentials : torch.Tensor
``target_potentials[i]`` is the reduced potential energy of the i-th
(mapped) sample in units of kT evaluated using target potential B.
The shape is ``(batch_size,)``.
log_det_J : torch.Tensor or None, optional
``log_det_J[i]`` is the logarithm of the absolute value of the
determinant of the Jacobian of the map (in units of kT) for the i-th
sample. The shape is ``(batch_size,)``.
If not passed, it is assumed the samples were not mapped or,
equivalently, that the Jacobian contribution has been already included
in ``potentials_B``.
log_weights : torch.Tensor or None, optional
``log_weights[i]`` is the log-weight for the i-th sample (in units
of kT) that can be used to reweight the loss function if the samples
were not sampled from A. The shape is ``(batch_size,)``.
ref_potentials : torch.Tensor or None, optional
``ref_potentials_A[i]`` is the reduced potential energy of the i-th
sample in units of kT evaluated using the reference potential A. The
shape is ``(batch_size,)``.
This is optional since it does not affect the optimization but only
the value returned by the loss function.
Returns
-------
loss : torch.Tensor
The value of the loss function.
"""
reduced_work = target_potentials
if log_det_J is not None:
reduced_work = reduced_work - log_det_J
if ref_potentials is not None:
reduced_work = reduced_work - ref_potentials
# Check if this must be a weighted or unweighted mean.
if log_weights is not None:
weights = torch.nn.functional.softmax(log_weights)
if self.ignore_nan:
return torch.nansum(weights * reduced_work)
return torch.sum(weights * reduced_work)
if self.ignore_nan:
return torch.nanmean(reduced_work)
return torch.mean(reduced_work)