Metrics#

Classes and associated functionality to compute metrics assessing similarity of inputs.

Large parts of this codebase consider the generic problem of taking a \(n \times d\) dataset and creating an alternative representation of it in some way. Having attained an alternative representation, we can then assess the quality of this representation using some appropriate metric. Such metrics are implemented within this module, all of which implement Metric.

class coreax.metrics.Metric[source]#

Base class for calculating metrics.

abstract compute(reference_data, comparison_data, **kwargs)[source]#

Compute the metric/distance between the reference and comparison data.

Parameters:
  • reference_data (Data) – An instance of the class coreax.data.Data, containing an \(n \times d\) array of data

  • comparison_data (Data) – An instance of the class coreax.data.Data to compare against reference_data, containing an \(m \times d\) array of data

Return type:

Array

Returns:

Computed metric as a zero-dimensional array

class coreax.metrics.MMD(kernel, precision_threshold=1e-12)[source]#

Definition and calculation of the (weighted) maximum mean discrepancy metric.

For a dataset \(\mathcal{D}_1\) of n points in d dimensions, and another dataset \(\mathcal{D}_2\) of m points in d dimensions, the (weighted) maximum mean discrepancy is given by:

\[\text{MMD}^2(\mathcal{D}_1,\mathcal{D}_2) = \mathbb{E}(k(\mathcal{D}_1, \mathcal{D}_1)) + \mathbb{E}(k(\mathcal{D}_2,\mathcal{D}_2)) - 2\mathbb{E}(k(\mathcal{D}_1,\mathcal{D}_2))\]

where \(k\) is the selected kernel, and the expectation is with respect to the normalized data weights.

Common uses of MMD include comparing a reduced representation of a dataset to the original dataset, comparing different original datasets to one another, or comparing reduced representations of different original datasets to one another.

Parameters:
  • kernel (Kernel) – Kernel object with compute method defined mapping \(k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}\)

  • precision_threshold (float) – Threshold above which negative values of the squared MMD are rounded to zero (accommodates precision loss)

compute(reference_data, comparison_data, *, block_size=None, unroll=1, **kwargs)[source]#

Compute the (weighted) maximum mean discrepancy.

\[\text{MMD}^2(\mathcal{D}_1,\mathcal{D}_2) = \mathbb{E}(k(\mathcal{D}_1, \mathcal{D}_1)) + \mathbb{E}(k(\mathcal{D}_2,\mathcal{D}_2)) - 2\mathbb{E}(k(\mathcal{D}_1,\mathcal{D}_2))\]
Parameters:
  • reference_data (Data) – An instance of the class coreax.data.Data, containing an \(n \times d\) array of data

  • comparison_data (Data) – An instance of the class coreax.data.Data to compare against reference_data containing an \(m \times d\) array of data

  • block_size (Union[int, None, tuple[Optional[int], Optional[int]]]) – Size of matrix blocks to process; a value of None sets \(B_x = n\) and \(B_y = m\), effectively disabling the block accumulation; an integer value B sets \(B_y = B_x = B\); a tuple allows different sizes to be specified for B_x and B_y; to reduce overheads, it is often sensible to select the largest block size that does not exhaust the available memory resources

  • unroll (Union[int, bool, tuple[Union[int, bool], Union[int, bool]]]) – Unrolling parameter for the outer and inner jax.lax.scan() calls, allows for trade-offs between compilation and runtime cost; consult the JAX docs for further information

Return type:

Array

Returns:

Maximum mean discrepancy as a 0-dimensional array