Analytical example with RPCholesky

Step-by-step usage of the RPCholesky algorithm (Algorithm 1 in [chen2023randomly]) on a small example with 3 data points in 2 dimensions and a coreset of size 2, i.e., \(N=3, m=2\).

In this example, we have the following data:

\[\begin{split}X = \begin{pmatrix} 0.5 & 0.2 \\ 0.4 & 0.6 \\ 0.8 & 0.3 \end{pmatrix}\end{split}\]

We choose a SquaredExponentialKernel with length_scale of \(\frac{1}{\sqrt{2}}\): for two points \(x, y \in X\), \(k(x, y) = e^{-||x - y||^2}\). We now compute the Gram matrix, \(A\), of the dataset \(X\) with respect to the kernel \(k\) as \(A_{ij} = k(X_i, X_j)\):

\[\begin{split}A = \begin{pmatrix} 1.0 & 0.84366477 & 0.90483737 \\ 0.84366477 & 1.0 & 0.7788007 \\ 0.90483737 & 0.7788007 & 1.0 \end{pmatrix}\end{split}\]

Note that, in practice, we do not need to precompute the full Gram matrix, the algorithm only needs to evaluate the pivot column at each iteration.

To apply the RPCholesky algorithm, we first initialise the residual diagonal \(d = \text{diag}(A)\) and the approximation matrix \(F = \mathbf{0}_{N \times m}\), where \(N = 3, m = 2\) in our case.

We now build a coreset iteratively by applying the following steps at each iteration i:
  • Sample a datapoint index (called a pivot) proportional to \(d\)

  • Compute/extract column \(g\) corresponding to the pivot index from \(A\)

  • Remove the overlap with previously selected columns from \(g\)

  • Normalize the column and add it to the approximation matrix \(F\)

  • Update the residual diagonal: \(d = d - |F[:,i]|^2\)

For the first iteration (i=0):

1. We sample a pivot point proportional to their value on the diagonal. Since \(d\) is initialised as \((1, 1, 1)\) in our case, all choices are equally likely, so let us suppose we choose the pivot with index = 2.

  1. We now compute \(g\), the column at index 2, as:

\[\begin{split}g = \begin{pmatrix} 0.90483737 \\ 0.7788007 \\ 1.0 \end{pmatrix}\end{split}\]
  1. Remove overlap with previously chosen columns (not needed on the first iteration).

  2. Update the approximation matrix:

\[\begin{split}F[:, 0] = g / \sqrt{(g[2])} = \begin{pmatrix} 0.90483737 \\ 0.7788007 \\ 1.0 \end{pmatrix}\end{split}\]
  1. Update the residual diagonal:

\[\begin{split}d = d - |F[:,0]|^2 = \begin{pmatrix} 0.18126933 \\ 0.39346947 \\ 0 \end{pmatrix}\end{split}\]

For the second iteration (i=1):

1. We again sample a pivot point proportional to their value on the updated residual diagonal, \(d\). Let’s suppose we choose the most likely pivot here (index=1).

  1. We now compute g, the column at index 1, as:

\[\begin{split}g = \begin{pmatrix} 0.84366477 \\ 1.0 \\ 0.7788007 \end{pmatrix}\end{split}\]
  1. Remove overlap with previously chosen columns:

\[\begin{split}g = g - F[:, 0] F[1, 0]^T = \begin{pmatrix} 0.13897679 \\ 0.39346947 \\ 0 \end{pmatrix}\end{split}\]
  1. Update the approximation matrix:

\[\begin{split}F[:, 1] = g / \sqrt{(g[1])} = \begin{pmatrix} 0.22155766 \\ 0.62727145 \\ 0 \end{pmatrix}\end{split}\]
  1. Update the residual diagonal:

\[\begin{split}d = d - |F[:,0]|^2 = \begin{pmatrix} 0.13218154 \\ 0 \\ 0 \end{pmatrix}\end{split}\]

After this iteration, the final state is:

\[\begin{split}F = \begin{pmatrix} 0.90483737 & 0.22155766 \\ 0.7788007 & 0.62727145 \\ 1.0 & 0 \end{pmatrix}, \quad d = \begin{pmatrix} 0.13218154 \\ 0 \\ 0 \end{pmatrix}, \quad S = \{2, 1\} \, .\end{split}\]

This completes the coreset of size \(m = 2\). We can also use the \(F\) to compute an approximation to the original Gram matrix:

\[\begin{split}F \cdot F^T = \begin{pmatrix} 0.86781846 & 0.84366477 & 0.90483737 \\ 0.84366477 & 1.0 & 0.7788007 \\ 0.90483737 & 0.7788007 & 1.0 \end{pmatrix}\end{split}\]

Note that we have recovered the original matrix except for \(A_{00}\), which was not covered by any of the chosen pivots.

import jax.numpy as jnp
import jax.random as jr
from unittest.mock import patch

from coreax import Data, SquaredExponentialKernel
from coreax.solvers import RPCholesky

# Setup example data
coreset_size = 2
x = jnp.array(
    [
        [0.5, 0.2],
        [0.4, 0.6],
        [0.8, 0.3],
    ]
)

# Define a kernel
length_scale = 1.0 / jnp.sqrt(2)
kernel = SquaredExponentialKernel(length_scale=length_scale)

# Create a mock for the random choice function
def deterministic_choice(*_, p, **__):
    """
    Return the index of largest element of p.

    If there is a tie, return the largest index.
    This is used to mimic random sampling, where we have a deterministic
    sampling approach.
    """
    # Find indices where the value equals the maximum
    is_max = p == p.max()
    # Convert boolean mask to integers and multiply by index
    # This way, we'll get the highest index where True appears
    indices = jnp.arange(p.shape[0])
    return jnp.where(is_max, indices, -1).max()


# Generate the coreset
data = Data(x)
solver = RPCholesky(
    coreset_size=coreset_size,
    random_key=jr.PRNGKey(0),  # Fixed seed for reproducibility
    kernel=kernel,
    unique=True,
)

# Mock the random choice function
with patch("jax.random.choice", deterministic_choice):
    coreset, solver_state = solver.reduce(data)

# Independently computed gramian diagonal
expected_gramian_diagonal = jnp.array([0.13218154, 0.0, 0.0])

# Coreset indices forced by our mock choice function
expected_coreset_indices = jnp.array([2, 1])

# Inspect results
print("Chosen coreset:")
print(coreset.unweighted_indices)  # The coreset_indices
print(coreset.points.data)  # The data-points in the coreset
print("Residual diagonal:")
print(solver_state.gramian_diagonal)