# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Classes and associated functionality to optimise weighted representations of data.
Several aspects of this codebase take a :math:`n \times d` dataset and generate an
alternative representation of it, for example a coreset. The quality of this alternative
representation in approximating the original dataset can be assessed using some metric
of interest, for example see :class:`~coreax.metrics.Metric`.
One can improve the quality of the representation generated by weighting the individual
elements of it. These weights are determined by optimising the metric of interest, which
compares the original :math:`n \times d` dataset and the generated representation of it.
This module provides functionality to calculate such weights, through various methods.
All methods implement :class:`WeightsOptimiser` and must have a
:meth:`~WeightsOptimiser.solve` method that, given two datasets, returns an array of
weights such that a metric of interest is optimised when these weights are applied to
the dataset.
"""
from abc import abstractmethod
from typing import Generic, TypeVar
import equinox as eqx
import jax.numpy as jnp
from jax import Array
from jaxopt import OSQP
from jaxtyping import Shaped
from coreax.data import Data, as_data
from coreax.kernels import ScalarValuedKernel
from coreax.util import apply_negative_precision_threshold
_Data = TypeVar("_Data", bound=Data)
INVALID_KERNEL_DATA_COMBINATION = (
"Invalid combination of 'kernel' and 'dataset' or 'coreset'; if solving weights for"
+ " unsupervised 'Data', one must pass child of 'Kernel'. 'SupervisedData' is not"
+ " currently supported by any implemented solvers."
)
[docs]
def solve_qp(
kernel_mm: Shaped[Array, "m m"],
gramian_row_mean: Shaped[Array, " m 1"],
**osqp_kwargs,
) -> Shaped[Array, " m"]:
r"""
Solve quadratic programs with the :class:`jaxopt.OSQP` solver.
Solves simplex weight problems of the form:
.. math::
\mathbf{w}^{\mathrm{T}} \mathbf{k} \mathbf{w} +
\bar{\mathbf{k}}^{\mathrm{T}} \mathbf{w} = 0
subject to
.. math::
\mathbf{Aw} = \mathbf{1}, \qquad \mathbf{Gx} \le 0.
:param kernel_mm: :math:`m \times m` coreset Gram matrix
:param gramian_row_mean: :math:`m \times 1` array of Gram matrix means
:return: Optimised solution for the quadratic program
"""
# Setup optimisation problem - all variable names are consistent with the OSQP
# terminology. Begin with the objective parameters.
q_array = jnp.asarray(kernel_mm)
c = -jnp.asarray(gramian_row_mean)
# Define the equality constraint parameters
num_points = q_array.shape[0]
a_array = jnp.ones((1, num_points))
b = jnp.array([1.0])
# Define the inequality constraint parameters
g_array = jnp.eye(num_points) * -1.0
h = jnp.zeros(num_points)
# Define solver object and run solver
qp = OSQP(**osqp_kwargs)
sol = qp.run(
params_obj=(q_array, c), params_eq=(a_array, b), params_ineq=(g_array, h)
).params
# Ensure conditions of solution are met
solution = apply_negative_precision_threshold(sol.primal, jnp.inf)
return solution / jnp.sum(solution)
# Disable this so we can type check specifically the parent class
# # pylint: disable = unidiomatic-typecheck
def _prepare_kernel_system(
kernel: ScalarValuedKernel,
dataset: Data,
coreset: Data,
epsilon: float = 1e-10,
*,
block_size: int | None | tuple[int | None, int | None] = None,
unroll: int | bool | tuple[int | bool, int | bool] = 1,
) -> tuple[Array, Array]:
r"""
Return the row mean of :math`k(coreset, dataset)` and the coreset Gramian.
:param kernel: :class:`~coreax.kernels.ScalarValuedKernel` instance implementing a
kernel function
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}` if solving
with unsupervised data.
:param dataset: :class:`~coreax.data.Data` instance consisting of a
:math:`n \times d` data array
:param coreset: :class:`~coreax.data.Data` instance consisting of a
:math:`m \times d` data array
:param epsilon: Small positive value to add to the kernel Gram matrix to aid
numerical solver computations
:param block_size: Block size passed to the ``self.kernel.compute_mean``
:param unroll: Unroll parameter passed to ``self.kernel.compute_mean``
:return: Row mean of k(coreset, dataset) and the epsilon perturbed coreset Gramian
"""
if (
type(dataset) is Data
and type(coreset) is Data
and isinstance(kernel, ScalarValuedKernel)
):
x_d, x_c = dataset.data, coreset.data
kernel_cd = kernel.compute_mean(
x_d, x_c, axis=0, block_size=block_size, unroll=unroll
)
kernel_cc = kernel.compute(x_c, x_c) + epsilon * jnp.eye(len(coreset))
else:
raise ValueError(INVALID_KERNEL_DATA_COMBINATION)
return kernel_cd, kernel_cc
# pylint: enable = unidiomatic-typecheck
[docs]
class WeightsOptimiser(eqx.Module, Generic[_Data]):
r"""Base class for optimising weights."""
[docs]
@abstractmethod
def solve(
self,
dataset: _Data,
coreset: _Data,
epsilon: float = 1e-10,
) -> Shaped[Array, " m"]:
r"""
Solve the optimisation problem, return the optimal weights.
:param dataset: :class:`~coreax.data.Data` instance consisting of a
:math:`n \times d` data array or :class:`~coreax.data.SupervisedData`
instance consisting of :math:`n \times d` data array paired with
:math:`n \times p` supervision array
:param coreset: :class:`~coreax.data.Data` instance consisting of a
:math:`m \times d` data array or :class:`~coreax.data.SupervisedData`
instance consisting of :math:`m \times d` data array paired with
:math:`m \times p` supervision array, representing a coreset
:param epsilon: Small positive value to add to the matrices to aid numerical
solver computations
:return: Optimal weighting of points in `dataset` to represent `coreset`
"""
[docs]
class SBQWeightsOptimiser(WeightsOptimiser[_Data]):
r"""
Define the Sequential Bayesian Quadrature (SBQ) optimiser class.
References for this technique can be found in :cite:`huszar2016optimally`.
Weights determined by SBQ are equivalent to the unconstrained weighted maximum mean
discrepancy (MMD) optimum.
The Bayesian quadrature estimate of the integral
.. math::
\int f(x) p(x) dx
can be viewed as a weighted version of kernel herding. The Bayesian quadrature
weights, :math:`w_{BQ}`, are given by
.. math::
w_{BQ}^{(n)} = \sum_m z_m^T K_{mn}^{-1}
for a dataset :math:`x` with :math:`n` points, and coreset :math:`y` of :math:`m`
points. Here, for given kernel :math:`k`, we have :math:`z = \int k(x, y)p(x) dx`
and :math:`K = k(y, y)` in the above expression. See equation 20 in
:cite:`huszar2016optimally` for further detail.
:param kernel: :class:`~coreax.kernels.ScalarValuedKernel` instance implementing a
kernel :math:`k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}`
"""
kernel: ScalarValuedKernel
[docs]
def solve(
self,
dataset: _Data,
coreset: _Data,
epsilon: float = 1e-10,
*,
block_size: int | None | tuple[int | None, int | None] = None,
unroll: int | bool | tuple[int | bool, int | bool] = 1,
**solver_kwargs,
) -> Shaped[Array, " m"]:
r"""
Calculate weights from Sequential Bayesian Quadrature (SBQ).
References for this technique can be found in
:cite:`huszar2016optimally`. These are equivalent to the unconstrained
weighted maximum mean discrepancy (MMD) optimum.
Note that weights determined through SBQ do not need to sum to 1, and can be
negative.
:param dataset: :class:`~coreax.data.Data` instance consisting of a
:math:`n \times d` data array
:param coreset: :class:`~coreax.data.Data` instance consisting of a
:math:`m \times d` data array, representing a coreset
:param epsilon: Small positive value to add to the kernel Gram matrix to aid
numerical solver computations
:param block_size: Block size passed to the ``self.kernel.compute_mean``
:param unroll: Unroll parameter passed to ``self.kernel.compute_mean``
:param solver_kwargs: Additional kwargs passed to ``jnp.linalg.solve``
:return: Optimal weighting of points in ``coreset`` to represent ``dataset``
"""
kernel_cd, kernel_cc = _prepare_kernel_system(
self.kernel,
as_data(dataset),
as_data(coreset),
epsilon,
block_size=block_size,
unroll=unroll,
)
return jnp.linalg.solve(kernel_cc, kernel_cd, **solver_kwargs)
[docs]
class MMDWeightsOptimiser(WeightsOptimiser[_Data]):
r"""
Define the MMD weights optimiser class.
This optimiser solves a simplex weight problem of the form:
.. math::
\mathbf{w}^{\mathrm{T}} \mathbf{k} \mathbf{w} +
\bar{\mathbf{k}}^{\mathrm{T}} \mathbf{w} = 0
subject to
.. math::
\mathbf{Aw} = \mathbf{1}, \qquad \mathbf{Gx} \le 0.
using the OSQP quadratic programming solver.
:param kernel: :class:`~coreax.kernels.ScalarValuedKernel` instance implementing a
kernel function
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}`
"""
kernel: ScalarValuedKernel
[docs]
def solve(
self,
dataset: _Data,
coreset: _Data,
epsilon: float = 1e-10,
*,
block_size: int | None | tuple[int | None, int | None] = None,
unroll: int | bool | tuple[int | bool, int | bool] = 1,
**solver_kwargs,
) -> Shaped[Array, " m"]:
r"""
Compute optimal weights given the simplex constraint.
:param dataset: :class:`~coreax.data.Data` instance consisting of a
:math:`n \times d` data array
:param coreset: :class:`~coreax.data.Data` instance consisting of a
:math:`m \times d` data array, representing a coreset
:param epsilon: Small positive value to add to the kernel Gram matrix to aid
numerical solver computations
:param block_size: Block size passed to the ``self.kernel.compute_mean``
:param unroll: Unroll parameter passed to ``self.kernel.compute_mean``
:param solver_kwargs: Additional kwargs passed to :func:`solve_qp`
:return: Optimal weighting of points in ``coreset`` to represent ``dataset``
"""
kernel_cd, kernel_cc = _prepare_kernel_system(
self.kernel,
as_data(dataset),
as_data(coreset),
epsilon,
block_size=block_size,
unroll=unroll,
)
return solve_qp(kernel_cc, kernel_cd, **solver_kwargs)