Analytical example with kernel herding#
Step-by-step usage of kernel herding on an analytical example, enforcing a unique coreset.
In this example, we have data of:
and choose a length_scale of \(\frac{1}{\sqrt{2}}\) to simplify computations
with the SquaredExponentialKernel, in particular it becomes:
- Kernel herding should do as follows:
Compute the Gramian row mean, that is for each data-point \(x\) and all other data-points \(x'\), \(\frac{1}{N} \sum_{x'} k(x, x')\) where we have \(N\) data-points in total.
Select the first coreset point \(x_{1}\) as the data-point where the Gramian row mean is highest.
Compute all future coreset points as \(x_{T+1} = \arg\max_{x} \left( \mathbb{E}[k(x, x')] - \frac{1}{T+1}\sum_{t=1}^T k(x, x_t) \right)\) where we currently have \(T\) points in the coreset.
We ask for a coreset of size 2 in this example. With an empty coreset, we first compute \(\mathbb{E}[k(x, x')]\) as:
resulting in:
The largest value in this array is 0.9906914124997632, so we expect the first
coreset point to be [0.4, 0.2], that is the data-point at index 1 in the
dataset. At this point we have coreset_indices as [1, ?].
We then compute the penalty update term \(\frac{1}{T+1}\sum_{t=1}^T k(x, x_t)\) with \(T = 1\):
which evaluates to:
We now select the data-point that maximises \(\mathbb{E}[k(x, x')] - \frac{1}{T+1}\sum_{t=1}^T k(x, x_t)\), which evaluates to:
giving a final result of:
The largest value in this array is at index 1, which would be to again choose
the point [0.4, 0.2] for the coreset. However, in this example we enforce the
coreset to be unique, that is not to select the same data-point twice, which
means we should take the next highest value in the above result to include in
our coreset. This happens to be 0.4845485203517275, the data-point at index 2.
This means our final coreset_indices should be [1, 2].
Finally, the solver state tracks variables we need not compute repeatedly. In the case of kernel herding, we don’t need to recompute \(\mathbb{E}[k(x, x')]\) at every single step - so the solver state from the coreset reduce method should be set to:
This example would be run in coreax using:
from coreax import Data, SquaredExponentialKernel, KernelHerding
import equinox as eqx
# Define the data
coreset_size = 2
length_scale = 1.0 / jnp.sqrt(2)
x = jnp.array([
[0.3, 0.25],
[0.4, 0.2],
[0.5, 0.125],
])
# Define a kernel
kernel = SquaredExponentialKernel(length_scale=length_scale)
# Generate the coreset, using equinox to JIT compile the code and speed up
# generation for larger datasets
data = Data(x)
solver = KernelHerding(coreset_size=coreset_size, kernel=kernel, unique=True)
coreset, solver_state = eqx.filter_jit(solver.reduce)(data)
# Inspect results
print(coreset.unweighted_indices) # The coreset_indices
print(coreset.coreset.data) # The data-points in the coreset
print(solver_state.gramian_row_mean) # The stored gramian_row_mean
Coreax also supports weighted data. If we have the same data as described above, but weights of:
we would expect a different resulting coreset. The computation of the gramian row mean, \(\mathbb{E}[k(x, x')]\), becomes:
resulting in:
The largest value in this array is 0.9933471580051769, so we expect the first coreset
point to be [0.3 0.25], that is the data-point at index 0 in the dataset. At this point
we have coreset_indices as [0, ?].
We then compute the penalty update term \(\frac{1}{T+1}\sum_{t=1}^T k(x, x_t)\) with \(T = 1\) and get:
Finally, we select the next coreset point to maximise:
which means our final coreset_indices should be [0, 1]. In coreax, this example
would be run as:
from coreax import Data, SquaredExponentialKernel, KernelHerding
import equinox as eqx
# Define the data
coreset_size = 2
length_scale = 1.0 / jnp.sqrt(2)
x = jnp.array([
[0.3, 0.25],
[0.4, 0.2],
[0.5, 0.125],
])
weights = jnp.array([0.8, 0.1, 0.1])
# Define a kernel
kernel = SquaredExponentialKernel(length_scale=length_scale)
# Generate the coreset, using equinox to JIT compile the code and speed up
# generation for larger datasets
data = Data(x, weights=weights)
solver = KernelHerding(coreset_size=coreset_size, kernel=kernel, unique=True)
coreset, solver_state = eqx.filter_jit(solver.reduce)(data)
# Inspect results
print(coreset.unweighted_indices) # The coreset_indices
print(coreset.coreset.data) # The data-points in the coreset
print(solver_state.gramian_row_mean) # The stored gramian_row_mean