Score Matching¶
Classes and associated functionality to perform score matching.
The score function of some data is the derivative of the log-PDF. Score matching
aims to determine a model by matching the score function of the model to that
of the data. Exactly how the score function is modelled is specific to each
child class of the abstract base class ScoreMatching.
An example use of score matching arises when trying to work with a
SteinKernel, which requires as an input a score function. If
this is known analytically, one can provide an exact score function. In other cases,
approximations to the score function are required, which can be determined using
ScoreMatching.
When using SlicedScoreMatching, the score function is approximated using a
neural network, whereas in KernelDensityMatching, it is approximated by fitting
and then differentiating a kernel density estimate to the data.
- class coreax.score_matching.ScoreMatching[source]¶
Bases:
ModuleAbstract base class for score matching algorithms.
The score function of some data is the derivative of the log-PDF. Score matching aims to determine a model by ‘matching’ the score function of the model to that of the data. Exactly how the score function is modelled is specific to each child class of this base class.
This class should only contain abstract methods. Subclasses must implement all abstract methods to create concrete score matching algorithms.
- class coreax.score_matching.SlicedScoreMatching(random_key, random_generator, noise_conditioning=True, use_analytic=False, num_random_vectors=1, learning_rate=0.001, num_epochs=10, batch_size=64, hidden_dims=(128, 128, 128), optimiser=<function adamw>, num_noise_models=100, sigma=1.0, gamma=0.95, progress_bar=False)[source]¶
Bases:
ScoreMatchingImplementation of slice score matching, defined in [song2020ssm].
The score function of some data is the derivative of the log-PDF. Score matching aims to determine a model by ‘matching’ the score function of the model to that of the data. Exactly how the score function is modelled is specific to each child class of
ScoreMatching.With sliced score matching, we train a neural network to directly approximate the score function of the data. The approach is outlined in detail in [song2020ssm].
Note
The inputs num_random_vectors and num_noise_models are set to 1 if they are given any smaller than this.
- Parameters:
random_key (
Array) – Key for random number generationrandom_generator (
Callable[[Array,Sequence[int],Union[str,type[Any],dtype,SupportsDType]],Array]) – Distribution sampler (key,shape,dtype) \(\rightarrow\)Array, e.g. distributions inrandomnoise_conditioning (
bool) – Use the noise conditioning version of score matching. Defaults toTrue.use_analytic (
bool) – Use the analytic (reduced variance) objective or not. Defaults toFalse.num_random_vectors (
int) – The number of random vectors to use per data vector. Defaults to 1.learning_rate (
float) – Optimiser learning rate. Defaults to 1e-3.num_epochs (
int) – Number of epochs for training. Defaults to 10.batch_size (
int) – Size of mini-batch. Defaults to 64.hidden_dims (
Sequence[int]) – Sequence of ScoreNetwork hidden layer sizes. Defaults to [128, 128, 128] denoting 3 hidden layers each composed of 128 nodes.optimiser (
Callable[[float],GradientTransformation]) – The optax optimiser to use. Defaults to optax.adam.num_noise_models (
int) – Number of noise models to use in noise conditional score matching. Defaults to 100.sigma (
float) – Initial noise standard deviation for noise geometric progression in noise conditional score matching. Defaults to 1.gamma (
float) – Geometric progression ratio. Defaults to 0.95.progress_bar (
bool) – Boolean indicating whether or not to write a progress bar tracking the training of the neural network. Defaults toFalse.
-
random_generator:
Callable[[Array,Sequence[int],Union[str,type[Any],dtype,SupportsDType]],Array]¶
-
optimiser:
Callable[[float],GradientTransformation]¶
-
progress_bar:
Union[type[tqdm],type[SilentTQDM]]¶
- match(x)[source]¶
Learn a sliced score matching function via [song2020ssm].
We currently use the
ScoreNetworkneural network to approximate the score function. Alternative network architectures can be considered.- Parameters:
x – The \(n \times d\) data vectors
- Returns:
A function that applies the learned score function to input
x
- class coreax.score_matching.KernelDensityMatching(length_scale)[source]¶
Bases:
ScoreMatchingImplementation of a kernel density estimate to determine a score function.
The score function of some data is the derivative of the log-PDF. Score matching aims to determine a model by ‘matching’ the score function of the model to that of the data. Exactly how the score function is modelled is specific to each child class of this base class.
With kernel density matching, we approximate the underlying distribution function from a dataset using kernel density estimation, and then differentiate this to compute an estimate of the score function. A Gaussian kernel is used to construct the kernel density estimate.
- Parameters:
length_scale (
float) – Kernellength_scaleto use when fitting the kernel density estimate
-
kernel:
ScalarValuedKernel¶
- match(x)[source]¶
Learn a score function using kernel density estimation to model a distribution.
For the kernel density matching approach, the score function is determined by fitting a kernel density estimate to samples from the underlying distribution and then differentiating this. Therefore, learning in this context refers to simply defining the score function and kernel density estimate given some samples we wish to evaluate the score function at, and the data used to build the kernel density estimate.
- Parameters:
x – Set of \(n \times d\) samples from the underlying distribution that are used to build the kernel density estimate
- Returns:
A function that applies the learned score function to input
x
- coreax.score_matching.convert_stein_kernel(x, kernel, score_matching)[source]¶
Convert the kernel to a
SteinKernel.- Parameters:
x (
Shaped[Array, 'n d']) – The data used to call score_matching.match(x)kernel (
ScalarValuedKernel) –ScalarValuedKernelinstance implementing a kernel function \(k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}\); if ‘kernel’ is aSteinKernelandscore_matching is not data:`None, a new instance of the kernel will be generated where the score function is given byscore_matching.match(x)score_matching (
Optional[ScoreMatching]) – Specifies/overwrite the score function of the implied/passedSteinKernel; ifNone, default toKernelDensityMatchingunless ‘kernel’ is aSteinKernel, in which case the kernel’s existing score function is used.
- Return type:
- Returns:
The (potentially) converted/updated
SteinKernel.