tfep.nn.flows.continuous.ContinuousFlow
- class tfep.nn.flows.continuous.ContinuousFlow(dynamics, trace_estimator='hutchinson', solver='dopri5', solver_options=None, n_hutchinson_samples=1, adjoint=True, regularization=True, vmap=False, requires_backward=True)[source]
Bases:
ModuleContinuous normalizing flow.
This implements continuous normalizing flows as proposed in [1]. The trace can be estimated using Hutchinson’s stochastic estimator [2] at the cost of one backpropagation or exactly using D backpropagations, where D is the dimension of each sample.
Optionally, the flow can return also a regularization term as proposed in [3] that can be incorporated into the loss to keep the ODE dynamics used for the flow smoother.
- Parameters:
dynamics (torch.nn.Module) – The neural network taking a time tensor (shape
(1)) and the current positions (shape(batch_size, n_particles*3)) and returning the velocity of the dynamics (shape(batch_size, n_particles*3)).trace_estimator (‘exact’ or ‘hutchinson’, optional) – Whether the trace (and the Frobenious norm if
regularizationisTrue) of the Jacobian is computed exactly withn_particles*3backpropagation passes or using the hutchinson estimates described in [3] usingn_hutchinson_samplesbackpropagation passes. The random variable is sampled from a normal distribution.solver (str, optional) – One of the solvers supported by the
torchdiffeqpackage.solver_options (dict, optional) – A dictionary of solver options to pass to
torchdiffeq.odeint.n_hutchinson_samples (int, optional) – The number of normally-distributed sampled to be drawn for the Hutchinson estimate of the trace. If
trace_estimator == 'exact'this is ignored.adjoint (bool, optional) – If
Truethe backpropagation is performed using the adjoint method as described in [1]. Otherwise, automatic differentiation is used.regularization (bool, optional) – If
True,forward()returns also a regularization term, which is the sum of the velocity norm and the Frobenious norm of the Jacobian as described in [3].vmap (bool, optional) – If
True, the estimato of the trace and Frobenious norm are performed using the experimental vectorization features oftorch.autograd.grad(which are currently only in the unreleased development version).requires_backward (bool, optional) – If
False, theautogradcalls used to compute the trace and regularization terms will not create a graph for differentiation. This means that backpropagation (even with the adjoint method) will not take into account the contribution from these terms.
References
- [1] Chen RT, Rubanova Y, Bettencourt J, Duvenaud D. Neural ordinary differential
equations. arXiv preprint arXiv:1806.07366. 2018 Jun 19.
- [2] Grathwohl W, Chen RT, Bettencourt J, Sutskever I, Duvenaud D. Ffjord:
Free-form continuous dynamics for scalable reversible generative models. arXiv preprint arXiv:1810.01367. 2018 Oct 2.
- [3] Finlay C, Jacobsen JH, Nurbekyan L, Oberman A. How to train your neural
ODE: the world of Jacobian and kinetic regularization. In International Conference on Machine Learning 2020 Nov 21 (pp. 3154-3164). PMLR.
- __init__(dynamics, trace_estimator='hutchinson', solver='dopri5', solver_options=None, n_hutchinson_samples=1, adjoint=True, regularization=True, vmap=False, requires_backward=True)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Methods
__init__(dynamics[, trace_estimator, ...])Initialize internal Module state, shared by both nn.Module and ScriptModule.
add_module(name, module)Add a child module to the current module.
apply(fn)Apply
fnrecursively to every submodule (as returned by.children()) as well as self.bfloat16()Casts all floating point parameters and buffers to
bfloat16datatype.buffers([recurse])Return an iterator over module buffers.
children()Return an iterator over immediate children modules.
compile(*args, **kwargs)Compile this Module's forward using
torch.compile().cpu()Move all model parameters and buffers to the CPU.
cuda([device])Move all model parameters and buffers to the GPU.
double()Casts all floating point parameters and buffers to
doubledatatype.eval()Set the module in evaluation mode.
extra_repr()Set the extra representation of the module.
float()Casts all floating point parameters and buffers to
floatdatatype.forward(x)Map the input data.
get_buffer(target)Return the buffer given by
targetif it exists, otherwise throw an error.get_extra_state()Return any extra state to include in the module's state_dict.
get_parameter(target)Return the parameter given by
targetif it exists, otherwise throw an error.get_submodule(target)Return the submodule given by
targetif it exists, otherwise throw an error.half()Casts all floating point parameters and buffers to
halfdatatype.inverse(y)ipu([device])Move all model parameters and buffers to the IPU.
load_state_dict(state_dict[, strict, assign])Copy parameters and buffers from
state_dictinto this module and its descendants.modules()Return an iterator over all modules in the network.
mtia([device])Move all model parameters and buffers to the MTIA.
named_buffers([prefix, recurse, ...])Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children()Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules([memo, prefix, remove_duplicate])Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters([prefix, recurse, ...])Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parameters([recurse])Return an iterator over module parameters.
register_backward_hook(hook)Register a backward hook on the module.
register_buffer(name, tensor[, persistent])Add a buffer to the module.
register_forward_hook(hook, *[, prepend, ...])Register a forward hook on the module.
register_forward_pre_hook(hook, *[, ...])Register a forward pre-hook on the module.
register_full_backward_hook(hook[, prepend])Register a backward hook on the module.
register_full_backward_pre_hook(hook[, prepend])Register a backward pre-hook on the module.
register_load_state_dict_post_hook(hook)Register a post-hook to be run after module's
load_state_dict()is called.register_load_state_dict_pre_hook(hook)Register a pre-hook to be run before module's
load_state_dict()is called.register_module(name, module)Alias for
add_module().register_parameter(name, param)Add a parameter to the module.
register_state_dict_post_hook(hook)Register a post-hook for the
state_dict()method.register_state_dict_pre_hook(hook)Register a pre-hook for the
state_dict()method.requires_grad_([requires_grad])Change if autograd should record operations on parameters in this module.
set_extra_state(state)Set extra state contained in the loaded state_dict.
set_submodule(target, module)Set the submodule given by
targetif it exists, otherwise throw an error.share_memory()See
torch.Tensor.share_memory_().state_dict(*args[, destination, prefix, ...])Return a dictionary containing references to the whole state of the module.
to(*args, **kwargs)Move and/or cast the parameters and buffers.
to_empty(*, device[, recurse])Move the parameters and buffers to the specified device without copying storage.
train([mode])Set the module in training mode.
type(dst_type)Casts all parameters and buffers to
dst_type.xpu([device])Move all model parameters and buffers to the XPU.
zero_grad([set_to_none])Reset gradients of all model parameters.
Attributes
T_destinationcall_super_initdump_patchestraining- forward(x)[source]
Map the input data.
- Parameters:
x (torch.Tensor) – An input batch of data of shape
(batch_size, dimension_in).- Returns:
y (torch.Tensor) – The mapped data of shape
(batch_size, dimension_in).trace (torch.Tensor) – The instantaneous log absolute value of the Jacobian of the flow (equal to the trace of the jacobian) as a tensor of shape
(batch_size,).reg (torch.Tensor, optional) – A regularization term of shape
(batch_size,)that can be included in the loss for regularization. This is returned only ifself.regularizationisTrue.