Weights#
Classes and associated functionality to optimise weighted representations of data.
Several aspects of this codebase take a \(n \times d\) dataset and generate an
alternative representation of it, for example a coreset. The quality of this alternative
representation in approximating the original dataset can be assessed using some metric
of interest, for example see Metric.
One can improve the quality of the representation generated by weighting the individual elements of it. These weights are determined by optimising the metric of interest, which compares the original \(n \times d\) dataset and the generated representation of it.
This module provides functionality to calculate such weights, through various methods.
All methods implement WeightsOptimiser and must have a
solve() method that, given two datasets, returns an array of
weights such that a metric of interest is optimised when these weights are applied to
the dataset.
- coreax.weights._prepare_kernel_system(kernel, x, y, epsilon=1e-10, *, block_size=None, unroll=1)[source]#
Return the row mean of :math`k(y, x)` and the Gramian \(k(y, y)\).
- Parameters:
kernel (
Kernel) – The kernel \(k\) to evaluatex (
Union[Array,ndarray,bool_,number,bool,int,float,complex,Data]) – The original \(n \times d\) datay (
Union[Array,ndarray,bool_,number,bool,int,float,complex,Data]) – \(m \times d\) representation ofx, e.g. a coresetepsilon (
float) – Small positive value to add to the kernel Gram matrix to aid numerical solver computationsblock_size (
Union[int,None,tuple[Optional[int],Optional[int]]]) – Block size passed to theself.kernel.compute_meanunroll (
Union[int,bool,tuple[Union[int,bool],Union[int,bool]]]) – Unroll parameter passed toself.kernel.compute_mean
- Return type:
- Returns:
The row mean of k(y,x) and the epsilon perturbed Gramian k(y,y)
- class coreax.weights.WeightsOptimiser(kernel)[source]#
Base class for calculating weights.
- class coreax.weights.SBQWeightsOptimiser(kernel)[source]#
Define the Sequential Bayesian Quadrature (SBQ) optimiser class.
References for this technique can be found in [huszar2016optimally]. Weights determined by SBQ are equivalent to the unconstrained weighted maximum mean discrepancy (MMD) optimum.
The Bayesian quadrature estimate of the integral
\[\int f(x) p(x) dx\]can be viewed as a weighted version of kernel herding. The Bayesian quadrature weights, \(w_{BQ}\), are given by
\[w_{BQ}^{(n)} = \sum_m z_m^T K_{mn}^{-1}\]for a dataset \(x\) with \(n\) points, and coreset \(y\) of \(m\) points. Here, for given kernel \(k\), we have \(z = \int k(x, y)p(x) dx\) and \(K = k(y, y)\) in the above expression. See equation 20 in [huszar2016optimally] for further detail.
- solve(x, y, epsilon=1e-10, *, block_size=None, unroll=1, **solver_kwargs)[source]#
Calculate weights from Sequential Bayesian Quadrature (SBQ).
References for this technique can be found in [huszar2016optimally]. These are equivalent to the unconstrained weighted maximum mean discrepancy (MMD) optimum.
Note that weights determined through SBQ do not need to sum to 1, and can be negative.
- Parameters:
x (
Union[Array,ndarray,bool_,number,bool,int,float,complex,Data]) – The original \(n \times d\) datay (
Union[Array,ndarray,bool_,number,bool,int,float,complex,Data]) – \(m \times d\) representation ofx, e.g. a coresetepsilon (
float) – Small positive value to add to the kernel Gram matrix to aid numerical solver computationsblock_size (
Union[int,None,tuple[Optional[int],Optional[int]]]) – Block size passed to theself.kernel.compute_meanunroll (
Union[int,bool,tuple[Union[int,bool],Union[int,bool]]]) – Unroll parameter passed toself.kernel.compute_meansolver_kwargs – Additional kwargs passed to
jnp.linalg.solve
- Return type:
- Returns:
Optimal weighting of points in
yto representx
- class coreax.weights.MMDWeightsOptimiser(kernel)[source]#
Define the MMD weights optimiser class.
This optimiser solves a simplex weight problem of the form:
\[\mathbf{w}^{\mathrm{T}} \mathbf{k} \mathbf{w} + \bar{\mathbf{k}}^{\mathrm{T}} \mathbf{w} = 0\]subject to
\[\mathbf{Aw} = \mathbf{1}, \qquad \mathbf{Gx} \le 0.\]using the OSQP quadratic programming solver.
- solve(x, y, epsilon=1e-10, *, block_size=None, unroll=1, **solver_kwargs)[source]#
Compute optimal weights given the simplex constraint.
- Parameters:
x (
Union[Array,ndarray,bool_,number,bool,int,float,complex,Data]) – The original \(n \times d\) datay (
Union[Array,ndarray,bool_,number,bool,int,float,complex,Data]) – \(m \times d\) representation ofx, e.g. a coresetepsilon (
float) – Small positive value to add to the kernel Gram matrix to aid numerical solver computationsblock_size (
Union[int,None,tuple[Optional[int],Optional[int]]]) – Block size passed to theself.kernel.compute_meanunroll (
Union[int,bool,tuple[Union[int,bool],Union[int,bool]]]) – Unroll parameter passed toself.kernel.compute_meansolver_kwargs – Additional kwargs passed to
jnp.linalg.solve
- Return type:
- Returns:
Optimal weighting of points in
yto representx
- class coreax.weights.SBQ(kernel)[source]#
Deprecated reference to
SBQWeightsOptimiser.Will be removed in version 0.3.0
- Parameters:
kernel (Kernel) –
- class coreax.weights.MMD(kernel)[source]#
Deprecated reference to
MMDWeightsOptimiser.Will be removed in version 0.3.0
- Parameters:
kernel (Kernel) –