Utility Functions#

Functionality to perform simple, generic tasks and operations.

The functions within this module are simple solutions to various problems or requirements that are sufficiently generic to be useful across multiple areas of the codebase. Examples of this include computation of squared distances, definition of class factories and checks for numerical precision.

coreax.util.KeyArray#

JAX random key type annotations.

exception coreax.util.NotCalculatedError[source]#

Raise when trying to use a variable that has not been calculated yet.

class coreax.util.InvalidKernel(x)[source]#

Simple class that does not have a compute method on to test kernel.

This is used across several testing instances to ensure the consequence of invalid inputs is correctly caught.

Parameters:

x (float) –

coreax.util.tree_leaves_repeat(tree, length=2)[source]#

Flatten a PyTree to its leaves and (potentially) repeat the trailing leaf.

The PyTree ‘tree’ is flattened, but unlike the standard flattening, None is treated as a valid leaf and the trailing leaf (potentially) repeated such that the length of the collection of leaves is given by the ‘length’ parameter.

Parameters:
  • tree (Any) – The PyTree to flatten and whose trailing leaf to (potentially) repeat

  • length (int) – The length of the flattened PyTree after any repetition; values are implicitly clipped by max(len(tree_leaves), length)

Return type:

list[Any]

Returns:

The PyTree leaves, with the trailing leaf repeated as many times as required for the collection of leaves to have length ‘repeated_length’

coreax.util.tree_zero_pad_leading_axis(tree, pad_width)[source]#

Pad each array leaf of ‘tree’ with ‘pad_width’ trailing zeros.

Parameters:
  • tree (Any) – The PyTree whose array leaves to pad with trailing zeros

  • pad_width (int) – The number of trailing zeros to pad with

Return type:

Any

Returns:

A copy of the original PyTree with the array leaves padded

coreax.util.apply_negative_precision_threshold(x, precision_threshold=1e-08)[source]#

Round a number to 0.0 if it is negative but within precision_threshold of 0.0.

Parameters:
  • x (ArrayLike) – Scalar value we wish to compare to 0.0

  • precision_threshold (float) – Positive threshold we compare against for precision

Return type:

Array

Returns:

x, rounded to 0.0 if it is between -precision_threshold and 0.0

coreax.util.pairwise(fn)[source]#

Transform a function so it returns all pairwise evaluations of its inputs.

Parameters:

fn (Callable[[ArrayLike, ArrayLike], Array]) – the function to apply the pairwise transform to.

Return type:

Callable[[ArrayLike, ArrayLike], Array]

Returns:

function that returns an array whose entries are the evaluations of fn for every pairwise combination of its input arguments.

coreax.util.squared_distance(x, y)[source]#

Calculate the squared distance between two vectors.

Parameters:
Return type:

Array

Returns:

Dot product of x - y and x - y, the square distance between x and y

coreax.util.squared_distance_pairwise(x, y)[source]#

Calculate efficient pairwise square distance between two arrays.

Parameters:
  • x (ArrayLike) – First set of vectors as a \(n \times d\) array

  • y (ArrayLike) – Second set of vectors as a \(m \times d\) array

Return type:

Array

Returns:

Pairwise squared distances between x_array and y_array as an \(n \times m\) array

coreax.util.difference(x, y)[source]#

Calculate vector difference for a pair of vectors.

Parameters:
Return type:

Array

Returns:

Vector difference x - y

coreax.util.pairwise_difference(x, y)[source]#

Calculate efficient pairwise difference between two arrays of vectors.

Parameters:
  • x (ArrayLike) – First set of vectors as a \(n \times d\) array

  • y (ArrayLike) – Second set of vectors as a \(m \times d\) array

Return type:

Array

Returns:

Pairwise differences between x_array and y_array as an \(n \times m \times d\) array

coreax.util.solve_qp(kernel_mm, gramian_row_mean, **osqp_kwargs)[source]#

Solve quadratic programs with the jaxopt.OSQP solver.

Solves simplex weight problems of the form:

\[\mathbf{w}^{\mathrm{T}} \mathbf{k} \mathbf{w} + \bar{\mathbf{k}}^{\mathrm{T}} \mathbf{w} = 0\]

subject to

\[\mathbf{Aw} = \mathbf{1}, \qquad \mathbf{Gx} \le 0.\]
Parameters:
  • kernel_mm (ArrayLike) – \(m \times m\) coreset Gram matrix

  • gramian_row_mean (ArrayLike) – \(m \times 1\) array of Gram matrix means

Return type:

Array

Returns:

Optimised solution for the quadratic program

coreax.util.sample_batch_indices(random_key, max_index, batch_size, num_batches)[source]#

Sample an array of indices of size num_batches x batch_size.

Each row (batch) of the sampled array will contain unique elements.

Parameters:
  • random_key (ArrayLike) – Key for random number generation

  • max_index (int) – Largest index we wish to sample

  • batch_size (int) – Size of the batch we wish to sample

  • num_batches (int) – Number of batches to sample

Return type:

Array

Returns:

Array of batch indices of size num_batches x batch_size

coreax.util.jit_test(fn, fn_args=(), fn_kwargs=None, jit_kwargs=None)[source]#

Verify JIT performance by comparing timings of a before and after run of a function.

The function is called with supplied arguments twice, and timed for each run. These timings are returned in a 2-tuple.

Parameters:
  • fn (Callable) – Function callable to test

  • fn_args (tuple) – Arguments passed during the calls to the passed function

  • fn_kwargs (Optional[dict]) – Keyword arguments passed during the calls to the passed function

  • jit_kwargs (Optional[dict]) – Keyword arguments that are partially applied to jax.jit() before being called to compile the passed function.

Return type:

tuple[float, float]

Returns:

(First run time, Second run time)

class coreax.util.SilentTQDM(iterable, *_args, **_kwargs)[source]#

Class implementing interface of tqdm that does nothing.

It can substitute tqdm to silence all output.

Based on code by Pro Q.

Additional parameters are accepted and ignored to match interface of tqdm.

Parameters:

iterable (Iterable[Any]) – Iterable of tasks to (not) indicate progress for

write(*_args, **_kwargs)[source]#

Do nothing instead of writing to output.

Return type:

None