Score Matching#
Classes and associated functionality to perform score matching.
The score function of some data is the derivative of the log-PDF. Score matching
aims to determine a model by matching the score function of the model to that
of the data. Exactly how the score function is modelled is specific to each
child class of the abstract base class ScoreMatching.
An example use of score matching arises when trying to work with a
SteinKernel, which requires as an input a score function. If
this is known analytically, one can provide an exact score function. In other cases,
approximations to the score function are required, which can be determined using
ScoreMatching.
When using SlicedScoreMatching, the score function is approximated using a
neural network, whereas in KernelDensityMatching, it is approximated by fitting
and then differentiating a kernel density estimate to the data.
- class coreax.score_matching.ScoreMatching[source]#
Base class for score matching algorithms.
The score function of some data is the derivative of the log-PDF. Score matching aims to determine a model by ‘matching’ the score function of the model to that of the data. Exactly how the score function is modelled is specific to each child class of this base class.
- class coreax.score_matching.SlicedScoreMatching(random_key, random_generator, noise_conditioning=True, use_analytic=False, num_random_vectors=1, learning_rate=0.001, num_epochs=10, batch_size=64, hidden_dims=(128, 128, 128), optimiser=<function adamw>, num_noise_models=100, sigma=1.0, gamma=0.95)[source]#
Implementation of slice score matching, defined in [ssm].
The score function of some data is the derivative of the log-PDF. Score matching aims to determine a model by ‘matching’ the score function of the model to that of the data. Exactly how the score function is modelled is specific to each child class of
ScoreMatching.With sliced score matching, we train a neural network to directly approximate the score function of the data. The approach is outlined in detail in [ssm].
Note
The inputs num_random_vectors and num_noise_models are set to 1 if they are given any smaller than this.
- Parameters:
random_key (
ArrayLike) – Key for random number generationrandom_generator (
Callable[[ArrayLike,Sequence[int],Union[str,type[Any],dtype,SupportsDType]],Array]) – Distribution sampler (key,shape,dtype) \(\rightarrow\)Array, e.g. distributions inrandomnoise_conditioning (
bool) – Use the noise conditioning version of score matching. Defaults toTrue.use_analytic (
bool) – Use the analytic (reduced variance) objective or not. Defaults toFalse.num_random_vectors (
int) – The number of random vectors to use per data vector. Defaults to 1.learning_rate (
float) – Optimiser learning rate. Defaults to 1e-3.num_epochs (
int) – Number of epochs for training. Defaults to 10.batch_size (
int) – Size of mini-batch. Defaults to 64.hidden_dims (
Sequence[int]) – Sequence of ScoreNetwork hidden layer sizes. Defaults to [128, 128, 128] denoting 3 hidden layers each composed of 128 nodes.optimiser (
Callable[[float],GradientTransformation]) – The optax optimiser to use. Defaults to optax.adam.num_noise_models (
int) – Number of noise models to use in noise conditional score matching. Defaults to 100.sigma (
float) – Initial noise standard deviation for noise geometric progression in noise conditional score matching. Defaults to 1.gamma (
float) – Geometric progression ratio. Defaults to 0.95.
-
random_generator:
Callable[[ArrayLike,Sequence[int],Union[str,type[Any],dtype,SupportsDType]],Array]#
-
optimiser:
Callable[[float],GradientTransformation]#
- _objective_function(random_direction_vector, grad_score_times_random_direction_matrix, score_matrix)[source]#
Compute the score matching loss function.
Two objectives are proposed in [ssm], a general objective, and a simplification with reduced variance that holds for particular assumptions. The choice between the two is determined by the boolean
use_analyticdefined when the class is initiated.- Parameters:
- Returns:
Evaluation of score matching objective, see equations 7 and 8 in [ssm]
- static _analytic_objective(random_direction_vector, grad_score_times_random_direction_matrix, score_matrix)[source]#
Compute reduced variance score matching loss function.
This is for use with certain random measures, e.g. normal and Rademacher. If this assumption is not true, then
SlicedScoreMatching._general_objective()should be used instead.- Parameters:
- Return type:
- Returns:
Evaluation of score matching objective, see equation 8 in [ssm]
- static _general_objective(random_direction_vector, grad_score_times_random_direction_matrix, score_matrix)[source]#
Compute general score matching loss function.
This is to be used when one cannot assume normal or Rademacher random measures when using score matching, but has higher variance than
SlicedScoreMatching._analytic_objective()if these assumptions hold.- Parameters:
- Return type:
- Returns:
Evaluation of score matching objective, see equation 7 in [ssm]
- _loss_element(x, v, score_network)[source]#
Compute element-wise loss function.
Computes the loss function from Section 3.2 of Song el al.’s paper on sliced score matching [ssm].
- _loss(score_network)[source]#
Compute vector mapped loss function for arbitrary many
XandVvectors.In the context of score matching, we expect to call the objective function on the data vector
x, random vectorsvand using the score neural network.
- _train_step(state, x, random_vectors)[source]#
Apply a single training step that updates model parameters using loss gradient.
- Parameters:
state (
TrainState) – TheTrainStateobjectx (
ArrayLike) – The \(n \times d\) data vectorsrandom_vectors (
ArrayLike) – The \(n \times m \times d\) random vectors
- Return type:
- Returns:
The updated
TrainStateobject
- _noise_conditional_loop_body(i, obj, state, params, x, random_vectors, sigmas)[source]#
Sum objective function with noise perturbations.
Inputs are perturbed by Gaussian random noise to improve performance of score matching. See [improved_sgm] for details.
- Parameters:
i (
int) – Loop indexobj (
float) – Running objective, i.e. the current partial sumstate (
TrainState) – TheTrainStateobjectparams (
dict) – The current iterate parameter settingsx (
Array) – The \(n \times d\) data vectorsrandom_vectors (
ArrayLike) – The \(n \times m \times d\) random vectorssigmas (
Array) – The geometric progression of noise standard deviations
- Return type:
- Returns:
The updated objective, i.e. partial sum
- _noise_conditional_train_step(state, x, random_vectors, sigmas)[source]#
Apply a single training step that updates model parameters using loss gradient.
- Parameters:
state (
TrainState) – TheTrainStateobjectx (
Array) – The \(n \times d\) data vectorsrandom_vectors (
Array) – The \(n \times m \times d\) random vectorssigmas (
Array) – Array of noise standard deviations to use in objective function
- Return type:
- Returns:
The updated
TrainStateobject
- match(x)[source]#
Learn a sliced score matching function from Song et al.’s paper [ssm].
We currently use the
ScoreNetworkneural network to approximate the score function. Alternative network architectures can be considered.
- class coreax.score_matching.KernelDensityMatching(length_scale)[source]#
Implementation of a kernel density estimate to determine a score function.
The score function of some data is the derivative of the log-PDF. Score matching aims to determine a model by ‘matching’ the score function of the model to that of the data. Exactly how the score function is modelled is specific to each child class of this base class.
With kernel density matching, we approximate the underlying distribution function from a dataset using kernel density estimation, and then differentiate this to compute an estimate of the score function. A Gaussian kernel is used to construct the kernel density estimate.
- Parameters:
length_scale (
float) – Kernellength_scaleto use when fitting the kernel density estimate
- match(x)[source]#
Learn a score function using kernel density estimation to model a distribution.
For the kernel density matching approach, the score function is determined by fitting a kernel density estimate to samples from the underlying distribution and then differentiating this. Therefore, learning in this context refers to simply defining the score function and kernel density estimate given some samples we wish to evaluate the score function at, and the data used to build the kernel density estimate.
- coreax.score_matching.convert_stein_kernel(x, kernel, score_matching)[source]#
Convert the kernel to a
SteinKernel.- Parameters:
x (
ArrayLike) – The data used to call score_matching.match(x)kernel (
Kernel) –Kernelinstance implementing a kernel function \(k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}\); if ‘kernel’ is aSteinKernelandscore_matching is not None, a new instance of the kernel will be generated where the score function is given byscore_matching.match(x)score_matching (
Optional[ScoreMatching]) – Specifies/overwrite the score function of the implied/passedSteinKernel; ifNone, default toKernelDensityMatchingunless ‘kernel’ is aSteinKernel, in which case the kernel’s existing score function is used.
- Return type:
- Returns:
The (potentially) converted/updated
SteinKernel.