Weights

Classes and associated functionality to optimise weighted representations of data.

Several aspects of this codebase take a \(n \times d\) dataset and generate an alternative representation of it, for example a coreset. The quality of this alternative representation in approximating the original dataset can be assessed using some metric of interest, for example see Metric.

One can improve the quality of the representation generated by weighting the individual elements of it. These weights are determined by optimising the metric of interest, which compares the original \(n \times d\) dataset and the generated representation of it.

This module provides functionality to calculate such weights, through various methods. All methods implement WeightsOptimiser and must have a solve() method that, given two datasets, returns an array of weights such that a metric of interest is optimised when these weights are applied to the dataset.

coreax.weights.solve_qp(kernel_mm, gramian_row_mean, **osqp_kwargs)[source]

Solve quadratic programs with the jaxopt.OSQP solver.

Solves simplex weight problems of the form:

\[\mathbf{w}^{\mathrm{T}} \mathbf{k} \mathbf{w} + \bar{\mathbf{k}}^{\mathrm{T}} \mathbf{w} = 0\]

subject to

\[\mathbf{Aw} = \mathbf{1}, \qquad \mathbf{Gx} \le 0.\]
Parameters:
  • kernel_mm (Shaped[Array, 'm m']) – \(m \times m\) coreset Gram matrix

  • gramian_row_mean (Shaped[Array, 'm 1']) – \(m \times 1\) array of Gram matrix means

Return type:

Shaped[Array, 'm']

Returns:

Optimised solution for the quadratic program

class coreax.weights.WeightsOptimiser[source]

Bases: Module, Generic[_Data]

Base class for optimising weights.

abstractmethod solve(dataset, coreset, epsilon=1e-10)[source]

Solve the optimisation problem, return the optimal weights.

Parameters:
  • dataset (Data) – Data instance consisting of a \(n \times d\) data array or SupervisedData instance consisting of \(n \times d\) data array paired with \(n \times p\) supervision array

  • coreset (Data) – Data instance consisting of a \(m \times d\) data array or SupervisedData instance consisting of \(m \times d\) data array paired with \(m \times p\) supervision array, representing a coreset

  • epsilon (float) – Small positive value to add to the matrices to aid numerical solver computations

Return type:

Shaped[Array, 'm']

Returns:

Optimal weighting of points in dataset to represent coreset

class coreax.weights.SBQWeightsOptimiser(kernel)[source]

Bases: WeightsOptimiser[_Data]

Define the Sequential Bayesian Quadrature (SBQ) optimiser class.

References for this technique can be found in [huszar2016optimally]. Weights determined by SBQ are equivalent to the unconstrained weighted maximum mean discrepancy (MMD) optimum.

The Bayesian quadrature estimate of the integral

\[\int f(x) p(x) dx\]

can be viewed as a weighted version of kernel herding. The Bayesian quadrature weights, \(w_{BQ}\), are given by

\[w_{BQ}^{(n)} = \sum_m z_m^T K_{mn}^{-1}\]

for a dataset \(x\) with \(n\) points, and coreset \(y\) of \(m\) points. Here, for given kernel \(k\), we have \(z = \int k(x, y)p(x) dx\) and \(K = k(y, y)\) in the above expression. See equation 20 in [huszar2016optimally] for further detail.

Parameters:

kernel (ScalarValuedKernel) – ScalarValuedKernel instance implementing a kernel \(k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}\)

kernel: ScalarValuedKernel
solve(dataset, coreset, epsilon=1e-10, *, block_size=None, unroll=1, **solver_kwargs)[source]

Calculate weights from Sequential Bayesian Quadrature (SBQ).

References for this technique can be found in [huszar2016optimally]. These are equivalent to the unconstrained weighted maximum mean discrepancy (MMD) optimum.

Note that weights determined through SBQ do not need to sum to 1, and can be negative.

Parameters:
  • dataset (Data) – Data instance consisting of a \(n \times d\) data array

  • coreset (Data) – Data instance consisting of a \(m \times d\) data array, representing a coreset

  • epsilon (float) – Small positive value to add to the kernel Gram matrix to aid numerical solver computations

  • block_size (Union[int, None, tuple[Optional[int], Optional[int]]]) – Block size passed to the self.kernel.compute_mean

  • unroll (Union[int, bool, tuple[Union[int, bool], Union[int, bool]]]) – Unroll parameter passed to self.kernel.compute_mean

  • solver_kwargs – Additional kwargs passed to jnp.linalg.solve

Return type:

Shaped[Array, 'm']

Returns:

Optimal weighting of points in coreset to represent dataset

class coreax.weights.MMDWeightsOptimiser(kernel)[source]

Bases: WeightsOptimiser[_Data]

Define the MMD weights optimiser class.

This optimiser solves a simplex weight problem of the form:

\[\mathbf{w}^{\mathrm{T}} \mathbf{k} \mathbf{w} + \bar{\mathbf{k}}^{\mathrm{T}} \mathbf{w} = 0\]

subject to

\[\mathbf{Aw} = \mathbf{1}, \qquad \mathbf{Gx} \le 0.\]

using the OSQP quadratic programming solver.

Parameters:

kernel (ScalarValuedKernel) – ScalarValuedKernel instance implementing a kernel function \(k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}\)

kernel: ScalarValuedKernel
solve(dataset, coreset, epsilon=1e-10, *, block_size=None, unroll=1, **solver_kwargs)[source]

Compute optimal weights given the simplex constraint.

Parameters:
  • dataset (Data) – Data instance consisting of a \(n \times d\) data array

  • coreset (Data) – Data instance consisting of a \(m \times d\) data array, representing a coreset

  • epsilon (float) – Small positive value to add to the kernel Gram matrix to aid numerical solver computations

  • block_size (Union[int, None, tuple[Optional[int], Optional[int]]]) – Block size passed to the self.kernel.compute_mean

  • unroll (Union[int, bool, tuple[Union[int, bool], Union[int, bool]]]) – Unroll parameter passed to self.kernel.compute_mean

  • solver_kwargs – Additional kwargs passed to solve_qp()

Return type:

Shaped[Array, 'm']

Returns:

Optimal weighting of points in coreset to represent dataset