Utility Functions

Functionality to perform simple, generic tasks and operations.

The functions within this module are simple solutions to various problems or requirements that are sufficiently generic to be useful across multiple areas of the codebase. Examples of this include computation of squared distances, definition of class factories and checks for numerical precision.

coreax.util.KeyArray

JAX random key type annotations.

exception coreax.util.NotCalculatedError[source]

Bases: Exception

Raise when trying to use a variable that has not been calculated yet.

class coreax.util.JITCompilableFunction(fn, fn_args=(), fn_kwargs=None, jit_kwargs=None, name=None)[source]

Bases: NamedTuple

Parameters for jit_test().

Parameters:
  • fn (Callable) – JIT-compilable function callable to test

  • fn_args (tuple) – Arguments passed during the calls to the passed function

  • fn_kwargs (Dict[str, Any] | None) – Keyword arguments passed during the calls to the passed function

  • jit_kwargs (Dict[str, Any] | None) – Keyword arguments that are partially applied to jax.jit() before being called to compile the passed function

  • name (str | None)

without_name()[source]

Return the tuple (fn, fn_args, fn_kwargs, jit_kwargs).

Return type:

Tuple[Callable, Tuple, Optional[Dict[str, Any]], Optional[Dict[str, Any]]]

class coreax.util.InvalidKernel(x)[source]

Bases: object

Simple class that does not have a compute method on to test kernel.

This is used across several testing instances to ensure the consequence of invalid inputs is correctly caught.

Parameters:

x (float)

coreax.util.tree_leaves_repeat(tree, length=2)[source]

Flatten a PyTree to its leaves and (potentially) repeat the trailing leaf.

The PyTree ‘tree’ is flattened, but unlike the standard flattening, None is treated as a valid leaf and the trailing leaf (potentially) repeated such that the length of the collection of leaves is given by the ‘length’ parameter.

Parameters:
  • tree (Any) – The PyTree to flatten and whose trailing leaf to (potentially) repeat

  • length (int) – The length of the flattened PyTree after any repetition; values are implicitly clipped by max(len(tree_leaves), length)

Return type:

list[Any]

Returns:

The PyTree leaves, with the trailing leaf repeated as many times as required for the collection of leaves to have length ‘repeated_length’

coreax.util.tree_zero_pad_leading_axis(tree, pad_width)[source]

Pad each array leaf of ‘tree’ with ‘pad_width’ trailing zeros.

Parameters:
  • tree (Any) – The PyTree whose array leaves to pad with trailing zeros

  • pad_width (int) – The number of trailing zeros to pad with

Return type:

Any

Returns:

A copy of the original PyTree with the array leaves padded

coreax.util.apply_negative_precision_threshold(x, precision_threshold=1e-08)[source]

Round a number to 0.0 if it is negative but within precision_threshold of 0.0.

Parameters:
  • x (Union[Shaped[Array, ''], float, int]) – Scalar value we wish to compare to 0.0

  • precision_threshold (float) – Positive threshold we compare against for precision

Return type:

Shaped[Array, '']

Returns:

x, rounded to 0.0 if it is between -precision_threshold and 0.0

coreax.util.pairwise(fn)[source]

Transform a function so it returns all pairwise evaluations of its inputs.

Parameters:

fn (Callable[[Union[Shaped[Array, 'd'], Shaped[Array, ''], float, int], Union[Shaped[Array, 'd'], Shaped[Array, ''], float, int]], Shaped[Array, '*d']]) – the function to apply the pairwise transform to.

Return type:

Callable[[Union[Shaped[Array, 'n d'], Shaped[Array, 'd'], Shaped[Array, ''], float, int], Union[Shaped[Array, 'm d'], Shaped[Array, 'd'], Shaped[Array, ''], float, int]], Shaped[Array, 'n m *d']]

Returns:

function that returns an array whose entries are the evaluations of fn for every pairwise combination of its input arguments.

coreax.util.squared_distance(x, y)[source]

Calculate the squared distance between two vectors.

Parameters:
  • x (Union[Shaped[Array, 'd'], Shaped[Array, ''], float, int]) – First vector argument

  • y (Union[Shaped[Array, 'd'], Shaped[Array, ''], float, int]) – Second vector argument

Return type:

Shaped[Array, '']

Returns:

Dot product of x - y and x - y, the square distance between x and y

coreax.util.difference(x, y)[source]

Calculate vector difference for a pair of vectors.

Parameters:
  • x (Union[Shaped[Array, 'd'], Shaped[Array, ''], float, int]) – First vector

  • y (Union[Shaped[Array, 'd'], Shaped[Array, ''], float, int]) – Second vector

Return type:

Shaped[Array, '']

Returns:

Vector difference x - y

coreax.util.median_heuristic(x)[source]

Compute the median heuristic for setting kernel bandwidth.

Analysis of the performance of the median heuristic can be found in [garreau2018median].

Parameters:

x (Union[Shaped[Array, 'n d'], Shaped[Array, 'n'], Shaped[Array, ''], float, int]) – Input array of vectors

Return type:

Shaped[Array, '']

Returns:

Bandwidth parameter, computed from the median heuristic, as a zero-dimensional array

coreax.util.sample_batch_indices(random_key, max_index, batch_size, num_batches)[source]

Sample an array of indices of size num_batches x batch_size.

Each row (batch) of the sampled array will contain unique elements.

Parameters:
  • random_key (ArrayLike) – Key for random number generation

  • max_index (int) – Largest index we wish to sample

  • batch_size (int) – Size of the batch we wish to sample

  • num_batches (int) – Number of batches to sample

Return type:

Shaped[Array, 'num_batches batch_size']

Returns:

Array of batch indices of size num_batches x batch_size

coreax.util.jit_test(fn, fn_args=(), fn_kwargs=None, jit_kwargs=None)[source]

Measure execution times of two runs of a JIT-compilable function.

The function is called with supplied arguments twice, and timed for each run. These timings are returned in a 2-tuple. These timings can help verify the JIT performance by comparing timings of a before and after run of a function.

Parameters:
  • fn (Callable) – JIT-compilable function callable to test

  • fn_args (tuple) – Arguments passed during the calls to the passed function

  • fn_kwargs (Optional[dict]) – Keyword arguments passed during the calls to the passed function

  • jit_kwargs (Optional[dict]) – Keyword arguments that are partially applied to jax.jit() before being called to compile the passed function

Return type:

tuple[float, float]

Returns:

(First run time, Second run time), in seconds

coreax.util.format_time(num)[source]

Standardise the format of the input time.

Floats will be converted to a standard format, e.g. 0.4531 -> “453.1 ms”.

Parameters:

num (float) – Float to be converted

Return type:

str

Returns:

Formatted time as a string

coreax.util.speed_comparison_test(function_setups, num_runs=10, log_results=False, normalisation=None)[source]

Compare compilation time and runtime of a list of JIT-able functions.

Parameters:
  • function_setups (Sequence[JITCompilableFunction]) – Sequence of instances of JITCompilableFunction

  • num_runs (int) – Number of times to average function timings over

  • log_results (bool) – If True, the results are formatted and logged

  • normalisation (Optional[Tuple[float, float]]) – Tuple (compilation normalisation, execution normalisation). If provided, returned compilation/execution times are normalised so that this time is 1 time unit.

Return type:

tuple[list[tuple[Array, Array]], dict[str, Array]]

Returns:

List of tuples (means, standard deviations) for each function containing JIT compilation and execution times as array components; Dictionary with key function name and value array of estimated compilation times in first column and execution time in second column

class coreax.util.SilentTQDM(iterable, *_args, **_kwargs)[source]

Bases: Generic[T]

Class implementing interface of tqdm that does nothing.

It can substitute tqdm to silence all output.

Based on code by Pro Q.

Additional parameters are accepted and ignored to match interface of tqdm.

Parameters:

iterable (Iterable[Any]) – Iterable of tasks to (not) indicate progress for

static write(*_args, **_kwargs)[source]

Do nothing instead of writing to output.

Return type:

None