tfep.tests.nn.create_random_input

tfep.tests.nn.create_random_input(batch_size: int, n_features: int, n_parameters: int = 0, dtype: type | None = None, seed: int | None = None, x_func: ~collections.abc.Callable = <built-in method randn of type object>, par_func: ~collections.abc.Callable = <built-in method randn of type object>) Tensor | tuple[Tensor][source]

Create random input and parameters.

Parameters:
  • x_func (Callable, optional) – The random function used to generate x. Default is torch.randn.

  • par_func (Callable, optional) – The random function used to generate parameters. Default is torch.randn.

Returns:

  • x (torch.Tensor) – Shape (batch_size, n_features). The random input.

  • parameters (torch.Tensor, optional) – Shape (batch_size, n_parameters). The random parameters. This is returned only if n_parameters > 0.