Networks#

Classes and associated functionality to define neural networks.

Neural networks are used throughout the codebase as functional approximators.

class coreax.networks.ScoreNetwork(hidden_dims, output_dim, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

A feed-forward neural network for use in sliced score matching.

See SlicedScoreMatching for an example usage of this class.

Parameters:
  • hidden_dims (Sequence) – Sequence of hidden dimension layer sizes. Each element of the sequence corresponds to one hidden layer.

  • output_dim (int) – Number of output layer nodes.

  • parent (Type[Module] | Scope | Type[_Sentinel] | None) –

  • name (str | None) –

hidden_dims: Sequence#
coreax.networks.create_train_state(random_key, module, learning_rate, data_dimension, optimiser)[source]#

Create a flax TrainState for learning with.

Parameters:
Return type:

TrainState

Returns:

TrainState object