Quickstart#

Here are some of the most commonly used classes and methods in the library.

Kernel herding#

Kernel herding is one (greedy) approach to coreset construction. A Kernel herding solver can be created by supplying a ScalarValuedKernel object, such as a SquaredExponentialKernel. A coreset is then generated by calling the reduce() method on the original dataset.

Note that, throughout the codebase, there are block versions of herding for fitting within memory constraints. These methods partition the data into blocks before carrying out the coreset algorithm, restricting the maximum size of variables handled in the process.

# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from sklearn.datasets import make_blobs

from coreax.data import Data
from coreax.kernels import SquaredExponentialKernel, median_heuristic
from coreax.solvers import KernelHerding

# Generate some data
num_data_points = 10_000
num_features = 2
num_cluster_centers = 6
random_seed = 1989
x, *_ = make_blobs(
    num_data_points,
    n_features=num_features,
    centers=num_cluster_centers,
    random_state=random_seed,
)

# Request 100 coreset points
coreset_size = 100

# Setup the original data object
data = Data(x)

# Set the bandwidth parameter of the kernel using a median heuristic derived from
# at most 1000 random samples in the data.
num_samples_length_scale = min(num_data_points, 1_000)
generator = np.random.default_rng(random_seed)
idx = generator.choice(num_data_points, num_samples_length_scale, replace=False)
length_scale = median_heuristic(x[idx])

# Compute a coresubset using kernel herding with a squared exponential kernel.
herding_solver = KernelHerding(
    coreset_size, kernel=SquaredExponentialKernel(length_scale=length_scale)
)
herding_coreset, _ = herding_solver.reduce(data)

# We can now print the selected coresubset indices and the materialized coresubset
print(herding_coreset.unweighted_indices)
print(herding_coreset.coreset)

Kernel herding with weighting#

A coreset can be weighted, a so-called weighted coreset, to attribute importance to each point and to better approximate the underlying data distribution. Optimal weights can be determined by implementing a WeightsOptimiser, such as the MMDWeightsOptimiser weights optimiser.

# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from coreax import SquaredExponentialKernel
from coreax.solvers import KernelHerding
from coreax.weights import MMDWeightsOptimiser

# Define a kernel
kernel = SquaredExponentialKernel(length_scale=length_scale)

# Define a weights optimiser to learn optimal weights for the coreset after creation
weights_optimiser = MMDWeightsOptimiser(kernel=kernel)

# Compute a coreset using kernel herding with a squared exponential kernel.
herding_solver = KernelHerding(coreset_size, kernel=kernel)
herding_coreset, _ = herding_solver.reduce(data)

# Determine optimal weights for the coreset
re_weighted_herding_coreset = herding_coreset.solve_weights(weights_optimiser)

Kernel herding with refine#

To improve the quality of a coreset, a refine step can be executed. These functions work by substituting points from the coreset with points from the original dataset such that some metric decreases. This improves the coreset quality because the refined coreset better captures the underlying distribution of the original data, as measured by the reduced metric.

There are several different approaches to refining a coreset, which can be found in the children and methods of RefinementSolver(). In the example below, we instantiate a RefinementSolver, specifically a coreax.solvers.KernelHerding solver, and then call the refine() method on the solution yielded in the prior call to reduce().

# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from coreax import SquaredExponentialKernel
from coreax.solvers import KernelHerding

# Compute a coreset using kernel herding with a squared exponential kernel.
herding_solver = KernelHerding(
    coreset_size,
    kernel=SquaredExponentialKernel(length_scale=length_scale),
)
herding_coreset, _ = herding_solver.reduce(data)

# Refine the coreset to improve quality
refined_herding_coreset = herding_solver.refine(herding_coreset)

# We can now print the selected coresubset indices and the materialized coresubset
print(refined_herding_coreset.coreset)
print(refined_herding_coreset.coreset_indices)

Scalable herding#

For large \(n\) or \(d\), you may run into time or memory issues. The class MapReduce uses partitioning to tractably compute an approximate coreset in reasonable time. There is a necessary impact on coreset quality, for a dramatic improvement in computation time. These methods can be used by simply composing a Solver in the previous examples with MapReduce and setting the parameter leaf_size in line with memory requirements.

# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from coreax.kernels import SquaredExponentialKernel
from coreax.solvers import KernelHerding, MapReduce

# Compute a coreset using kernel herding with a squared exponential kernel.
herding_solver = KernelHerding(
    coreset_size,
    kernel=SquaredExponentialKernel(length_scale=length_scale),
)
mapped_herding_solver = MapReduce(herding_solver, leaf_size=200)
mapped_herding_coreset, _ = mapped_herding_solver.reduce(data)

For large \(d\), it is usually worth reducing dimensionality using PCA. See examples.pounce_map_reduce for an example.

Stein kernel herding#

We have implemented a version of kernel herding that uses a Stein kernel, which targets kernelised Stein discrepancy (KSD) [liu2016kernelized] rather than MMD. This can often give better integration error in practice, but it can be slower than using a simpler kernel targeting MMD. To use Stein kernel herding, we have to define a continuous approximation to the discrete measure, e.g. using kernel density estimation (KDE), or an estimate the score function \(\nabla \log f_X(\mathbf{x})\) of a continuous PDF from a finite set of samples. In this example, we use a Stein kernel with a squared exponential base kernel, computing the score function explicitly.

# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from coreax import (
    KernelDensityMatching,
    SquaredExponentialKernel,
    SteinKernel,
)
from coreax.solvers import KernelHerding

# Select a subset of data from which to learn score function
generator = np.random.default_rng(random_seed)
idx = generator.choice(len(data), subset_size, replace=False)
data_subset = data[idx, :]

# Learn a score function from the subset of the data, through a kernel density
# estimation applied to a subset of the data.
kernel_density_score_matcher = KernelDensityMatching(length_scale=length_scale)
score_function = kernel_density_score_matcher.match(data_subset)

# Define a kernel to use for herding
herding_kernel = SteinKernel(
    SquaredExponentialKernel(length_scale=length_scale),
    score_function=score_function,
)

# Compute a coreset using kernel herding with a Stein kernel
herding_solver = KernelHerding(coreset_size, kernel=herding_kernel)
herding_coreset, _ = herding_solver.reduce(data)

Score matching example#

The score function, \(\nabla \log f_X(\mathbf{x})\), of a distribution is the derivative of the log-density function. This function is required when evaluating Stein kernels. However, it can be difficult to specify analytically in practice.

To resolve this, we have implemented an approximation of the score function using a neural network [song2020ssm]. This approximate score function can then be passed directly to a Stein kernel, removing any requirement for analytical derivation. More details on score matching methods implemented are found in coreax.score_matching.

# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import numpy as np

from coreax import SlicedScoreMatching, SteinKernel
from coreax.kernels import PCIMQKernel

# Select a subset of data from which to learn score function
generator = np.random.default_rng(random_seed)
idx = generator.choice(len(data), subset_size, replace=False)
data_subset = data[idx, :]

# Learn a score function from a subset of the data, through approximation using a neural
# network applied to a subset of the data
score_key = jax.random.key(random_seed)
sliced_score_matcher = SlicedScoreMatching(
    score_key,
    random_generator=jax.random.rademacher,
    use_analytic=True,
    num_epochs=10,
    num_random_vectors=1,
    sigma=1.0,
    gamma=0.95,
)
score_function = sliced_score_matcher.match(data_subset)

# Define a kernel to use for herding
herding_kernel = SteinKernel(
    PCIMQKernel(length_scale=length_scale),
    score_function=score_function,
)

JIT compilation#

JAX enables us to perform just-in-time (JIT) compilation of our code (providing certain mild conditions are met), with the potential to significantly improve runtime performance. It is for this reason that in all the examples, we always call eqx.filter_jit(...) on the reduction method. Because we already have equinox as a dependency, it makes sense to use this nicer variant of the JIT transformation, over jax.jit(). See here <https://docs.kidger.site/equinox/api/transformations/#equinox.filter_jit> for more information.

We can apply the JIT transformation to most methods within Coreax, but we do not apply these transformation by default. This is because we typically want to apply the JIT transformation at the highest possible level/scope to give the compiler the best chance of performing all the possible optimisations, at the cost of high compile times.