Source code for tfep.nn.transformers.sos

#!/usr/bin/env python


# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Sum-of-squares polynomial transformer for autoregressive normalizing flows.
"""


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

import numpy as np
import torch
import torch.autograd

from tfep.nn.transformers.transformer import MAFTransformer


# =============================================================================
# SUM-OF-SQUARES POLYNOMIAL TRANSFORMER
# =============================================================================

[docs] class SOSPolynomialTransformer(MAFTransformer): """Sum-of-squares polynomial transformer module for autoregressive normalizing flows. This is an implementation of the polynomial transformer proposed in [1]. :math:`y_i = a_0 + \int_0^{x_i} \sum_{k=1}^K \left( \sum_{l=0}^L a_{kl} z^l \right)^2 dz` where :math:`K` and :math:`L` are the total number and degree of the polynomials respectively, and :math:`a_X` represent the parameters of the transformer. Only sums of squared first-degree polynomials (i.e., L=1) are currently supported as they are the only one with an analytic inverse and sum of zeroth degree polynomials (i.e., L=0) are equivalent to affine transformer. See Also -------- nets.functions.transformer.sos_polynomial_transformer References ---------- [1] Jaini P, Selby KA, Yu Y. Sum-of-Squares Polynomial Flow. arXiv preprint arXiv:1905.02325. 2019 May 7. """
[docs] def __init__(self, n_polynomials=2): """Constructor. Parameters ---------- n_polynomials : int The functional form of this transformer is a sum of squared polynomials. This is the number of such polynomials, which must be greater than 1. The more polynomials, the greater the number of parameters. Default is 2. """ super().__init__() if n_polynomials < 2: raise ValueError('n_polynomials must be strictly greater than 1.') self.n_polynomials = n_polynomials
@property def degree_polynomials(self): """The degree of each squared polynomial.""" return 1 @property def parameters_per_polynomial(self): """Numer of parameters needed by the transformer for each squared polynomial.""" return self.degree_polynomials + 1 @property def n_parameters_per_feature(self): """Number of parameters needed by the transformer for each input dimension.""" return self.parameters_per_polynomial * self.n_polynomials + 1
[docs] def forward(self, x: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]: """Apply the transformation to the input. Parameters ---------- x : torch.Tensor Shape ``(batch_size, n_features)``. Input tensor. parameters : torch.Tensor Shape ``(batch_size, (1 + K*L)*n_features)``. The coefficients of the squared polynomials obtained from the conditioner. The coefficients are ordered by polynomial so that ``parameters[:,0]`` is :math:`a_0` followed by :math:`a_{10}, a_{11}, ..., a_{K0}, a_{K1}`. Returns ------- y : torch.Tensor Shape ``(batch_size, n_vectors*dimension)``. The transformed vectors. log_det_J : torch.Tensor Shape ``(batch_size,)``. The logarithm of the absolute value of the Jacobian determinant ``dy / dx``. """ # From (batch, n_parameters*n_features) to (batch, n_parameters, n_features). batch_size = parameters.shape[0] parameters = parameters.reshape(batch_size, self.n_parameters_per_feature, -1) return sos_polynomial_transformer(x, parameters)
[docs] def inverse(self, y: torch.Tensor, parameters: torch.Tensor) -> tuple[torch.Tensor]: """Currently not implemented.""" raise NotImplementedError( 'Inversion of SOS polynomial transformer has not been implemented yet.')
[docs] def get_identity_parameters(self, n_features: int) -> torch.Tensor: """Return the value of the parameters that makes this the identity function. This can be used to initialize the normalizing flow to perform the identity transformation. Parameters ---------- n_features : int The dimension of the input vector of the transformer. Returns ------- parameters : torch.Tensor Shape ``(1+K*L, n_features)`` where ``K`` and ``L`` are the number and degree of the polynomials. """ id_conditioner = torch.zeros(size=(self.n_parameters_per_feature, n_features)) # The sum of the squared linear parameters must be 1. id_conditioner[1::self.parameters_per_polynomial].fill_(np.sqrt(1 / self.n_polynomials)) return id_conditioner.flatten()
[docs] def get_degrees_out(self, degrees_in: torch.Tensor) -> torch.Tensor: """Returns the degrees associated to the conditioner's output. Parameters ---------- degrees_in : torch.Tensor Shape ``(n_transformed_features,)``. The autoregressive degrees associated to the features provided as input to the transformer. Returns ------- degrees_out : torch.Tensor Shape ``(n_parameters,)``. The autoregressive degrees associated to each output of the conditioner that will be fed to the transformer as parameters. """ return degrees_in.tile((self.n_parameters_per_feature,))
# ============================================================================= # FUNCTIONAL API # =============================================================================
[docs] class SOSPolynomialTransformerFunc(torch.autograd.Function): r"""Implement the sum-of-squares polynomial transformer for triangular maps. This provides a functional API for the :class:`~tfep.nn.transformers.SOSPolynomialTransformer` layer. It implements the polynomial transformer proposed in [1]. :math:`y_i = a_0 + \int_0^{x_i} \sum_{k=1}^K \left( \sum_{l=0}^L a_{kl} z^l \right)^2 dz` where :math:`K` and :math:`L` are the total number and degree of the polynomials respectively, and :math:`a_X` represent the parameters of the transformer. The function returns the transformed feature as a ``Tensor`` of shape ``(batch_size, n_features)`` and the log absolute determinant of its Jacobian as a ``Tensor`` of shape ``(batch_size,)``. Only sums of squared first-degree polynomials (i.e., L=1) are currently supported as they are the only one with an analytic inverse and sum of zeroth degree polynomials (i.e., L=0) are equivalent to affine transformer. Parameters ---------- x : torch.Tensor Shape ``(batch_size, n_features)``. Input tensor x. parameters : torch.Tensor Shape ``(batch_size, 1+K*L, n_features)``. The coefficients of the squared polynomials obtained from the conditioner. The coefficients are ordered by polynomial so that ``parameters[:,0]`` is :math:`a_0` followed by :math:`a_{10}, a_{11}, ..., a_{K0}, a_{K1}`. Returns ------- y : torch.Tensor Output tensor of shape ``(batch_size, n_features)``. log_det_J : torch.Tensor The logarithm of the absolute value of the determinant of the Jacobian of the transformation with shape ``(batch_size,)``. References ---------- [1] Jaini P, Selby KA, Yu Y. Sum-of-Squares Polynomial Flow. arXiv preprint arXiv:1905.02325. 2019 May 7. """
[docs] @staticmethod def forward(ctx, x, parameters): # Compute the parameters of the sos polynomial. sos_degree_coefficients = SOSPolynomialTransformerFunc.get_sos_poly_coefficients(parameters) # Compute the power of x. x_powers = [x, x*x] # Compute y and the gradient of y w.r.t. x. y = sos_degree_coefficients[1].clone() grad_x = sos_degree_coefficients[1].clone() for degree, coef in enumerate(sos_degree_coefficients[2:]): term = coef * x_powers[degree] y += term grad_x += (degree+2) * term y *= x y += sos_degree_coefficients[0] log_det_J = torch.sum(torch.log(grad_x), dim=1) # Save tensor used for backward() before returning. ctx.save_for_backward(grad_x, parameters, *x_powers) # We don't need to compute gradients of log_det_J. ctx.mark_non_differentiable(log_det_J) return y, log_det_J
[docs] @staticmethod def backward(ctx, grad_y, grad_log_det_J): saved_grad_x, parameters, x, x2 = ctx.saved_tensors grad_x = grad_parameters = None batch_size, n_features = saved_grad_x.shape # Compute gradients w.r.t. input parameters. if ctx.needs_input_grad[0]: grad_x = saved_grad_x * grad_y if ctx.needs_input_grad[1]: grad_parameters = torch.empty_like(parameters) # The first coefficient is the constant term. grad_parameters[:, 0] = torch.ones( size=(batch_size, n_features), dtype=saved_grad_x.dtype) # Zeroth and first degree terms of the inner polynomials. zeroth_degree_terms = parameters[:, 1::2] first_degree_terms = parameters[:, 2::2] # We need to add a dimension corresponding to the number of # coefficients in the power of x for them to be broadcastable. x = x.unsqueeze(1) x2 = x2.unsqueeze(1) x3 = x2 * x grad_parameters[:, 1::2] = first_degree_terms*x2 + 2*zeroth_degree_terms*x grad_parameters[:, 2::2] = 2/3*first_degree_terms*x3 + zeroth_degree_terms*x2 grad_parameters = grad_parameters * grad_y.unsqueeze(1) return grad_x, grad_parameters
[docs] @staticmethod def get_sos_poly_coefficients(parameters): """Compute the coefficient of the SOS polynomial. Parameters ---------- parameters : torch.Tensor The coefficients of the squared polynomials obtained from the conditioner. Each ``Tensor`` has shape ``(batch_size, 1+K*L, n_features)``. The coefficients are ordered by polynomial so that ``parameters[:,0]`` is :math:`a_0` followed by :math:`a_{10}, a_{11}, ..., a_{K0}, a_{K1}`. Returns ------- sos_poly_coefficients : List[torch.Tensor] ``sos_poly_coefficients[i]`` is a tensor of shape ``(batch_size, n_features)`` with the coefficients of the term of the SOS polynomial of degree ``i``. """ # We support only L=1 for now. Number of coefficients in # each summed polynomials include also the constant term. coeff_per_inner_poly = 2 batch_size, _, n_features = parameters.shape # inner_degree_parameters[d][b][p] is the parameter for the term of # the p-th inner polynomial of degree d for the b-th batch sample. inner_degree_coefficients = [] for degree in range(coeff_per_inner_poly): inner_degree_coefficients.append(parameters[:, 1+degree::coeff_per_inner_poly]) # Find the coefficients of the integrated polynomial. sos_degree_coefficients = [parameters[:, 0]] sos_degree_coefficients.append(torch.sum(inner_degree_coefficients[0]**2, dim=1)) sos_degree_coefficients.append(torch.sum(inner_degree_coefficients[0]*inner_degree_coefficients[1], dim=1)) sos_degree_coefficients.append(torch.sum(inner_degree_coefficients[1]**2, dim=1) / 3) return sos_degree_coefficients
# Functional notation. sos_polynomial_transformer = SOSPolynomialTransformerFunc.apply