tfep.loss.BoltzmannKLDivLoss
- class tfep.loss.BoltzmannKLDivLoss(ignore_nan: bool = False)[source]
Bases:
ModuleKL divergence between two Boltzmann distributions.
The loss function assumes the sampling is done in the reference distribution A. The KL divergence between two Boltzmann distribution is then given by
\(D_{KL}[p_A||p_B] = \int p_A(x) \Delta u_{AB}(x) dx - \Delta f_{AB}\)
where \(p_A(x)\) is distribution A, \(\Delta u_{AB}(x) = u_B(x) - u_A(x)\) is the difference between the reduced potential energies B and A for configuration x (in units of \(k_B T\)), and \(\Delta f_{AB} = f_B - f_A\) is the reduced free energy difference (also in units of \(k_B T\)).
In TFEP, the KL divergence of interest is between A and the mapped distribution B’, whose potential energy includes the logarithm of the absolute value of the Jacobian of the map M
\(u_{B'}(x) = u_B(x) - log|det J_M(x)|\)
Moreover, because the free energy difference and reference potential energies do not depend on the map, they can be ignored, and the loss function can be optimized by minimizing
:math:`
rac{1}{N} sum_i u_{B’}(x_i)`
Finally, if the samples were not sampled from A, the mean must be weighted. If log-weights are passed to the function, the loss is
:math:`
rac{1}{N} sum_i rac{e^{w_i}}{sum_i e^{w_i}} u_{B’}(x_i)`
where \(w_i\) is the log-weight for the i-th sample, and correspond to potential energy difference between the sampled and A distributions.
- __init__(ignore_nan: bool = False)[source]
Constructor.
- Parameters:
ignore_nan (bool, optional) – Whether to ignore NaNs when computing the loss or not (which will cause the loss value to be NaN as well).
Methods
__init__([ignore_nan])Constructor.
add_module(name, module)Add a child module to the current module.
apply(fn)Apply
fnrecursively to every submodule (as returned by.children()) as well as self.bfloat16()Casts all floating point parameters and buffers to
bfloat16datatype.buffers([recurse])Return an iterator over module buffers.
children()Return an iterator over immediate children modules.
compile(*args, **kwargs)Compile this Module's forward using
torch.compile().cpu()Move all model parameters and buffers to the CPU.
cuda([device])Move all model parameters and buffers to the GPU.
double()Casts all floating point parameters and buffers to
doubledatatype.eval()Set the module in evaluation mode.
extra_repr()Set the extra representation of the module.
float()Casts all floating point parameters and buffers to
floatdatatype.forward(target_potentials[, log_det_J, ...])Compute the loss.
get_buffer(target)Return the buffer given by
targetif it exists, otherwise throw an error.get_extra_state()Return any extra state to include in the module's state_dict.
get_parameter(target)Return the parameter given by
targetif it exists, otherwise throw an error.get_submodule(target)Return the submodule given by
targetif it exists, otherwise throw an error.half()Casts all floating point parameters and buffers to
halfdatatype.ipu([device])Move all model parameters and buffers to the IPU.
load_state_dict(state_dict[, strict, assign])Copy parameters and buffers from
state_dictinto this module and its descendants.modules()Return an iterator over all modules in the network.
mtia([device])Move all model parameters and buffers to the MTIA.
named_buffers([prefix, recurse, ...])Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children()Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules([memo, prefix, remove_duplicate])Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters([prefix, recurse, ...])Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parameters([recurse])Return an iterator over module parameters.
register_backward_hook(hook)Register a backward hook on the module.
register_buffer(name, tensor[, persistent])Add a buffer to the module.
register_forward_hook(hook, *[, prepend, ...])Register a forward hook on the module.
register_forward_pre_hook(hook, *[, ...])Register a forward pre-hook on the module.
register_full_backward_hook(hook[, prepend])Register a backward hook on the module.
register_full_backward_pre_hook(hook[, prepend])Register a backward pre-hook on the module.
register_load_state_dict_post_hook(hook)Register a post-hook to be run after module's
load_state_dict()is called.register_load_state_dict_pre_hook(hook)Register a pre-hook to be run before module's
load_state_dict()is called.register_module(name, module)Alias for
add_module().register_parameter(name, param)Add a parameter to the module.
register_state_dict_post_hook(hook)Register a post-hook for the
state_dict()method.register_state_dict_pre_hook(hook)Register a pre-hook for the
state_dict()method.requires_grad_([requires_grad])Change if autograd should record operations on parameters in this module.
set_extra_state(state)Set extra state contained in the loaded state_dict.
set_submodule(target, module)Set the submodule given by
targetif it exists, otherwise throw an error.share_memory()See
torch.Tensor.share_memory_().state_dict(*args[, destination, prefix, ...])Return a dictionary containing references to the whole state of the module.
to(*args, **kwargs)Move and/or cast the parameters and buffers.
to_empty(*, device[, recurse])Move the parameters and buffers to the specified device without copying storage.
train([mode])Set the module in training mode.
type(dst_type)Casts all parameters and buffers to
dst_type.xpu([device])Move all model parameters and buffers to the XPU.
zero_grad([set_to_none])Reset gradients of all model parameters.
Attributes
T_destinationcall_super_initdump_patchesWhether to ignore NaNs when computing the loss or not.
training- forward(target_potentials: Tensor, log_det_J: Tensor | None = None, log_weights: Tensor | None = None, ref_potentials: Tensor | None = None) Tensor[source]
Compute the loss.
Warning
Because ``Tensor``s are unit-less you need to make sure all arguments are passed using consistent units.
Typically, the
log_det_Jobtained as output of the normalizing flow will be in units of \(k_BT\) so potentials and log-weights should be divided by \(k_BT\) as well.- Parameters:
target_potentials (torch.Tensor) –
target_potentials[i]is the reduced potential energy of the i-th (mapped) sample in units of kT evaluated using target potential B. The shape is(batch_size,).log_det_J (torch.Tensor or None, optional) –
log_det_J[i]is the logarithm of the absolute value of the determinant of the Jacobian of the map (in units of kT) for the i-th sample. The shape is(batch_size,).If not passed, it is assumed the samples were not mapped or, equivalently, that the Jacobian contribution has been already included in
potentials_B.log_weights (torch.Tensor or None, optional) –
log_weights[i]is the log-weight for the i-th sample (in units of kT) that can be used to reweight the loss function if the samples were not sampled from A. The shape is(batch_size,).ref_potentials (torch.Tensor or None, optional) –
ref_potentials_A[i]is the reduced potential energy of the i-th sample in units of kT evaluated using the reference potential A. The shape is(batch_size,).This is optional since it does not affect the optimization but only the value returned by the loss function.
- Returns:
loss – The value of the loss function.
- Return type:
torch.Tensor
- ignore_nan
Whether to ignore NaNs when computing the loss or not.