Utility Functions#
Functionality to perform simple, generic tasks and operations.
The functions within this module are simple solutions to various problems or requirements that are sufficiently generic to be useful across multiple areas of the codebase. Examples of this include computation of squared distances, definition of class factories and checks for numerical precision.
- coreax.util.KeyArray#
JAX random key type annotations.
- exception coreax.util.NotCalculatedError[source]#
Raise when trying to use a variable that has not been calculated yet.
- class coreax.util.InvalidKernel(x)[source]#
Simple class that does not have a compute method on to test kernel.
This is used across several testing instances to ensure the consequence of invalid inputs is correctly caught.
- Parameters:
x (float) –
- coreax.util.tree_leaves_repeat(tree, length=2)[source]#
Flatten a PyTree to its leaves and (potentially) repeat the trailing leaf.
The PyTree ‘tree’ is flattened, but unlike the standard flattening,
Noneis treated as a valid leaf and the trailing leaf (potentially) repeated such that the length of the collection of leaves is given by the ‘length’ parameter.- Parameters:
- Return type:
- Returns:
The PyTree leaves, with the trailing leaf repeated as many times as required for the collection of leaves to have length ‘repeated_length’
- coreax.util.tree_zero_pad_leading_axis(tree, pad_width)[source]#
Pad each array leaf of ‘tree’ with ‘pad_width’ trailing zeros.
- coreax.util.apply_negative_precision_threshold(x, precision_threshold=1e-08)[source]#
Round a number to 0.0 if it is negative but within precision_threshold of 0.0.
- coreax.util.pairwise(fn)[source]#
Transform a function so it returns all pairwise evaluations of its inputs.
- coreax.util.squared_distance_pairwise(x, y)[source]#
Calculate efficient pairwise square distance between two arrays.
- coreax.util.pairwise_difference(x, y)[source]#
Calculate efficient pairwise difference between two arrays of vectors.
- coreax.util.solve_qp(kernel_mm, gramian_row_mean, **osqp_kwargs)[source]#
Solve quadratic programs with the
jaxopt.OSQPsolver.Solves simplex weight problems of the form:
\[\mathbf{w}^{\mathrm{T}} \mathbf{k} \mathbf{w} + \bar{\mathbf{k}}^{\mathrm{T}} \mathbf{w} = 0\]subject to
\[\mathbf{Aw} = \mathbf{1}, \qquad \mathbf{Gx} \le 0.\]
- coreax.util.sample_batch_indices(random_key, max_index, batch_size, num_batches)[source]#
Sample an array of indices of size num_batches x batch_size.
Each row (batch) of the sampled array will contain unique elements.
- Parameters:
- Return type:
- Returns:
Array of batch indices of size num_batches x batch_size
- coreax.util.jit_test(fn, fn_args=(), fn_kwargs=None, jit_kwargs=None)[source]#
Verify JIT performance by comparing timings of a before and after run of a function.
The function is called with supplied arguments twice, and timed for each run. These timings are returned in a 2-tuple.
- Parameters:
fn (
Callable) – Function callable to testfn_args (
tuple) – Arguments passed during the calls to the passed functionfn_kwargs (
Optional[dict]) – Keyword arguments passed during the calls to the passed functionjit_kwargs (
Optional[dict]) – Keyword arguments that are partially applied tojax.jit()before being called to compile the passed function.
- Return type:
- Returns:
(First run time, Second run time)