Source code for coreax.approximation

# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""
Classes and associated functionality to approximate kernels.

When a dataset is very large, methods which have to evaluate all pairwise combinations
of the data, such as :meth:`~coreax.kernels.ScalarValuedKernel.gramian_row_mean`, can
become prohibitively expensive. To reduce this computational cost, such methods can
instead be approximated (providing suitable approximation error can be achieved).

The :class:`ApproximateKernel`\ s in this module provide the functionality required to
override specific methods of a ``base_kernel`` with their approximate counterparts.
Because :class:`ApproximateKernel`\ s inherit from
:class:`~coreax.kernels.ScalarValuedKernel`, with all functionality provided through
composition with a ``base_kernel``, they can be freely used in any place where a
standard :class:`~coreax.kernels.ScalarValuedKernel` is expected.
"""

from collections.abc import Callable
from functools import partial
from typing import TYPE_CHECKING, Any, Union

import jax
import jax.numpy as jnp
import jax.random as jr
from jax import Array
from jaxtyping import Shaped
from typing_extensions import Literal, override

from coreax.data import Data, _atleast_2d_consistent
from coreax.kernels import UniCompositeKernel
from coreax.util import KeyArrayLike

if TYPE_CHECKING:
    from coreax.kernels import ScalarValuedKernel  # noqa: F401


def _random_indices(
    key: KeyArrayLike,
    num_data_points: int,
    num_select: int,
    mode: Literal["kernel", "train"] = "kernel",
):
    """
    Select a random subset of indices.

    :param key: RNG key for seeding the random selection
    :param num_data_points: The total number of indexable data points
    :param num_select: The number of indices to select
    :param mode: The selection mode, used for error message formatting
    :return: A randomly selected subset of indices, of size ``num_samples``, for a
        dataset with ``num_data_points`` indexable entries.
    """
    try:
        selected_indices = jr.choice(key, num_data_points, (num_select,), replace=False)
    except ValueError as exception:
        if num_select > num_data_points:
            raise ValueError(
                f"'num_{mode}_points' must be no larger than the number of points in "
                "the provided data"
            ) from exception
        raise
    return selected_indices


def _random_least_squares(
    key: KeyArrayLike,
    data: Shaped[Array, " n p"],
    features: Shaped[Array, " n n"],
    num_indices: int,
    target_map: Callable[[Shaped[Array, " n p"]], Shaped[Array, " n p"]] = lambda x: x,
) -> Shaped[Array, " n p"]:
    r"""
    Solve the least-square problem on a random subset of the system.

    A linear system :math:`AX = B`, solved via least-squares as :math:`X = A^+ B`, can
    be approximated by random least-square as `X \approx \hat{X} = \hat{A}^+ \hat{B}`,
    where
    :math:`\hat{A} = A_{i\cdot}\ \text{and}\ \hat{B} = B_{i\cdot}\, \forall i \in I]`.
    :math:`I` is a random subset of indices for the original system of equations.

    :param key: RNG key for seeding the random selection
    :param data: The data :math:`Z \in \mathbb{R}^{n \times p}`; yields
        :math:`B \in \mathbb{R}^{n \times p}` when pushed through the target map
    :param features: The feature matrix :math:`A \in \mathbb{R}^{n \times n}`
    :param num_indices: The size of the random subset of indices :math:`I`
    :param target_map: The target map :math:`\phi` which defines :math:`b := \phi(z)`,
        where :math:`z` is the input ``data``
    :return: The push-forward of the approximate solution :math:`A\hat{X}`
    """
    num_data_points = len(data)
    train_idx = _random_indices(key, num_data_points, num_indices, mode="train")
    target = target_map(data[train_idx])
    approximate_solution, _, _, _ = jnp.linalg.lstsq(features[train_idx], target)
    return features @ approximate_solution


[docs] class ApproximateKernel(UniCompositeKernel): """ Base class for approximated kernels. Provides approximations of the methods in the ``base_kernel``. The :meth:`~coreax.kernels.ScalarValuedKernel.gramian_row_mean` method is particularly amenable to approximation, with significant performance improvements possible depending on the acceptable levels of error. :param base_kernel: a :class:`~coreax.kernels.ScalarValuedKernel` whose attributes/methods are to be approximated """
[docs] @override def compute_elementwise(self, x, y): return self.base_kernel.compute_elementwise(x, y)
[docs] @override def grad_x_elementwise(self, x, y): return self.base_kernel.grad_x_elementwise(x, y)
[docs] @override def grad_y_elementwise(self, x, y): return self.base_kernel.grad_y_elementwise(x, y)
[docs] @override def divergence_x_grad_y_elementwise(self, x, y): return self.base_kernel.divergence_x_grad_y_elementwise(x, y)
[docs] class RandomRegressionKernel(ApproximateKernel): """ An approximate kernel that requires the attributes for random regression. :param base_kernel: a :class:`~coreax.kernels.ScalarValuedKernel` whose attributes/methods are to be approximated :param random_key: Key for random number generation :param num_kernel_points: Number of kernel evaluation points :param num_train_points: Number of training points used to fit kernel regression """ random_key: KeyArrayLike num_kernel_points: int = 10_000 num_train_points: int = 10_000 def __check_init__(self): """Check that 'num_kernel_points' and 'num_train_points' are feasible.""" if self.num_kernel_points <= 0: raise ValueError("'num_kernel_points' must be a positive integer") if self.num_train_points <= 0: raise ValueError("'num_train_points' must be a positive integer")
[docs] class MonteCarloApproximateKernel(RandomRegressionKernel): """ Approximate a base kernel via random subset selection. Only the Gramian row-mean is approximated here, all other methods are inherited directly from the ``base_kernel``. :param base_kernel: a :class:`~coreax.kernels.ScalarValuedKernel` whose attributes/methods are to be approximated :param random_key: Key for random number generation :param num_kernel_points: Number of kernel evaluation points :param num_train_points: Number of training points used to fit kernel regression """
[docs] def gramian_row_mean( self, x: Union[ Shaped[Array, " n d"], Shaped[Array, " d"], Shaped[Array, ""], float, int, Data, ], **kwargs: Any, ) -> Shaped[Array, " n"]: r""" Approximate the Gramian row-mean by Monte-Carlo sampling. A uniform random subset of ``x`` is used to approximate the base kernel's Gramian row-mean. :param x: Data matrix, :math:`n \times d` :return: Approximation of the base kernel's Gramian row-mean """ del kwargs # This method does not support weighted computation of the mean, therefore # we need to handle the case where `x` is passed as a `Data` instance if isinstance(x, Data): x = x.data x = _atleast_2d_consistent(x) num_data_points = len(x) key = self.random_key features_idx = _random_indices(key, num_data_points, self.num_kernel_points - 1) features = self.base_kernel.compute(x, x[features_idx]) return _random_least_squares( key, x, features, self.num_train_points, partial(self.base_kernel.compute_mean, x, axis=0), )
[docs] class ANNchorApproximateKernel(RandomRegressionKernel): r""" Approximate a base kernel via random kernel regression on ANNchor selected points. Only the base kernel's Gramian row-mean is approximated here, all other methods are inherited directly from the ``base_kernel``. :param base_kernel: a :class:`~coreax.kernels.ScalarValuedKernel` whose attributes/methods are to be approximated :param random_key: Key for random number generation :param num_kernel_points: Number of kernel evaluation points :param num_train_points: Number of training points used to fit kernel regression """
[docs] def gramian_row_mean( self, x: Union[ Shaped[Array, " n d"], Shaped[Array, " d"], Shaped[Array, ""], float, int, Data, ], **kwargs: Any, ) -> Shaped[Array, " n"]: r""" Approximate the Gramian row-mean by random regression on ANNchor points. A subset of ``x`` is selected via the ANNchor approach and random kernel regression used to approximate the base kernel's Gramian row-mean. The ANNchor implementation used can be found `here <https://github.com/gchq/annchor>`_. :param x: Data matrix, :math:`n \times d` :return: Approximation of the base kernel's Gramian row-mean """ del kwargs # This method does not support weighted computation of the mean, therefore # we need to handle the case where `x` is passed as a `Data` instance if isinstance(x, Data): x = x.data x = _atleast_2d_consistent(x) num_data_points = len(x) features = jnp.zeros((num_data_points, self.num_kernel_points)) features = features.at[:, 0].set(self.base_kernel.compute(x, x[0])[:, 0]) def _annchor_body( idx: int, _features: Shaped[Array, " n num_kernel_points"] ) -> Shaped[Array, " n num_kernel_points"]: r""" Execute main loop of the ANNchor construction. :param idx: Loop counter :param _features: Loop variables to be updated :return: Updated loop variables ``features`` """ max_entry = _features.max(axis=1).argmin() _features = _features.at[:, idx].set( self.base_kernel.compute(x, x[max_entry])[:, 0] ) return _features features = jax.lax.fori_loop(1, self.num_kernel_points, _annchor_body, features) return _random_least_squares( self.random_key, x, features, self.num_train_points, partial(self.base_kernel.compute_mean, x, axis=0), )
[docs] class NystromApproximateKernel(RandomRegressionKernel): """ Approximate a base kernel via Nystrom approximation. Only the base kernel's Gramian row-mean is approximated here, all other methods are inherited directly from the ``base_kernel``. :param base_kernel: a :class:`~coreax.kernels.ScalarValuedKernel` whose attributes/methods are to be approximated :param random_key: Key for random number generation :param num_kernel_points: Number of kernel evaluation points :param num_train_points: Number of training points used to fit kernel regression """
[docs] def gramian_row_mean( self, x: Union[ Shaped[Array, " n d"], Shaped[Array, " d"], Shaped[Array, ""], float, int, Data, ], **kwargs: Any, ) -> Shaped[Array, " n"]: r""" Approximate the Gramian row-mean by Nystrom approximation. We consider a :math:`n \times d` dataset, and wish to use an :math:`m \times d` subset of this to approximate the base kernel's Gramian row-mean. The ``m`` points are selected uniformly at random, and the Nystrom estimator, as defined in :cite:`chatalic2022nystrom` is computed using this subset. :param x: Data matrix, :math:`n \times d` :return: Approximation of the base kernel's Gramian row-mean """ del kwargs # This method does not support weighted computation of the mean, therefore # we need to handle the case where `x` is passed as a `Data` instance if isinstance(x, Data): x = x.data x = _atleast_2d_consistent(x) num_data_points = len(x) feature_idx = _random_indices( self.random_key, num_data_points, self.num_kernel_points ) features = self.base_kernel.compute(x, x[feature_idx]) return _random_least_squares( self.random_key, # intentional key reuse to ensure train_idx = feature_idx x, features, self.num_train_points, self.base_kernel.gramian_row_mean, )