Metrics#
Classes and associated functionality to compute metrics assessing similarity of inputs.
Large parts of this codebase consider the generic problem of taking a
\(n \times d\) dataset and creating an alternative representation of it in some way.
Having attained an alternative representation, we can then assess the quality of this
representation using some appropriate metric. Such metrics are implemented within this
module, all of which implement Metric.
- class coreax.metrics.Metric[source]#
Bases:
Module,Generic[_Data]Base class for calculating metrics.
- abstract compute(reference_data, comparison_data, **kwargs)[source]#
Compute the metric/distance between the reference and comparison data.
- Parameters:
reference_data (
Data) – An instance of the classcoreax.data.Data, containing an \(n \times d\) array of datacomparison_data (
Data) – An instance of the classcoreax.data.Datato compare againstreference_data, containing an \(m \times d\) array of data
- Return type:
Shaped[Array, '']- Returns:
Computed metric as a zero-dimensional array
- class coreax.metrics.MMD(kernel, precision_threshold=1e-12)[source]#
-
Definition and calculation of the (weighted) maximum mean discrepancy metric.
For a dataset \(\mathcal{D}_1\) of
npoints inddimensions, and another dataset \(\mathcal{D}_2\) ofmpoints inddimensions, the (weighted) maximum mean discrepancy is given by:\[\text{MMD}^2(\mathcal{D}_1,\mathcal{D}_2) = \mathbb{E}(k(\mathcal{D}_1, \mathcal{D}_1)) + \mathbb{E}(k(\mathcal{D}_2,\mathcal{D}_2)) - 2\mathbb{E}(k(\mathcal{D}_1,\mathcal{D}_2))\]where \(k\) is the selected kernel, and the expectation is with respect to the normalized data weights.
Common uses of MMD include comparing a reduced representation of a dataset to the original dataset, comparing different original datasets to one another, or comparing reduced representations of different original datasets to one another.
- Parameters:
kernel (
ScalarValuedKernel) – Kernel object with compute method defined mapping \(k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}\)precision_threshold (
float) – Threshold above which negative values of the squared MMD are rounded to zero (accommodates precision loss)
-
kernel:
ScalarValuedKernel#
- compute(reference_data, comparison_data, *, block_size=None, unroll=1, **kwargs)[source]#
Compute the (weighted) maximum mean discrepancy.
\[\text{MMD}^2(\mathcal{D}_1,\mathcal{D}_2) = \mathbb{E}(k(\mathcal{D}_1, \mathcal{D}_1)) + \mathbb{E}(k(\mathcal{D}_2,\mathcal{D}_2)) - 2\mathbb{E}(k(\mathcal{D}_1,\mathcal{D}_2))\]- Parameters:
reference_data (
Data) – An instance of the classcoreax.data.Data, containing an \(n \times d\) array of datacomparison_data (
Data) – An instance of the classcoreax.data.Datato compare againstreference_datacontaining an \(m \times d\) array of datablock_size (
Union[int,None,tuple[Optional[int],Optional[int]]]) – Size of matrix blocks to process; a value ofNonesets \(B_x = n\) and \(B_y = m\), effectively disabling the block accumulation; an integer valueBsets \(B_y = B_x = B\); a tuple allows different sizes to be specified forB_xandB_y; to reduce overheads, it is often sensible to select the largest block size that does not exhaust the available memory resourcesunroll (
Union[int,bool,tuple[Union[int,bool],Union[int,bool]]]) – Unrolling parameter for the outer and innerjax.lax.scan()calls, allows for trade-offs between compilation and runtime cost; consult the JAX docs for further information
- Return type:
Shaped[Array, '']- Returns:
Maximum mean discrepancy as a 0-dimensional array
- class coreax.metrics.KSD(kernel, score_matching=None, precision_threshold=1e-12)[source]#
-
Computation of the (regularised) (Laplace-corrected) kernel Stein discrepancy (KSD).
For a set of
ni.i.d. samples inddimensions \(\mathcal{D}_1 \sim \mathbb{P}\) and another set ofmi.i.d. samples inddimensions \(\mathcal{D}_2 \sim \mathbb{Q}\), the regularised Laplace-corrected kernel Stein discrepancy is given by:\[KSD_{\lambda}^2(\mathbb{P}, \mathbb{Q}) = \frac{1}{m^2}\sum_{i \neq j}^m k_{\mathbb{P}}(x_i, x_j) + \frac{1}{m^2}\sum_{i = 1}^m [k_{\mathbb{P}}(x_i, x_i) + \Delta^+ \log(\mathbb{P}(x_i))] - \lambda \frac{1}{m}\sum_{i = 1}^m \log(\mathbb{P}(x_i))\]where \(x \sim \mathbb{Q}\), \(k_{\mathbb{P}}\) is the Stein kernel induced by a base kernel and estimated with samples from \(\mathbb{P}\). The first term is vanilla KSD, the second term implements a Laplace-correction, and the third term enforces entropic regularisation. See [benard2023kernel] for a discussion on the need for and effects of Laplace-correction and entropic regularisation.
Common uses of KSD include comparing a reduced representation of a dataset to the original dataset, comparing different original datasets to one another, or comparing reduced representations of different original datasets to one another.
Note
The kernel stein discrepancy is not a metric like
coreax.metrics.MMD. It is instead a divergence, which is a kind of statistical distance that differs from a metric in a few ways. In particular, they are not symmetric. i.e. \(KSD_{\lambda}(\mathbb{P}, \mathbb{Q}) \neq KSD_{\lambda}(\mathbb{Q}, \mathbb{P})\), and they generalise the concept of squared distance and so do not satisfy the triangle inequality.- Parameters:
kernel (
ScalarValuedKernel) –ScalarValuedKernelinstance implementing a kernel function \(k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}\); if ‘kernel’ is aSteinKernelandscore_matching is not None, a new instance of the kernel will be generated where the score function is given byscore_matching.match(...)score_matching (
Optional[ScoreMatching]) – Specifies/overwrite the score function of the implied/passedSteinKernel; ifNone, default toKernelDensityMatchingunless ‘kernel’ is aSteinKernel, in which case the kernel’s existing score function is used.precision_threshold (
float) – Threshold above which negative values of the squared KSD are rounded to zero (accommodates precision loss)
-
kernel:
ScalarValuedKernel#
-
score_matching:
Optional[ScoreMatching] = None#
- compute(reference_data, comparison_data, *, laplace_correct=False, regularise=False, block_size=None, unroll=1, **kwargs)[source]#
Compute the (regularised) (Laplace-corrected) kernel Stein discrepancy.
\[KSD_{\lambda}^2(\mathbb{P}, \mathbb{Q}) = \frac{1}{m^2}\sum_{i \neq j}^m k_{\mathbb{P}}(x_i, x_j) + \frac{1}{m^2}\sum_{i = 1}^m [k_{\mathbb{P}}(x_i, x_i) + \Delta^+ \log(\mathbb{P}(x_i))] - \lambda \frac{1}{m}\sum_{i = 1}^m \log(\mathbb{P}(x_i))\]- Parameters:
reference_data (
Data) – An instance of the classcoreax.data.Data, containing an \(n \times d\) array of data sampled from \(\mathbb{P}\)comparison_data (
Data) – An instance of the classcoreax.data.Datato compare againstreference_datacontaining an \(m \times d\) array of data sampled from \(\mathbb{Q}\)laplace_correct (
bool) – Boolean that enforces Laplace correction, see Section 3.1 of [benard2023kernel].regularise (
bool) – Boolean that enforces entropic regularisation.True, uses regularisation strength suggested in [benard2023kernel]. \(\lambda = \frac{1}{m}\).block_size (
Optional[int]) – Size of matrix blocks to process; a value ofNonesetsblock_size\(=n\) effectively disabling the block accumulation; an integer valueBsetsblock_size\(=B\), it is often sensible to select the largest block size that does not exhaust the available memory resourcesunroll (
Union[int,bool,tuple[Union[int,bool],Union[int,bool]]]) – Unrolling parameter for the outer and innerjax.lax.scan()calls, allows for trade-offs between compilation and runtime cost; consult the JAX docs for further information
- Return type:
Shaped[Array, '']- Returns:
Kernel Stein Discrepancy as a 0-dimensional array