Source code for coreax.util

# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Functionality to perform simple, generic tasks and operations.

The functions within this module are simple solutions to various problems or
requirements that are sufficiently generic to be useful across multiple areas of the
codebase. Examples of this include computation of squared distances, definition of
class factories and checks for numerical precision.
"""

import logging
import sys
import time
from collections.abc import Callable, Sequence
from functools import partial, wraps
from math import log10
from typing import (
    Any,
    NamedTuple,
    TypeAlias,
)

import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
from jax import Array, block_until_ready, jit, vmap
from jaxtyping import Shaped

_logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, stream=sys.stdout)

PyTreeDef: TypeAlias = Any
Leaf: TypeAlias = Any

#: JAX random key type annotations.
KeyArray: TypeAlias = Array
# jax.random functions crash if passed a scalar, so can't use ArrayLike
KeyArrayLike: TypeAlias = Array


[docs] class NotCalculatedError(Exception): """Raise when trying to use a variable that has not been calculated yet."""
[docs] class JITCompilableFunction(NamedTuple): """ Parameters for :func:`jit_test`. :param fn: JIT-compilable function callable to test :param fn_args: Arguments passed during the calls to the passed function :param fn_kwargs: Keyword arguments passed during the calls to the passed function :param jit_kwargs: Keyword arguments that are partially applied to :func:`jax.jit` before being called to compile the passed function """ fn: Callable fn_args: tuple = () fn_kwargs: dict[str, Any] | None = None jit_kwargs: dict[str, Any] | None = None name: str | None = None
[docs] def without_name( self, ) -> tuple[Callable, tuple, dict[str, Any] | None, dict[str, Any] | None]: """Return the tuple (fn, fn_args, fn_kwargs, jit_kwargs).""" return self.fn, self.fn_args, self.fn_kwargs, self.jit_kwargs
[docs] class InvalidKernel: """ Simple class that does not have a compute method on to test kernel. This is used across several testing instances to ensure the consequence of invalid inputs is correctly caught. """ def __init__(self, x: float): """Initialise the invalid kernel object.""" self.x = x
[docs] def tree_leaves_repeat(tree: PyTreeDef, length: int = 2) -> list[Leaf]: """ Flatten a PyTree to its leaves and (potentially) repeat the trailing leaf. The PyTree 'tree' is flattened, but unlike the standard flattening, :data:`None` is treated as a valid leaf and the trailing leaf (potentially) repeated such that the length of the collection of leaves is given by the 'length' parameter. :param tree: The PyTree to flatten and whose trailing leaf to (potentially) repeat :param length: The length of the flattened PyTree after any repetition; values are implicitly clipped by :code:`max(len(tree_leaves), length)` :return: The PyTree leaves, with the trailing leaf repeated as many times as required for the collection of leaves to have length 'repeated_length' """ tree_leaves = jtu.tree_leaves(tree, is_leaf=lambda x: x is None) num_repeats = length - len(tree_leaves) return tree_leaves + tree_leaves[-1:] * num_repeats
[docs] def tree_zero_pad_leading_axis(tree: PyTreeDef, pad_width: int) -> PyTreeDef: """ Pad each array leaf of 'tree' with 'pad_width' trailing zeros. :param tree: The PyTree whose array leaves to pad with trailing zeros :param pad_width: The number of trailing zeros to pad with :return: A copy of the original PyTree with the array leaves padded """ if int(pad_width) < 0: raise ValueError("'pad_width' must be a positive integer") leaves_to_pad, leaves_to_keep = eqx.partition(tree, eqx.is_array) def _pad(x: Shaped[Array, " n"]) -> Shaped[Array, " n + pad_width"]: padding = (0, int(pad_width)) skip_padding = ((0, 0),) * (jnp.ndim(x) - 1) return jnp.pad(x, (padding, *skip_padding)) padded_leaves = jtu.tree_map(_pad, leaves_to_pad) return eqx.combine(padded_leaves, leaves_to_keep)
[docs] def apply_negative_precision_threshold( x: Shaped[Array, ""] | float | int, precision_threshold: float = 1e-8 ) -> Shaped[Array, ""]: """ Round a number to 0.0 if it is negative but within precision_threshold of 0.0. :param x: Scalar value we wish to compare to 0.0 :param precision_threshold: Positive threshold we compare against for precision :return: ``x``, rounded to 0.0 if it is between ``-precision_threshold`` and 0.0 """ _x = jnp.asarray(x) return jnp.where((-jnp.abs(precision_threshold) < _x) & (_x < 0.0), 0.0, _x)
[docs] def pairwise( fn: Callable[ [ Shaped[Array, " d"] | Shaped[Array, ""] | float | int, Shaped[Array, " d"] | Shaped[Array, ""] | float | int, ], Shaped[Array, " *d"], ], ) -> Callable[ [ Shaped[Array, " n d"] | Shaped[Array, " d"] | Shaped[Array, ""] | float | int, Shaped[Array, " m d"] | Shaped[Array, " d"] | Shaped[Array, ""] | float | int, ], Shaped[Array, " n m *d"], ]: """ Transform a function so it returns all pairwise evaluations of its inputs. :param fn: the function to apply the pairwise transform to. :returns: function that returns an array whose entries are the evaluations of `fn` for every pairwise combination of its input arguments. """ @wraps(fn) def pairwise_fn( x: Shaped[Array, " n d"] | Shaped[Array, " d"] | Shaped[Array, ""] | float | int, y: Shaped[Array, " m d"] | Shaped[Array, " d"] | Shaped[Array, ""] | float | int, ) -> Shaped[Array, " n m *d"]: x = jnp.atleast_2d(x) y = jnp.atleast_2d(y) return vmap( vmap(fn, in_axes=(0, None), out_axes=0), in_axes=(None, 0), out_axes=1, )(x, y) return pairwise_fn
[docs] @jit def squared_distance( x: Shaped[Array, " d"] | Shaped[Array, ""] | float | int, y: Shaped[Array, " d"] | Shaped[Array, ""] | float | int, ) -> Shaped[Array, ""]: """ Calculate the squared distance between two vectors. :param x: First vector argument :param y: Second vector argument :return: Dot product of ``x - y`` and ``x - y``, the square distance between ``x`` and ``y`` """ x = jnp.atleast_1d(x) y = jnp.atleast_1d(y) return jnp.dot(x - y, x - y)
[docs] @jit def difference( x: Shaped[Array, " d"] | Shaped[Array, ""] | float | int, y: Shaped[Array, " d"] | Shaped[Array, ""] | float | int, ) -> Shaped[Array, ""]: """ Calculate vector difference for a pair of vectors. :param x: First vector :param y: Second vector :return: Vector difference ``x - y`` """ x = jnp.atleast_1d(x) y = jnp.atleast_1d(y) return x - y
[docs] def sample_batch_indices( random_key: KeyArrayLike, max_index: int, batch_size: int, num_batches: int, ) -> Shaped[Array, " num_batches batch_size"]: """ Sample an array of indices of size `num_batches` x `batch_size`. Each row (batch) of the sampled array will contain unique elements. :param random_key: Key for random number generation :param max_index: Largest index we wish to sample :param batch_size: Size of the batch we wish to sample :param num_batches: Number of batches to sample :return: Array of batch indices of size `num_batches` x `batch_size` """ if max_index < batch_size: raise ValueError("'max_index' must be greater than or equal to 'batch_size'") if batch_size < 0.0: raise ValueError("'batch_size' must be non-negative") batch_keys = jr.split(random_key, num_batches) batch_permutation = vmap(jr.permutation, in_axes=(0, None)) return batch_permutation(batch_keys, max_index)[:, :batch_size]
[docs] def jit_test( fn: Callable, fn_args: tuple = (), fn_kwargs: dict | None = None, jit_kwargs: dict | None = None, ) -> tuple[float, float]: """ Measure execution times of two runs of a JIT-compilable function. The function is called with supplied arguments twice, and timed for each run. These timings are returned in a 2-tuple. These timings can help verify the JIT performance by comparing timings of a before and after run of a function. :param fn: JIT-compilable function callable to test :param fn_args: Arguments passed during the calls to the passed function :param fn_kwargs: Keyword arguments passed during the calls to the passed function :param jit_kwargs: Keyword arguments that are partially applied to :func:`jax.jit` before being called to compile the passed function :return: (First run time, Second run time), in seconds """ # Avoid dangerous default values - Pylint W0102 if fn_kwargs is None: fn_kwargs = {} if jit_kwargs is None: jit_kwargs = {} @partial(jit, **jit_kwargs) def _fn(*args, **kwargs): return fn(*args, **kwargs) start_time = time.perf_counter() block_until_ready(_fn(*fn_args, **fn_kwargs)) end_time = time.perf_counter() pre_delta = end_time - start_time start_time = time.perf_counter() block_until_ready(_fn(*fn_args, **fn_kwargs)) end_time = time.perf_counter() post_delta = end_time - start_time return pre_delta, post_delta
[docs] def format_time(num: float) -> str: """ Standardise the format of the input time. Floats will be converted to a standard format, e.g. 0.4531 -> "453.1 ms". :param num: Float to be converted :return: Formatted time as a string """ try: order = log10(abs(num)) except ValueError: return "0 s" if order >= 2: # noqa: PLR2004 scaled_time = num / 60 unit_string = "mins" elif order < -9: # noqa: PLR2004 scaled_time = 1e12 * num unit_string = "ps" elif order < -6: # noqa: PLR2004 scaled_time = 1e9 * num unit_string = "ns" elif order < -3: # noqa: PLR2004 scaled_time = 1e6 * num unit_string = "\u03bcs" elif order < 0: # noqa: PLR2004 scaled_time = 1e3 * num unit_string = "ms" else: scaled_time = num unit_string = "s" return f"{round(scaled_time, 2)} {unit_string}"
[docs] def speed_comparison_test( function_setups: Sequence[JITCompilableFunction], num_runs: int = 10, log_results: bool = False, normalisation: tuple[float, float] | None = None, ) -> tuple[list[tuple[Array, Array]], dict[str, Array]]: """ Compare compilation time and runtime of a list of JIT-able functions. :param function_setups: Sequence of instances of :class:`JITCompilableFunction` :param num_runs: Number of times to average function timings over :param log_results: If :data:`True`, the results are formatted and logged :param normalisation: Tuple (compilation normalisation, execution normalisation). If provided, returned compilation/execution times are normalised so that this time is 1 time unit. :return: List of tuples (means, standard deviations) for each function containing JIT compilation and execution times as array components; Dictionary with key function name and value array of estimated compilation times in first column and execution time in second column """ timings_dict = {} results = [] for i, function in enumerate(function_setups): name = function.name name = name if name is not None else f"function_{i + 1}" if log_results: _logger.info("------------------- %s -------------------", name) timings = jnp.zeros((num_runs, 2)) for j in range(num_runs): timings = timings.at[j, :].set(jit_test(*function.without_name())) # Compute the time just spent on compilation timings = timings.at[:, 0].set(timings[:, 0] - timings[:, 1]) # Normalise, if necessary if normalisation is not None: timings = timings.at[:, 0].set(timings[:, 0] / normalisation[0]) timings = timings.at[:, 1].set(timings[:, 1] / normalisation[1]) timings_dict[name] = timings # Compute summary statistics mean = timings.mean(axis=0) std = timings.std(axis=0) results.append((mean, std)) if log_results: if normalisation: _logger.info( "Compilation time: %.4g units ± %.4g units per run " "(mean ± std. dev. of %s runs)", mean[0].item(), std[0].item(), num_runs, ) _logger.info( "Execution time: %.4g units ± %.4g units per run " "(mean ± std. dev. of %s runs)", mean[1].item(), std[1].item(), num_runs, ) else: _logger.info( "Compilation time: %s ± %s per run (mean ± std. dev. of %s runs)", format_time(mean[0].item()), format_time(std[0].item()), num_runs, ) _logger.info( "Execution time: %s ± %s per run (mean ± std. dev. of %s runs)", format_time(mean[1].item()), format_time(std[1].item()), num_runs, ) return results, timings_dict