Source code for coreax.kernels.util
# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for the kernels subpackage."""
from math import ceil
import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import Array
from jaxtyping import Shaped
from coreax.data import Data, _atleast_2d_consistent, as_data
from coreax.util import pairwise, squared_distance, tree_zero_pad_leading_axis
def _block_data_convert(
x: Data | Shaped[Array, " n d"], block_size: int | None
) -> tuple[Data, int]:
"""Convert 'x' into padded and weight normalized blocks of size 'block_size'."""
x = as_data(x).normalize(preserve_zeros=True)
block_size = len(x) if block_size is None else min(max(int(block_size), 1), len(x))
padding = ceil(len(x) / block_size) * block_size - len(x)
padded_x = tree_zero_pad_leading_axis(x, padding)
def _reshape(x: Array) -> Array:
_, *remaining_shape = jnp.shape(x)
try:
return x.reshape(-1, block_size, *remaining_shape)
except ZeroDivisionError as err:
if 0 in x.shape:
raise ValueError("'x' must not be empty") from err
raise
return jtu.tree_map(_reshape, padded_x, is_leaf=eqx.is_array), len(x)