Networks

Classes and associated functionality to define neural networks.

Neural networks are used throughout the codebase as functional approximators.

class coreax.networks.ScoreNetwork(hidden_dims, output_dim, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: Module

A feed-forward neural network for use in sliced score matching.

See SlicedScoreMatching for an example usage of this class.

Parameters:
  • hidden_dims (Sequence) – Sequence of hidden dimension layer sizes. Each element of the sequence corresponds to one hidden layer.

  • output_dim (int) – Number of output layer nodes.

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

hidden_dims: Sequence
coreax.networks.create_train_state(random_key, module, learning_rate, data_dimension, optimiser)[source]

Create a flax TrainState for learning with.

Parameters:
Return type:

TrainState

Returns:

TrainState object