Source code for tfep.nn.conditioners.conditioner

#!/usr/bin/env python


# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Base conditioner classes for autoregressive flows.
"""


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

import abc

import torch


# =============================================================================
# BASE CLASSES
# =============================================================================

[docs] class Conditioner(abc.ABC, torch.nn.Module): """A conditioner for an autoregressive flow. This class documents the API of a conditioner layer compatible with an :class:`.AutoregressiveFlow`. """
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Compute the parameters for the transformer. Parameters ---------- x : torch.Tensor Shape ``(batch_size, n_features)``. The input features. Returns ------- parameters : torch.Tensor Shape ``(batch_size, n_parameters)``. The parameters for the transformer. """ return super().forward(x) # Raises NotImplementedError.
[docs] @abc.abstractmethod def set_output(self, output: torch.Tensor): """Sets the parameters of the conditioner to produce a constant output. This is used to force the autoregressive flow to implement the identity function on initialization. Parameters ---------- output : torch.Tensor Shape ``(n_parameters,)``. The desired output of the conditioner. """ pass