Source code for coreax.coreset

# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for defining coreset data structures."""

from abc import abstractmethod
from typing import (
    TYPE_CHECKING,
    Final,
    Generic,
    TypeVar,
    overload,
)

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, Shaped
from typing_extensions import Self, override

from coreax.data import Data, SupervisedData, as_data
from coreax.metrics import Metric
from coreax.weights import WeightsOptimiser

if TYPE_CHECKING:
    from typing import Any  # noqa: F401

# `_co` is a well-established suffix for covariant TypeVars
# pylint: disable=invalid-name
_TPointsData_co = TypeVar("_TPointsData_co", Data, SupervisedData, covariant=True)
_TOriginalData = TypeVar("_TOriginalData", Data, SupervisedData)
_TOriginalData_co = TypeVar("_TOriginalData_co", Data, SupervisedData, covariant=True)
# pylint: enable=invalid-name


[docs] class AbstractCoreset(eqx.Module, Generic[_TPointsData_co, _TOriginalData_co]): r""" Abstract base class for coresets. A coreset is a reduced set of :math:`\hat{n}` (potentially weighted) data points, :math:`\hat{X} := \{(\hat{x}_i, \hat{w}_i)\}_{i=1}^\hat{n}` that, in some sense, best represent the "important" properties of a larger set of :math:`n > \hat{n}` (potentially weighted) data points :math:`X := \{(x_i, w_i)\}_{i=1}^n`. :math:`\hat{x}_i, x_i \in \Omega` represent the data points/nodes and :math:`\hat{w}_i, w_i \in \mathbb{R}` represent the associated weights. """ @property @abstractmethod def points(self) -> _TPointsData_co: """The coreset points.""" @property @abstractmethod def pre_coreset_data(self) -> _TOriginalData_co: """The original data that this coreset is based on."""
[docs] @abstractmethod def solve_weights(self, solver: WeightsOptimiser[Data], **solver_kwargs) -> Self: """Return a copy of 'self' with weights solved by 'solver'."""
[docs] def compute_metric( self, metric: Metric[Data], **metric_kwargs ) -> Shaped[Array, ""]: """Return metric-distance between `self.pre_coreset_data` and `self.coreset`.""" return metric.compute(self.pre_coreset_data, self.points, **metric_kwargs)
def __len__(self) -> int: """Return Coreset size/length.""" return len(self.points) def __check_init__(self) -> None: """Check that coreset has fewer 'nodes' than the 'pre_coreset_data'.""" if len(self.points) > len(self.pre_coreset_data): raise ValueError( "'len(points)' cannot be greater than 'len(pre_coreset_data)' " "by definition of a Coreset" )
[docs] class PseudoCoreset( AbstractCoreset[Data, _TOriginalData_co], Generic[_TOriginalData_co] ): r""" Data structure for representing a pseudo-coreset. The points of a pseudo-coreset are not necessarily points in the original dataset. :param nodes: The (weighted) coreset nodes, :math:`I`; these can be accessed via :meth:`Coresubset.points`. :param pre_coreset_data: The dataset :math:`X` used to construct the coreset. """ # These aren't _constants_ so much as just _read-only_, so it doesn't make sense # for them to be in SCREAMING_SNAKE_CASE. Also, even if they are changed to appease # Pylint here, Pylint then just complains when they're assigned to in __init__ # instead! # pylint: disable=invalid-name _nodes: Final[Data] _pre_coreset_data: Final[_TOriginalData_co] # pylint: enable=invalid-name def __init__(self, nodes: Data, pre_coreset_data: _TOriginalData_co) -> None: """Initialise self.""" if not isinstance(nodes, Data): raise TypeError( "`nodes` must be of type `Data`. " "To use an array, use PseudoCoreset.build() instead." ) if not isinstance(pre_coreset_data, Data): raise TypeError( "`pre_coreset_data` must be of type `Data` or `SupervisedData`. " "To use an array or tuple of arrays, use PseudoCoreset.build() instead." ) self._nodes = nodes self._pre_coreset_data = pre_coreset_data @classmethod @overload def build( cls, nodes: Data | Array, pre_coreset_data: Array ) -> "PseudoCoreset[Data]": ... @classmethod @overload def build( cls, nodes: Data | Array, pre_coreset_data: tuple[Array, Array], ) -> "PseudoCoreset[SupervisedData]": ... @classmethod @overload def build( cls, nodes: Data | Array, pre_coreset_data: _TOriginalData, ) -> "PseudoCoreset[_TOriginalData]": ...
[docs] @classmethod def build( cls, nodes: Data | Array, pre_coreset_data: _TOriginalData | Array | tuple[Array, Array], ) -> "PseudoCoreset[Data]\ | PseudoCoreset[SupervisedData]\ | PseudoCoreset[_TOriginalData]\ ": """ Construct a PseudoCoreset from Data or raw Arrays. :param nodes: The (weighted) coreset nodes, :math:`I`; these can be accessed via :meth:`Coresubset.points`. :class:`jax.Array` instances are automatically converted into :class:`~coreax.data.Data`. :param pre_coreset_data: The dataset :math:`X` used to construct the coreset. :class:`jax.Array` instances are automatically converted into :class:`~coreax.data.Data`. :class:`tuple` [:class:`jax.Array`, :class:`jax.Array`] is automatically converted into :class:`~coreax.data.SupervisedData`. """ if isinstance(pre_coreset_data, Array): converted_pre_coreset_data = as_data(pre_coreset_data) elif isinstance(pre_coreset_data, tuple): converted_pre_coreset_data = SupervisedData(*pre_coreset_data) else: converted_pre_coreset_data = pre_coreset_data return PseudoCoreset(as_data(nodes), converted_pre_coreset_data)
@property @override def points(self) -> Data: """Materialised coreset.""" return self._nodes @property @override def pre_coreset_data(self): return self._pre_coreset_data
[docs] @override def solve_weights(self, solver: WeightsOptimiser[Data], **solver_kwargs) -> Self: """Return a copy of 'self' with weights solved by 'solver'.""" weights = solver.solve(self.pre_coreset_data, self.points, **solver_kwargs) return eqx.tree_at(lambda x: x.points.weights, self, weights)
[docs] class Coresubset( AbstractCoreset[_TOriginalData_co, _TOriginalData_co], Generic[_TOriginalData_co] ): r""" Data structure for representing a coresubset. A coresubset is a coreset, with the additional condition that the coreset data points/nodes must be a subset of the original data points/nodes, such that .. math:: \hat{x}_i = x_i, \forall i \in I, I \subset \{1, \dots, n\}, \text{card}(I) = \hat{n}. Thus, a coresubset, unlike a coreset, ensures that feasibility constraints on the support of the measure are maintained :cite:`litterer2012recombination`. In coresubsets, the dataset reduction can be implicit (setting weights/nodes to zero for all :math:`i \notin I`) or explicit (removing entries from the weight/node arrays). The implicit approach is useful when input/output array shape stability is required (E.G. for some JAX transformations); the explicit approach is more similar to a standard coreset. :param indices: The (weighted) coresubset node indices, :math:`I`; the materialised coresubset nodes should only be accessed via :meth:`Coresubset.points`. :param pre_coreset_data: The dataset :math:`X` used to construct the coreset. """ # These aren't _constants_ so much as just _read-only_, so it doesn't make sense # for them to be in SCREAMING_SNAKE_CASE. Also, even if they are changed to appease # Pylint here, Pylint then just complains when they're assigned to in __init__ # instead! # pylint: disable=invalid-name _indices: Final[Data] _pre_coreset_data: Final[_TOriginalData_co] # pylint: enable=invalid-name def __init__(self, indices: Data, pre_coreset_data: _TOriginalData_co) -> None: """Initialise self.""" if not isinstance(indices, Data): raise TypeError( "`indices` must be of type `Data`. " "To use an array, use PseudoCoreset.build() instead." ) if not isinstance(pre_coreset_data, Data): raise TypeError( "`pre_coreset_data` must be of type `Data` or `SupervisedData`. " "To use an array or tuple of arrays, use PseudoCoreset.build() instead." ) self._indices = indices self._pre_coreset_data = pre_coreset_data @classmethod @overload def build( cls, indices: Data | Array, pre_coreset_data: Array ) -> "Coresubset[Data]": ... @classmethod @overload def build( cls, indices: Data | Array, pre_coreset_data: tuple[Array, Array], ) -> "Coresubset[SupervisedData]": ... @classmethod @overload def build( cls, indices: Data | Array, pre_coreset_data: _TOriginalData, ) -> "Coresubset[_TOriginalData]": ...
[docs] @classmethod def build( cls, indices: Data | Array, pre_coreset_data: _TOriginalData | Array | tuple[Array, Array], ) -> "Coresubset[Data] | Coresubset[SupervisedData] | Coresubset[_TOriginalData]": """ Construct a Coresubset from Data or raw Arrays. :param indices: The (weighted) coresubset node indices, :math:`I`; the materialised coresubset nodes should only be accessed via :meth:`Coresubset.points`. :class:`jax.Array` instances are automatically converted into :class:`~coreax.data.Data`. :param pre_coreset_data: The dataset :math:`X` used to construct the coreset. :class:`jax.Array` instances are automatically converted into :class:`~coreax.data.Data`. :class:`tuple` [:class:`jax.Array`, :class:`jax.Array`] is automatically converted into :class:`~coreax.data.SupervisedData`. """ if isinstance(pre_coreset_data, Array): converted_pre_coreset_data = as_data(pre_coreset_data) elif isinstance(pre_coreset_data, tuple): converted_pre_coreset_data = SupervisedData(*pre_coreset_data) else: converted_pre_coreset_data = pre_coreset_data return Coresubset(as_data(indices), converted_pre_coreset_data)
@property @override def points(self) -> _TOriginalData_co: """Materialise the coresubset from the indices and original data.""" coreset_data = self.pre_coreset_data[self.unweighted_indices] return eqx.tree_at(lambda x: x.weights, coreset_data, self._indices.weights) @property def unweighted_indices(self) -> Shaped[Array, " n"]: """Unweighted Coresubset indices - attribute access helper.""" # Ensure at least 1d to avoid shape errors. return jnp.atleast_1d(jnp.squeeze(self._indices.data)) @property @override def pre_coreset_data(self): return self._pre_coreset_data @property def indices(self) -> Data: """The (possibly weighted) Coresubset indices.""" return self._indices
[docs] @override def solve_weights(self, solver: WeightsOptimiser[Data], **solver_kwargs) -> Self: """Return a copy of 'self' with weights solved by 'solver'.""" weights = solver.solve(self.pre_coreset_data, self.points, **solver_kwargs) return eqx.tree_at(lambda x: x.indices.weights, self, weights)