Utility Functions¶
Functionality to perform simple, generic tasks and operations.
The functions within this module are simple solutions to various problems or requirements that are sufficiently generic to be useful across multiple areas of the codebase. Examples of this include computation of squared distances, definition of class factories and checks for numerical precision.
- coreax.util.KeyArray¶
JAX random key type annotations.
- exception coreax.util.NotCalculatedError[source]¶
Bases:
ExceptionRaise when trying to use a variable that has not been calculated yet.
- class coreax.util.JITCompilableFunction(fn, fn_args=(), fn_kwargs=None, jit_kwargs=None, name=None)[source]¶
Bases:
NamedTupleParameters for
jit_test().- Parameters:
fn (Callable) – JIT-compilable function callable to test
fn_args (tuple) – Arguments passed during the calls to the passed function
fn_kwargs (dict[str, Any] | None) – Keyword arguments passed during the calls to the passed function
jit_kwargs (dict[str, Any] | None) – Keyword arguments that are partially applied to
jax.jit()before being called to compile the passed functionname (str | None)
- class coreax.util.InvalidKernel(x)[source]¶
Bases:
objectSimple class that does not have a compute method on to test kernel.
This is used across several testing instances to ensure the consequence of invalid inputs is correctly caught.
- Parameters:
x (float)
- coreax.util.tree_leaves_repeat(tree, length=2)[source]¶
Flatten a PyTree to its leaves and (potentially) repeat the trailing leaf.
The PyTree ‘tree’ is flattened, but unlike the standard flattening,
Noneis treated as a valid leaf and the trailing leaf (potentially) repeated such that the length of the collection of leaves is given by the ‘length’ parameter.- Parameters:
- Return type:
- Returns:
The PyTree leaves, with the trailing leaf repeated as many times as required for the collection of leaves to have length ‘repeated_length’
- coreax.util.tree_zero_pad_leading_axis(tree, pad_width)[source]¶
Pad each array leaf of ‘tree’ with ‘pad_width’ trailing zeros.
- coreax.util.apply_negative_precision_threshold(x, precision_threshold=1e-08)[source]¶
Round a number to 0.0 if it is negative but within precision_threshold of 0.0.
- coreax.util.pairwise(fn)[source]¶
Transform a function so it returns all pairwise evaluations of its inputs.
- Parameters:
fn (
Callable[[Union[Shaped[Array, 'd'],Shaped[Array, ''],float,int],Union[Shaped[Array, 'd'],Shaped[Array, ''],float,int]],Shaped[Array, '*d']]) – the function to apply the pairwise transform to.- Return type:
Callable[[Union[Shaped[Array, 'n d'],Shaped[Array, 'd'],Shaped[Array, ''],float,int],Union[Shaped[Array, 'm d'],Shaped[Array, 'd'],Shaped[Array, ''],float,int]],Shaped[Array, 'n m *d']]- Returns:
function that returns an array whose entries are the evaluations of fn for every pairwise combination of its input arguments.
- coreax.util.sample_batch_indices(random_key, max_index, batch_size, num_batches)[source]¶
Sample an array of indices of size num_batches x batch_size.
Each row (batch) of the sampled array will contain unique elements.
- Parameters:
- Return type:
Shaped[Array, 'num_batches batch_size']- Returns:
Array of batch indices of size num_batches x batch_size
- coreax.util.jit_test(fn, fn_args=(), fn_kwargs=None, jit_kwargs=None)[source]¶
Measure execution times of two runs of a JIT-compilable function.
The function is called with supplied arguments twice, and timed for each run. These timings are returned in a 2-tuple. These timings can help verify the JIT performance by comparing timings of a before and after run of a function.
- Parameters:
fn (
Callable) – JIT-compilable function callable to testfn_args (
tuple) – Arguments passed during the calls to the passed functionfn_kwargs (
Optional[dict]) – Keyword arguments passed during the calls to the passed functionjit_kwargs (
Optional[dict]) – Keyword arguments that are partially applied tojax.jit()before being called to compile the passed function
- Return type:
- Returns:
(First run time, Second run time), in seconds
- coreax.util.format_time(num)[source]¶
Standardise the format of the input time.
Floats will be converted to a standard format, e.g. 0.4531 -> “453.1 ms”.
- coreax.util.speed_comparison_test(function_setups, num_runs=10, log_results=False, normalisation=None)[source]¶
Compare compilation time and runtime of a list of JIT-able functions.
- Parameters:
function_setups (
Sequence[JITCompilableFunction]) – Sequence of instances ofJITCompilableFunctionnum_runs (
int) – Number of times to average function timings overlog_results (
bool) – IfTrue, the results are formatted and loggednormalisation (
Optional[tuple[float,float]]) – Tuple (compilation normalisation, execution normalisation). If provided, returned compilation/execution times are normalised so that this time is 1 time unit.
- Return type:
- Returns:
List of tuples (means, standard deviations) for each function containing JIT compilation and execution times as array components; Dictionary with key function name and value array of estimated compilation times in first column and execution time in second column