Metrics#
Classes and associated functionality to compute metrics assessing similarity of inputs.
Large parts of this codebase consider the generic problem of taking a
\(n \times d\) dataset and creating an alternative representation of it in some way.
Having attained an alternative representation, we can then assess the quality of this
representation using some appropriate metric. Such metrics are implemented within this
module, all of which implement Metric.
- class coreax.metrics.Metric[source]#
Base class for calculating metrics.
- abstract compute(reference_data, comparison_data, **kwargs)[source]#
Compute the metric/distance between the reference and comparison data.
- Parameters:
reference_data (
Data) – An instance of the classcoreax.data.Data, containing an \(n \times d\) array of datacomparison_data (
Data) – An instance of the classcoreax.data.Datato compare againstreference_data, containing an \(m \times d\) array of data
- Return type:
- Returns:
Computed metric as a zero-dimensional array
- class coreax.metrics.MMD(kernel, precision_threshold=1e-12)[source]#
Definition and calculation of the (weighted) maximum mean discrepancy metric.
For a dataset \(\mathcal{D}_1\) of
npoints inddimensions, and another dataset \(\mathcal{D}_2\) ofmpoints inddimensions, the (weighted) maximum mean discrepancy is given by:\[\text{MMD}^2(\mathcal{D}_1,\mathcal{D}_2) = \mathbb{E}(k(\mathcal{D}_1, \mathcal{D}_1)) + \mathbb{E}(k(\mathcal{D}_2,\mathcal{D}_2)) - 2\mathbb{E}(k(\mathcal{D}_1,\mathcal{D}_2))\]where \(k\) is the selected kernel, and the expectation is with respect to the normalized data weights.
Common uses of MMD include comparing a reduced representation of a dataset to the original dataset, comparing different original datasets to one another, or comparing reduced representations of different original datasets to one another.
- Parameters:
- compute(reference_data, comparison_data, *, block_size=None, unroll=1, **kwargs)[source]#
Compute the (weighted) maximum mean discrepancy.
\[\text{MMD}^2(\mathcal{D}_1,\mathcal{D}_2) = \mathbb{E}(k(\mathcal{D}_1, \mathcal{D}_1)) + \mathbb{E}(k(\mathcal{D}_2,\mathcal{D}_2)) - 2\mathbb{E}(k(\mathcal{D}_1,\mathcal{D}_2))\]- Parameters:
reference_data (
Data) – An instance of the classcoreax.data.Data, containing an \(n \times d\) array of datacomparison_data (
Data) – An instance of the classcoreax.data.Datato compare againstreference_datacontaining an \(m \times d\) array of datablock_size (
Union[int,None,tuple[Optional[int],Optional[int]]]) – Size of matrix blocks to process; a value ofNonesets \(B_x = n\) and \(B_y = m\), effectively disabling the block accumulation; an integer valueBsets \(B_y = B_x = B\); a tuple allows different sizes to be specified forB_xandB_y; to reduce overheads, it is often sensible to select the largest block size that does not exhaust the available memory resourcesunroll (
Union[int,bool,tuple[Union[int,bool],Union[int,bool]]]) – Unrolling parameter for the outer and innerjax.lax.scan()calls, allows for trade-offs between compilation and runtime cost; consult the JAX docs for further information
- Return type:
- Returns:
Maximum mean discrepancy as a 0-dimensional array