Weights¶
Classes and associated functionality to optimise weighted representations of data.
Several aspects of this codebase take a \(n \times d\) dataset and generate an
alternative representation of it, for example a coreset. The quality of this alternative
representation in approximating the original dataset can be assessed using some metric
of interest, for example see Metric.
One can improve the quality of the representation generated by weighting the individual elements of it. These weights are determined by optimising the metric of interest, which compares the original \(n \times d\) dataset and the generated representation of it.
This module provides functionality to calculate such weights, through various methods.
All methods implement WeightsOptimiser and must have a
solve() method that, given two datasets, returns an array of
weights such that a metric of interest is optimised when these weights are applied to
the dataset.
- coreax.weights.solve_qp(kernel_mm, gramian_row_mean, **osqp_kwargs)[source]¶
Solve quadratic programs with the
jaxopt.OSQPsolver.Solves simplex weight problems of the form:
\[\mathbf{w}^{\mathrm{T}} \mathbf{k} \mathbf{w} + \bar{\mathbf{k}}^{\mathrm{T}} \mathbf{w} = 0\]subject to
\[\mathbf{Aw} = \mathbf{1}, \qquad \mathbf{Gx} \le 0.\]- Parameters:
kernel_mm (
Shaped[Array, 'm m']) – \(m \times m\) coreset Gram matrixgramian_row_mean (
Shaped[Array, 'm 1']) – \(m \times 1\) array of Gram matrix means
- Return type:
Shaped[Array, 'm']- Returns:
Optimised solution for the quadratic program
- class coreax.weights.WeightsOptimiser[source]¶
-
Base class for optimising weights.
- abstractmethod solve(dataset, coreset, epsilon=1e-10)[source]¶
Solve the optimisation problem, return the optimal weights.
- Parameters:
dataset (
Data) –Datainstance consisting of a \(n \times d\) data array orSupervisedDatainstance consisting of \(n \times d\) data array paired with \(n \times p\) supervision arraycoreset (
Data) –Datainstance consisting of a \(m \times d\) data array orSupervisedDatainstance consisting of \(m \times d\) data array paired with \(m \times p\) supervision array, representing a coresetepsilon (
float) – Small positive value to add to the matrices to aid numerical solver computations
- Return type:
Shaped[Array, 'm']- Returns:
Optimal weighting of points in dataset to represent coreset
- class coreax.weights.SBQWeightsOptimiser(kernel)[source]¶
Bases:
WeightsOptimiser[_Data]Define the Sequential Bayesian Quadrature (SBQ) optimiser class.
References for this technique can be found in [huszar2016optimally]. Weights determined by SBQ are equivalent to the unconstrained weighted maximum mean discrepancy (MMD) optimum.
The Bayesian quadrature estimate of the integral
\[\int f(x) p(x) dx\]can be viewed as a weighted version of kernel herding. The Bayesian quadrature weights, \(w_{BQ}\), are given by
\[w_{BQ}^{(n)} = \sum_m z_m^T K_{mn}^{-1}\]for a dataset \(x\) with \(n\) points, and coreset \(y\) of \(m\) points. Here, for given kernel \(k\), we have \(z = \int k(x, y)p(x) dx\) and \(K = k(y, y)\) in the above expression. See equation 20 in [huszar2016optimally] for further detail.
- Parameters:
kernel (
ScalarValuedKernel) –ScalarValuedKernelinstance implementing a kernel \(k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}\)
-
kernel:
ScalarValuedKernel¶
- solve(dataset, coreset, epsilon=1e-10, *, block_size=None, unroll=1, **solver_kwargs)[source]¶
Calculate weights from Sequential Bayesian Quadrature (SBQ).
References for this technique can be found in [huszar2016optimally]. These are equivalent to the unconstrained weighted maximum mean discrepancy (MMD) optimum.
Note that weights determined through SBQ do not need to sum to 1, and can be negative.
- Parameters:
dataset (
Data) –Datainstance consisting of a \(n \times d\) data arraycoreset (
Data) –Datainstance consisting of a \(m \times d\) data array, representing a coresetepsilon (
float) – Small positive value to add to the kernel Gram matrix to aid numerical solver computationsblock_size (
Union[int,None,tuple[Optional[int],Optional[int]]]) – Block size passed to theself.kernel.compute_meanunroll (
Union[int,bool,tuple[Union[int,bool],Union[int,bool]]]) – Unroll parameter passed toself.kernel.compute_meansolver_kwargs – Additional kwargs passed to
jnp.linalg.solve
- Return type:
Shaped[Array, 'm']- Returns:
Optimal weighting of points in
coresetto representdataset
- class coreax.weights.MMDWeightsOptimiser(kernel)[source]¶
Bases:
WeightsOptimiser[_Data]Define the MMD weights optimiser class.
This optimiser solves a simplex weight problem of the form:
\[\mathbf{w}^{\mathrm{T}} \mathbf{k} \mathbf{w} + \bar{\mathbf{k}}^{\mathrm{T}} \mathbf{w} = 0\]subject to
\[\mathbf{Aw} = \mathbf{1}, \qquad \mathbf{Gx} \le 0.\]using the OSQP quadratic programming solver.
- Parameters:
kernel (
ScalarValuedKernel) –ScalarValuedKernelinstance implementing a kernel function \(k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}\)
-
kernel:
ScalarValuedKernel¶
- solve(dataset, coreset, epsilon=1e-10, *, block_size=None, unroll=1, **solver_kwargs)[source]¶
Compute optimal weights given the simplex constraint.
- Parameters:
dataset (
Data) –Datainstance consisting of a \(n \times d\) data arraycoreset (
Data) –Datainstance consisting of a \(m \times d\) data array, representing a coresetepsilon (
float) – Small positive value to add to the kernel Gram matrix to aid numerical solver computationsblock_size (
Union[int,None,tuple[Optional[int],Optional[int]]]) – Block size passed to theself.kernel.compute_meanunroll (
Union[int,bool,tuple[Union[int,bool],Union[int,bool]]]) – Unroll parameter passed toself.kernel.compute_meansolver_kwargs – Additional kwargs passed to
solve_qp()
- Return type:
Shaped[Array, 'm']- Returns:
Optimal weighting of points in
coresetto representdataset