# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Classes and associated functionality to perform score matching.
The score function of some data is the derivative of the log-PDF. Score matching
aims to determine a model by matching the score function of the model to that
of the data. Exactly how the score function is modelled is specific to each
child class of the abstract base class :class:`ScoreMatching`.
An example use of score matching arises when trying to work with a
:class:`~coreax.kernels.SteinKernel`, which requires as an input a score function. If
this is known analytically, one can provide an exact score function. In other cases,
approximations to the score function are required, which can be determined using
:class:`ScoreMatching`.
When using :class:`SlicedScoreMatching`, the score function is approximated using a
neural network, whereas in :class:`KernelDensityMatching`, it is approximated by fitting
and then differentiating a kernel density estimate to the data.
"""
from abc import abstractmethod
from collections.abc import Callable, Sequence
from functools import partial
from typing import overload
import equinox as eqx
import numpy as np
from flax.training.train_state import TrainState
from jax import (
Array,
jvp,
numpy as jnp,
random,
value_and_grad,
vmap,
)
from jax.lax import cond, fori_loop
from jax.typing import DTypeLike
from jaxtyping import Shaped
from optax import adamw
from tqdm import tqdm
from typing_extensions import override
from coreax.kernels import ScalarValuedKernel, SquaredExponentialKernel, SteinKernel
from coreax.networks import ScoreNetwork, _LearningRateOptimiser, create_train_state
from coreax.util import KeyArrayLike
_RandomGenerator = Callable[[KeyArrayLike, Sequence[int], DTypeLike], Array]
[docs]
class ScoreMatching(eqx.Module):
"""
Abstract base class for score matching algorithms.
The score function of some data is the derivative of the log-PDF. Score matching
aims to determine a model by 'matching' the score function of the model to that
of the data. Exactly how the score function is modelled is specific to each
child class of this base class.
This class should only contain abstract methods. Subclasses must implement all
abstract methods to create concrete score matching algorithms.
"""
@abstractmethod
@overload
def match(
self, x: Shaped[Array, " 1 1"] | Shaped[Array, ""] | float | int
) -> Callable[
[Shaped[Array, " 1 1"] | Shaped[Array, ""] | float | int],
Shaped[Array, " 1 1"],
]: ...
@abstractmethod
@overload
def match( # pyright: ignore[reportOverlappingOverload]
self, x: Shaped[Array, " n d"]
) -> Callable[[Shaped[Array, " n d"]], Shaped[Array, " n d"]]: ...
[docs]
@abstractmethod
def match(
self, x: Shaped[Array, " n d"] | Shaped[Array, ""] | float | int
) -> (
Callable[[Shaped[Array, " n d"]], Shaped[Array, " n d"]]
| Callable[
[Shaped[Array, " 1 1"] | Shaped[Array, ""] | float | int],
Shaped[Array, " 1 1"],
]
):
r"""
Match some model score function to dataset :math:`X\in\mathbb{R}^{n \times d}`.
:param x: The :math:`n \times d` data vectors
"""
# pylint: disable=too-many-instance-attributes
[docs]
class SlicedScoreMatching(ScoreMatching):
r"""
Implementation of slice score matching, defined in :cite:`song2020ssm`.
The score function of some data is the derivative of the log-PDF. Score matching
aims to determine a model by 'matching' the score function of the model to that
of the data. Exactly how the score function is modelled is specific to each
child class of :class:`ScoreMatching`.
With sliced score matching, we train a neural network to directly approximate
the score function of the data. The approach is outlined in detail in
:cite:`song2020ssm`.
.. note::
The inputs `num_random_vectors` and `num_noise_models` are set to 1 if they are
given any smaller than this.
:param random_key: Key for random number generation
:param random_generator: Distribution sampler (``key``, ``shape``, ``dtype``)
:math:`\rightarrow` :class:`~jax.Array`, e.g. distributions in
:mod:`~jax.random`
:param noise_conditioning: Use the noise conditioning version of score matching.
Defaults to :data:`True`.
:param use_analytic: Use the analytic (reduced variance) objective or not.
Defaults to :data:`False`.
:param num_random_vectors: The number of random vectors to use per data vector.
Defaults to 1.
:param learning_rate: Optimiser learning rate. Defaults to 1e-3.
:param num_epochs: Number of epochs for training. Defaults to 10.
:param batch_size: Size of mini-batch. Defaults to 64.
:param hidden_dims: Sequence of ScoreNetwork hidden layer sizes. Defaults to
[128, 128, 128] denoting 3 hidden layers each composed of 128 nodes.
:param optimiser: The optax optimiser to use. Defaults to optax.adam.
:param num_noise_models: Number of noise models to use in noise
conditional score matching. Defaults to 100.
:param sigma: Initial noise standard deviation for noise geometric progression
in noise conditional score matching. Defaults to 1.
:param gamma: Geometric progression ratio. Defaults to 0.95.
:param progress_bar: Boolean indicating whether or not to write a progress bar
tracking the training of the neural network. Defaults to :data:`False`.
"""
random_key: KeyArrayLike
random_generator: _RandomGenerator
noise_conditioning: bool
use_analytic: bool
num_random_vectors: int
learning_rate: float
num_epochs: int
batch_size: int
hidden_dims: Sequence[int]
optimiser: _LearningRateOptimiser
num_noise_models: int
sigma: float
gamma: float
progress_bar: bool
# TODO: refactor this to require use of keyword arguments
# https://github.com/gchq/coreax/issues/782
# pylint: disable=too-many-arguments, too-many-positional-arguments
def __init__( # noqa: PLR0913, PLR0917
self,
random_key: KeyArrayLike,
random_generator: _RandomGenerator,
noise_conditioning: bool = True,
use_analytic: bool = False,
num_random_vectors: int = 1,
learning_rate: float = 1e-3,
num_epochs: int = 10,
batch_size: int = 64,
hidden_dims: Sequence[int] = (128, 128, 128),
optimiser: _LearningRateOptimiser = adamw,
num_noise_models: int = 100,
sigma: float = 1.0,
gamma: float = 0.95,
progress_bar: bool = False,
):
"""Define a sliced score matching class and update invalid inputs."""
# JAX will not error if we have num_random_vectors set to 0, but this approach
# is fundamentally about projecting along random vectors, so we cap the lower
# value for this at 1. Similarly, there must be at-least one noise model for
# the code to do the projections.
num_random_vectors = max(num_random_vectors, 1)
num_noise_models = max(num_noise_models, 1)
# Assign all inputs
self.random_key = random_key
self.random_generator = random_generator
self.noise_conditioning = noise_conditioning
self.use_analytic = use_analytic
self.num_random_vectors = num_random_vectors
self.learning_rate = learning_rate
self.num_epochs = num_epochs
self.batch_size = batch_size
self.hidden_dims = hidden_dims
self.optimiser = optimiser
self.num_noise_models = num_noise_models
self.sigma = sigma
self.gamma = gamma
self.progress_bar = progress_bar
# pylint: enable=too-many-arguments
def _objective_function(
self,
random_direction_vector: Shaped[Array, " d"],
grad_score_times_random_direction_matrix: Shaped[Array, " d"],
score_matrix: Shaped[Array, " d"],
) -> float:
"""
Compute the score matching loss function.
Two objectives are proposed in :cite:`song2020ssm`, a general objective, and a
simplification with reduced variance that holds for particular assumptions. The
choice between the two is determined by the boolean ``use_analytic`` defined
when the class is initiated.
:param random_direction_vector: :math:`d`-dimensional random vector
:param grad_score_times_random_direction_matrix: Product of the gradient of
score_matrix (w.r.t. ``x``) and the random_direction_vector
:param score_matrix: Gradients of log-density
:return: Evaluation of score matching objective, see equations 7 and 8 in
:cite:`song2020ssm`
"""
return cond(
self.use_analytic,
self._analytic_objective,
self._general_objective,
random_direction_vector,
grad_score_times_random_direction_matrix,
score_matrix,
)
@staticmethod
def _analytic_objective(
random_direction_vector: Shaped[Array, " d"],
grad_score_times_random_direction_matrix: Shaped[Array, " d"],
score_matrix: Shaped[Array, " d"],
) -> Shaped[Array, ""]:
"""
Compute reduced variance score matching loss function.
This is for use with certain random measures, e.g. normal and Rademacher. If
this assumption is not true, then
:meth:`SlicedScoreMatching._general_objective` should be used instead.
:param random_direction_vector: :math:`d`-dimensional random vector
:param grad_score_times_random_direction_matrix: Product of the gradient of
score_matrix (w.r.t. ``x``) and the random_direction_vector
:param score_matrix: Gradients of log-density
:return: Evaluation of score matching objective, see equation 8 in
:cite:`song2020ssm`
"""
result = (
random_direction_vector @ grad_score_times_random_direction_matrix
+ 0.5 * score_matrix @ score_matrix
)
return result
@staticmethod
def _general_objective(
random_direction_vector: Shaped[Array, " d"],
grad_score_times_random_direction_matrix: Shaped[Array, " d"],
score_matrix: Shaped[Array, " d"],
) -> Shaped[Array, ""]:
"""
Compute general score matching loss function.
This is to be used when one cannot assume normal or Rademacher random measures
when using score matching, but has higher variance than
:meth:`SlicedScoreMatching._analytic_objective` if these assumptions hold.
:param random_direction_vector: :math:`d`-dimensional random vector
:param grad_score_times_random_direction_matrix: Product of the gradient of
score_matrix (w.r.t. ``x``) and the random_direction_vector
:param score_matrix: Gradients of log-density
:return: Evaluation of score matching objective, see equation 7 in
:cite:`song2020ssm`
"""
result = (
random_direction_vector @ grad_score_times_random_direction_matrix
+ 0.5 * (random_direction_vector @ score_matrix) ** 2
)
return result
def _loss_element(
self, x: Shaped[Array, " d"], v: Shaped[Array, " d"], score_network: Callable
) -> float:
"""
Compute element-wise loss function.
Computes the loss function from Section 3.2 of Song el al.'s paper on sliced
score matching :cite:`song2020ssm`.
:param x: :math:`d`-dimensional data vector
:param v: :math:`d`-dimensional random vector
:param score_network: Function that calls the neural network on ``x``
:return: Objective function output for single ``x`` and ``v`` inputs
"""
s, u = jvp(score_network, (x,), (v,))
return self._objective_function(v, u, s)
def _loss(self, score_network: Callable) -> Callable:
"""
Compute vector mapped loss function for arbitrary many ``X`` and ``V`` vectors.
In the context of score matching, we expect to call the objective function on
the data vector ``x``, random vectors ``v`` and using the score neural
network.
:param score_network: Function that calls the neural network on ``x``
:return: Callable vectorised sliced score matching loss function
"""
inner = vmap(
lambda x, v: self._loss_element(x, v, score_network),
(None, 0),
0,
)
return vmap(inner, (0, 0), 0)
@eqx.filter_jit
def _train_step(
self,
state: TrainState,
x: Shaped[Array, " n d"],
random_vectors: Shaped[Array, " n m d"],
) -> tuple[TrainState, float]:
r"""
Apply a single training step that updates model parameters using loss gradient.
:param state: The :class:`~flax.training.train_state.TrainState` object
:param x: The :math:`n \times d` data vectors
:param random_vectors: The :math:`n \times m \times d` random vectors
:return: The updated :class:`~flax.training.train_state.TrainState` object
"""
def loss(params):
return self._loss(lambda x_: state.apply_fn({"params": params}, x_))(
x, random_vectors
).mean()
val, grads = value_and_grad(loss)(state.params)
state = state.apply_gradients(grads=grads)
return state, val
def _noise_conditional_loop_body(
self,
i: int,
obj: float,
state: TrainState,
params: dict,
x: Shaped[Array, " n d"],
random_vectors: Shaped[Array, " n m d"],
sigmas: Shaped[Array, " num_noise_models"],
) -> float:
r"""
Sum objective function with noise perturbations.
Inputs are perturbed by Gaussian random noise to improve performance of score
matching. See :cite:`song2020improved_sgm` for details.
:param i: Loop index
:param obj: Running objective, i.e. the current partial sum
:param state: The :class:`~flax.training.train_state.TrainState` object
:param params: The current iterate parameter settings
:param x: The :math:`n \times d` data vectors
:param random_vectors: The :math:`n \times m \times d` random vectors
:param sigmas: The geometric progression of noise standard deviations
:return: The updated objective, i.e. partial sum
"""
# This will generate the same set of random numbers on each function call. We
# might want to replace this with random.key(i) to get a unique set each
# time.
# Perturb the inputs with Gaussian noise
x_perturbed = x + sigmas[i] * random.normal(random.key(0), x.shape)
obj += (
sigmas[i] ** 2
* self._loss(lambda x_: state.apply_fn({"params": params}, x_))(
x_perturbed, random_vectors
).mean()
)
return obj
@eqx.filter_jit
def _noise_conditional_train_step(
self,
state: TrainState,
x: Shaped[Array, " n d"],
random_vectors: Shaped[Array, " n m d"],
sigmas: Shaped[Array, " num_noise_models"],
) -> tuple[TrainState, float]:
r"""
Apply a single training step that updates model parameters using loss gradient.
:param state: The :class:`~flax.training.train_state.TrainState` object
:param x: The :math:`n \times d` data vectors
:param random_vectors: The :math:`n \times m \times d` random vectors
:param sigmas: Array of noise standard deviations to use in objective function
:return: The updated :class:`~flax.training.train_state.TrainState` object
"""
def loss(params):
body = partial(
self._noise_conditional_loop_body,
state=state,
params=params,
x=x,
random_vectors=random_vectors,
sigmas=sigmas,
)
return fori_loop(0, self.num_noise_models, body, 0.0)
val, grads = value_and_grad(loss)(state.params)
state = state.apply_gradients(grads=grads)
return state, val
[docs]
@override
def match(self, x): # noqa: C901, PLR0912
r"""
Learn a sliced score matching function via :cite:`song2020ssm`.
We currently use the :class:`~coreax.networks.ScoreNetwork` neural network to
approximate the score function. Alternative network architectures can be
considered.
:param x: The :math:`n \times d` data vectors
:return: A function that applies the learned score function to input ``x``
"""
# Check format of input array. We use atleast_2d from JAX to perform
# conversions here which provides the desired handling of 1 dimensional arrays,
# whereas this handling differs if we instead used the custom function
# _atleast_2d_consistent in coreax.data.
x = jnp.atleast_2d(x)
# Setup neural network that will approximate the score function
num_points, data_dimension = x.shape
score_network = ScoreNetwork(self.hidden_dims, data_dimension)
# Define what a training step consists of - dependent on if we want to include
# noise perturbations
if self.noise_conditioning:
gammas = self.gamma ** jnp.arange(self.num_noise_models)
sigmas = self.sigma * gammas
train_step = partial(self._noise_conditional_train_step, sigmas=sigmas)
else:
train_step = self._train_step
# Define random projection vectors
generator_key, state_key, batch_key = random.split(self.random_key, 3)
try:
random_vectors = self.random_generator(
generator_key,
(num_points, self.num_random_vectors, data_dimension),
float,
)
except TypeError as exception:
if isinstance(self.num_random_vectors, float):
raise ValueError("num_random_vectors must be an integer") from exception
raise
# Define a training state
state = create_train_state(
state_key, score_network, self.learning_rate, data_dimension, self.optimiser
)
try:
loop_keys = random.split(batch_key, self.num_epochs)
except TypeError as exception:
if self.num_epochs < 0:
raise ValueError("num_epochs must be a positive integer") from exception
if isinstance(self.num_epochs, float):
raise TypeError("num_epochs must be a positive integer") from exception
raise
# Carry out main training loop to fit the neural network
tqdm_progress_bar = tqdm(range(self.num_epochs), disable=not self.progress_bar)
for i in tqdm_progress_bar:
# Sample some data-points to pass for this step
try:
idx = random.randint(loop_keys[i], (self.batch_size,), 0, num_points)
except TypeError as exception:
if self.batch_size < 0:
raise ValueError(
"batch_size must be a positive integer"
) from exception
if isinstance(self.batch_size, float):
raise TypeError(
"batch_size must be a positive integer"
) from exception
raise
# Apply training step
state, val = train_step(state, x[idx, :], random_vectors[idx, :])
# Print progress (limited to avoid excessive output)
if i % 10 == 0 and self.progress_bar:
tqdm_progress_bar.write(f"{i:>6}/{self.num_epochs}: loss {val:<.5f}")
# Return the learned score function, which is a callable
return lambda x_: state.apply_fn({"params": state.params}, x_)
# pylint: enable=too-many-instance-attributes
[docs]
class KernelDensityMatching(ScoreMatching):
r"""
Implementation of a kernel density estimate to determine a score function.
The score function of some data is the derivative of the log-PDF. Score matching
aims to determine a model by 'matching' the score function of the model to that
of the data. Exactly how the score function is modelled is specific to each
child class of this base class.
With kernel density matching, we approximate the underlying distribution function
from a dataset using kernel density estimation, and then differentiate this to
compute an estimate of the score function. A Gaussian kernel is used to construct
the kernel density estimate.
:param length_scale: Kernel ``length_scale`` to use when fitting the kernel density
estimate
"""
kernel: ScalarValuedKernel
def __init__(self, length_scale: float):
"""Define the kernel density matching class."""
# Define a normalised Gaussian kernel (which is a special cases of the squared
# exponential kernel) to construct the kernel density estimate
self.kernel = SquaredExponentialKernel(
length_scale=length_scale,
output_scale=1.0 / (np.sqrt(2 * np.pi) * length_scale),
)
super().__init__()
[docs]
@override
def match(self, x):
r"""
Learn a score function using kernel density estimation to model a distribution.
For the kernel density matching approach, the score function is determined by
fitting a kernel density estimate to samples from the underlying distribution
and then differentiating this. Therefore, learning in this context refers to
simply defining the score function and kernel density estimate given some
samples we wish to evaluate the score function at, and the data used to build
the kernel density estimate.
:param x: Set of :math:`n \times d` samples from the underlying distribution
that are used to build the kernel density estimate
:return: A function that applies the learned score function to input ``x``
"""
kde_data = x
@overload
def score_function(
x_: Shaped[Array, " 1 1"] | Shaped[Array, ""] | float | int,
) -> Shaped[Array, " 1 1"]: ...
@overload
def score_function( # pyright: ignore[reportOverlappingOverload]
x_: Shaped[Array, " n d"],
) -> Shaped[Array, " n d"]: ...
def score_function(
x_: Shaped[Array, " n d"] | Shaped[Array, ""] | float | int,
) -> Shaped[Array, " n d"] | Shaped[Array, " 1 1"]:
r"""
Compute the score function using a kernel density estimation.
The score function is determined by fitting a kernel density estimate to
samples from the underlying distribution and then differentiating this. The
kernel density estimate is create using a Gaussian kernel.
:param x_: The :math:`n \times d` data vectors we wish to evaluate the score
function at
"""
# Check format of input array. We use atleast_2d from JAX to perform
# conversions here. If we instead used the custom function
# _atleast_2d_consistent in coreax.data, we would require more
# processing when calling the methods on the kernel and the output values
# from these methods can differ from the expected outputs.
original_number_of_dimensions = jnp.asarray(x_).ndim
x_ = jnp.atleast_2d(x_)
# Get the gram matrix row-mean
gram_matrix_row_means = self.kernel.compute_mean(x_, kde_data, axis=1)
# Compute gradients with respect to x
gradients = self.kernel.grad_x(x_, kde_data).mean(axis=1)
# Compute final evaluation of the score function
score_result = gradients / gram_matrix_row_means[:, None]
# Ensure output format accounts for 1-dimensional inputs as-well as
# multi-dimensional ones
if original_number_of_dimensions == 1:
score_result = score_result[0, :]
return score_result
return score_function
[docs]
def convert_stein_kernel(
x: Shaped[Array, " n d"],
kernel: ScalarValuedKernel,
score_matching: ScoreMatching | None,
) -> SteinKernel:
r"""
Convert the kernel to a :class:`~coreax.kernels.SteinKernel`.
:param x: The data used to call `score_matching.match(x)`
:param kernel: :class:`~coreax.kernels.ScalarValuedKernel` instance implementing a
kernel function
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}`; if 'kernel'
is a :class:`~coreax.kernels.SteinKernel` and :code:`score_matching is not
data:`None`, a new instance of the kernel will be generated where the score
function is given by :code:`score_matching.match(x)`
:param score_matching: Specifies/overwrite the score function of the implied/passed
:class:`~coreax.kernels.SteinKernel`; if :data:`None`, default to
:class:`~coreax.score_matching.KernelDensityMatching` unless 'kernel' is a
:class:`~coreax.kernels.SteinKernel`, in which case the kernel's existing score
function is used.
:return: The (potentially) converted/updated :class:`~coreax.kernels.SteinKernel`.
"""
if isinstance(kernel, SteinKernel):
if score_matching is not None:
_kernel = eqx.tree_at(
lambda x: x.score_function, kernel, score_matching.match(x)
)
else:
_kernel = kernel
else:
if score_matching is None:
length_scale = getattr(kernel, "length_scale", 1.0)
score_matching = KernelDensityMatching(length_scale)
_kernel = SteinKernel(kernel, score_function=score_matching.match(x))
return _kernel