[docs]classScoreNetwork(nn.Module):""" A feed-forward neural network for use in sliced score matching. See :class:`~coreax.score_matching.SlicedScoreMatching` for an example usage of this class. :param hidden_dims: Sequence of hidden dimension layer sizes. Each element of the sequence corresponds to one hidden layer. :param output_dim: Number of output layer nodes. """hidden_dims:Sequenceoutput_dim:int@nn.compactdef__call__(self,x:Shaped[Array," b n d"])->Shaped[Array," b output_dim"]:r""" Compute forward pass through a three-layer network with softplus activations. :param x: Batch input data :math:`b \times n \times d` :return: Network output on batch :math:`b \times` ``self.output_dim`` """fordiminself.hidden_dims:x=nn.Dense(dim)(x)x=nn.softplus(x)returnnn.Dense(self.output_dim)(x)
[docs]defcreate_train_state(random_key:KeyArrayLike,module:Module,learning_rate:float,data_dimension:int,optimiser:_LearningRateOptimiser,)->TrainState:""" Create a flax :class:`~flax.training.train_state.TrainState` for learning with. :param random_key: Key for random number generation :param module: Subclass of :class:`~flax.linen.Module` :param learning_rate: Optimiser learning rate :param data_dimension: Data dimension :param optimiser: optax optimiser, e.g. :func:`~optax.adam` :return: :class:`~flax.training.train_state.TrainState` object """params=module.init(random_key,jnp.ones((1,data_dimension)))["params"]tx=optimiser(learning_rate)returnTrainState.create(apply_fn=module.apply,params=params,tx=tx)