Source code for coreax.data

# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data-structures for representing weighted and/or supervised data."""

from typing import Any, Optional

import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import Array, ArrayLike, Shaped
from typing_extensions import Self


[docs] class Data(eqx.Module): r""" Class for representing unsupervised data. A dataset of size `n` consists of a set of pairs :math:`\{(x_i, w_i)\}_{i=1}^n` where :math`x_i` are the features or inputs and :math:`w_i` are weights. :param data: An :math:`n \times d` array defining the features of the unsupervised dataset :param weights: An :math:`n`-vector of weights where each element of the weights vector is paired with the corresponding index of the data array, forming the pair :math:`(x_i, w_i)`; if passed a scalar weight, it will be broadcast to an :math:`n`-vector. the default value of :data:`None` sets the weights to the ones vector (implies a scalar weight of one); """ data: Shaped[Array, " n *d"] weights: Shaped[Array, " n"] def __init__( self, data: Shaped[ArrayLike, " n *d"], weights: Optional[Shaped[ArrayLike, " n"]] = None, ): """Initialise Data class.""" self.data = jnp.asarray(data) n = self.data.shape[:1] self.weights = jnp.broadcast_to(1 if weights is None else weights, n) def __getitem__(self, key) -> Self: """Support Array style indexing of 'Data' objects.""" return jtu.tree_map(lambda x: x[key], self) def __jax_array__(self) -> Shaped[ArrayLike, " n d"]: """Register ArrayLike behaviour - return value for `jnp.asarray(Data(...))`.""" return self.data def __len__(self) -> int: """Return data length.""" return len(self.data)
[docs] def normalize(self, *, preserve_zeros: bool = False) -> Self: """ Return a copy of 'self' with 'weights' that sum to one. :param preserve_zeros: If to preserve zero valued weights; when all weights are zero valued, the 'normalized' copy will **sum to zero, not one**. :return: A copy of 'self' with normalized 'weights' """ normalized_weights = self.weights / jnp.sum(self.weights) if preserve_zeros: normalized_weights = jnp.nan_to_num(normalized_weights) return eqx.tree_at(lambda x: x.weights, self, normalized_weights)
[docs] class SupervisedData(Data): r""" Class for representing supervised data. A supervised dataset of size `n` consists of a set of triples :math:`\{(x_i, y_i, w_i)\}_{i=1}^n` where :math`x_i` are the features or inputs, :math:`y_i` are the responses or outputs, and :math:`w_i` are weights which correspond to the pairs :math:`(x_i, y_i)`. :param data: An :math:`n \times d` array defining the features of the supervised dataset paired with the corresponding index of the supervision :param supervision: An :math:`n \times p` array defining the responses of the supervised dataset paired with the corresponding index of the data :param weights: An :math:`n`-vector of weights where each element of the weights vector is is paired with the corresponding index of the data and supervision array, forming the triple :math:`(x_i, y_i, w_i)`; if passed a scalar weight, it will be broadcast to an :math:`n`-vector. the default value of :data:`None` sets the weights to the ones vector (implies a scalar weight of one); """ supervision: Shaped[Array, " n *p"] = eqx.field(converter=jnp.atleast_2d) def __init__( self, data: Shaped[Array, " n d"], supervision: Shaped[Array, " n *p"], weights: Optional[Shaped[Array, " n"]] = None, ): """Initialise SupervisedData class.""" self.supervision = supervision super().__init__(data, weights) def __check_init__(self): """Check leading dimensions of supervision and data match.""" if self.supervision.shape[0] != self.data.shape[0]: raise ValueError( "Leading dimensions of 'supervision' and 'data' must be equal" )
[docs] def as_data(x: Any) -> Data: """Cast 'x' to a data instance.""" return x if isinstance(x, Data) else Data(x)
[docs] def is_data(x: Any) -> bool: """Return boolean indicating if 'x' is an instance of 'coreax.data.Data'.""" return isinstance(x, Data)