Source code for coreax.solvers.recombination

# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""
Recombination solvers.

Take a dataset :math:`\{(x_i, w_i)\}_{i=1}^n`, where each node :math:`x_i \in \Omega`
is paired with a weight :math:`w_i \in \mathbb{R} \ge 0`, and the sum of all weights
is one, :math:`\sum_{i=1}^n w_i = 1` (a strict requirement for a probability measure).

.. note::
    Given any weighted dataset, we can use normalisation to satisfy the sum to
    one condition, providing :math:`\sum_{i=1}^n w_i \neq 0`.

Combined with :math:`m-1` test-functions :math:`\Phi^\prime = \{\phi_i\}_{i=1}^{m-1}`,
where :math:`\phi_i \colon \Omega \to \mathbb{R}`, that parametrise a set of :math:`m`
test-functions :math:`\Phi = \{x \mapsto 1\} \cup \Phi^\prime`, there exists a dataset
push-forward measure :math:`\mu_n := \Phi_* \nu_n`.

A recombination solver attempts to find a reduced measure (a coresubset)
:math:`\hat{\mu}_{m^\prime}`, which is given as a basic-feasible solution (BFS) to the
following linear-programming problem (with trivial objective)

.. math::
    \begin{align}
        \mathbf{Y} \mathbf{\hat{w}} &= \text{CoM}(\mu_n),\\
        \mathbf{\hat{w}} &\ge 0,
    \end{align}

where the system variables and "centre-of-mass" are defined as

.. math::
    \begin{gather}
    \mathbf{Y} := \left[\Phi(x_1), \dots, \Phi(x_n)\right] \in \mathbb{R}^{m \times n},\
    \mathbf{\hat{w}} \in \mathbb{R}^n \ge 0,\\
    \text{CoM}(\mu_n) := \sum_{i=1}^n w_i \Phi(x_i)
            = \left[ \sum_{i=1}^n w_i \phi_j(x_i) \right]_{j=1}^m \in \mathbb{R}^m.\\
    \end{gather}

.. note::
    The source dataset is, by definition, a solution to the linear-program that is not
    necessarily a BFS. Hence, one may consider the fundamental problem of recombination
    as that of finding a BFS given a solution that is not a BFS.

Basic feasible solutions to the linear-program above are of the form
:math:`\mathbf{\hat{w}} = \{\hat{w}_1, \dots, \hat{w}_{m^\prime}, \mathbf{0}\}`; I.E.
BFSs are feasible solutions with :math:`n-m^\prime` weights equal to zero. Given a BFS,
the reduced measure (the coresubset) can be constructed by explicitly removing the nodes
associated with each zero valued (implicitly removed) weight

.. math::
    \begin{gather}
    \hat{\nu}_{m^\prime} = \sum_{i \in I} \hat{w_i} \delta_{x_i},\\
    I = \{i \mid \hat{w_i} \neq 0\, \forall i \in \{1, \dots, n\}\}.
    \end{gather}

Due to Tchakaloff's theorem :cite:`tchakaloff1957,bayer2006tchakaloff`, which follows
from Caratheodory's convex hull theorem :cite:`caratheodory1907,loera2018caratheodory`,
we know there always exists a basic-feasible solution to the linear-program, with at
most :math:`m^\prime = \text{dim}(\text{span}(\Phi))` non-zero weights. Hence, we have
an upper bound on the size of a coresubset, controlled by the choice of test-functions.

.. note::
    A basic feasible solution (coresubset produced by recombination) is non-unique. In
    fact, there exists :math:`\binom{n}{m^\prime}` basic feasible solutions
    (coresubsets) for the described linear-program. In the context of Coreax, this means
    that a :class:`RecombinationSolver` is unlikely to ever be truly invariant to the
    presence of padding (see :class:`~coreax.solvers.PaddingInvariantSolver`). I.E. the
    padded problem may have an equivalent, but different BFS than the unpadded problem.

Canonically, recombination is used for reducing the support of a quadrature/cubature
measure, against which integration of any function :math:`f \in \text{span}(\Phi)`
is identical to integration against a "target" (potentially continuous) measure
:math:`\mu`.
"""

import math
from collections.abc import Callable
from typing import Generic, Literal, NamedTuple, TypeVar

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
from jaxtyping import Array, Bool, DTypeLike, Float, Integer, Real, Shaped
from typing_extensions import override

from coreax import Coresubset, Data
from coreax.solvers.base import CoresubsetSolver

_Data = TypeVar("_Data", bound=Data)
_State = TypeVar("_State")


[docs] class RecombinationSolver(CoresubsetSolver[_Data, _State], Generic[_Data, _State]): r""" Solver which returns a :class:`~coreax.coreset.Coresubset` via recombination. Given :math:`m-1` explicitly provided test-functions :math:`\Phi^\prime`, a recombination solver finds a coresubset with :math:`m^\prime \le m` points, whose push-forward :math:`\hat{\mu}_{m^\prime}` has the same "centre-of-mass" (CoM) as the dataset push-forward :math:`\mu_n := \Phi_* \nu_n`. :param test_functions: A callable that applies a set of specified test-functions :math:`\Phi^\prime = \{\phi_1,\dots,\phi_{m-1}\}` where each function is a map :math:`\phi_i \colon \Omega\to\mathbb{R}`; a value of :data:`None` implies the identity map :math:`\Phi^\prime \colon x \mapsto x`, and necessarily assumes that :math:`x \in \Omega \subseteq \mathbb{R}^{m-1}` :param mode: 'implicit-explicit' explicitly removes :math:`n - m` points, yielding a coreset of size :math:`m`, with :math:`m - m^\prime` zero-weighted (implicitly removed) points; 'implicit' explicitly removes no points, yielding a coreset of size :math:`n` with :math:`n - m^\prime` zero-weighted (implicitly removed) points; 'explicit' explicitly removes :math:`n - m^\prime` points, yielding a coreset of size :math:`m^\prime`, but unlike the other methods is not JIT compatible as the coreset size :math:`m^\prime` is unknown at compile time. """ test_functions: Callable[[Array], Real[Array, " m-1"]] | None = None mode: Literal["implicit-explicit", "implicit", "explicit"] = "implicit-explicit" def __check_init__(self): """Ensure a valid `self.mode` is specified.""" if self.mode not in {"implicit-explicit", "implicit", "explicit"}: raise ValueError( "Invalid mode, expected 'implicit-explicit', 'implicit' or 'explicit'." )
class _EliminationState(NamedTuple): weights: Shaped[Array, " n"] nodes: Shaped[Array, "n m"] iteration: int
[docs] class CaratheodoryRecombination(RecombinationSolver[Data, None]): r""" Recombination via Caratheodory measure reduction (Gaussian-Elimination). Proposed in :cite:`tchernychova2016recombination` (see Chapter 1.3.3.3) as an alternative to the Simplex algorithm for solving the recombination problem. Unlike the Simplex method, with time complexity :math:`\mathcal{O}(m^3 n + m n^2)`, Caratheodory recombination has time complexity of only :math:`\mathcal{O}(m n^2)`. .. note:: Given :math:`n = cm`, for a rational constant :math:`c`, the above complexities can be alternatively represented as :math:`\mathcal{O}(m^4)` for the Simplex method and :math:`\mathcal{O}(m^3)` for Caratheodory recombination. :param test_functions: A callable that applies a set of specified test-functions :math:`\Phi^\prime = \{\phi_1,\dots,\phi_{m-1}\}` where each function is a map :math:`\phi_i \colon \Omega \to \mathbb{R}`; a value of non implies the identity map :math:`\Phi^\prime \colon x \mapsto x`, and necessarily assumes that :math:`x \in \Omega \subseteq \mathbb{R}^{m-1}` :param mode: 'implicit-explicit' explicitly removes :math:`n - m` points, yielding a coreset of size :math:`m`, with :math:`m - m^\prime` zero-weighted (implicitly removed) points; 'implicit' explicitly removes no points, yielding a coreset of size :math:`n` with :math:`n - m^\prime` zero-weighted (implicitly removed) points; 'explicit' explicitly removes :math:`n - m^\prime` points, yielding a coreset of size :math:`m^\prime`, but unlike the other methods is not JIT compatible as the coreset size :math:`m^\prime` is unknown at compile time. :param rcond: A relative condition number; any singular value :math:`s` below the threshold :math:`\text{rcond} * \text{max}(s)` is treated as equal to zero; if rcond is :data:`None`, it defaults to `floating point eps * max(n, d)` """ rcond: float | None = None
[docs] @override def reduce( self, dataset: Data, solver_state: None = None ) -> tuple[Coresubset, None]: nodes, weights = jtu.tree_leaves(dataset.normalize(preserve_zeros=True)) push_forward_nodes = _push_forward(nodes, self.test_functions) # Handle pre-existing zero-weighted nodes (not handled by the base algorithm # described in :cite:`tchernychova2016recombination`) safe_push_forward_nodes, safe_weights, indices = _co_linearize( push_forward_nodes, weights ) largest_null_space_basis, null_space_rank = _resolve_null_basis( safe_push_forward_nodes, self.rcond ) def _eliminate_cond(state: _EliminationState) -> Bool[Array, ""]: """ If to continue the iterative Gaussian-Elimination procedure. On each iteration, we eliminate a basis vector from the left null space. We repeat until all basis vectors have been eliminated (the dimension of the null space is zero); once the number of iterations is the same as the rank of the original null space. .. note:: The reason for using a while loop, rather than scanning over the basis vectors, is due to the dimension of the null space being unknown at JIT compile time, preventing us from slicing the left singular vectors down to only those which form a basis for the left null space. :param state: Elimination state information :return: Boolean indicating if to continue/exit the elimination loop. """ return state.iteration < null_space_rank def _eliminate(state: _EliminationState) -> _EliminationState: """ Eliminate a basis from the left null space. At least one weight is zeroed (implicitly removed from the dataset), and one left null space basis vector eliminated on each iteration. The mass that is "lost" in weight zeroing/elimination is redistributed among the remaining non-zero weights to preserve the total mass/weight sum. If the procedure is repeated until all the left null space basis vectors are eliminated, the resulting weights (when combined with the original nodes) are a BFS to the recombination problem/linear-program. :param state: Elimination state information :return: Updated `state` information resulting from the elimination step. """ # Algorithm 6 - Chapter 3.3 of :cite:`tchernychova2016recombination` # Our Notation -> Their Notation # - `basis_index` (loop iteration) -> i # - `elimination_index` -> k^{(i)} # - `elimination_rescaling_factor` -> \alpha_{(i)} # - `updated_weights` -> \underline\Beta^{(i)} # - `null_space_basis_update` -> d_{l+1}^{(i)}\phi_1^{(i-1)} # - `updated_null_space_basis` -> \Psi^{(i)) _weights, null_space_basis, basis_index = state basis_vector = null_space_basis[basis_index] # Equation 3: Select the weight to eliminate. elimination_condition = jnp.where( basis_vector > 0, _weights / basis_vector, jnp.inf ) elimination_index = jnp.argmin(elimination_condition) elimination_rescaling_factor = elimination_condition[elimination_index] # Equation 4: Eliminate the selected weight and redistribute its mass. # NOTE: Equation 5 is implicit from Equation 4 and is performed outside # of `_eliminate` via `_coresubset_nodes`. updated_weights = _weights - elimination_rescaling_factor * basis_vector updated_weights = updated_weights.at[elimination_index].set(0) # Equations 6, 7 and 8: Update the Null space basis. null_space_basis_update = jnp.tensordot( null_space_basis[:, elimination_index], basis_vector / basis_vector[elimination_index], axes=0, ) updated_null_space_basis = null_space_basis - null_space_basis_update updated_null_space_basis = updated_null_space_basis.at[ :, elimination_index ].set(0) return _EliminationState( updated_weights, updated_null_space_basis, basis_index + 1 ) in_state = _EliminationState(safe_weights, largest_null_space_basis, 0) out_weights, *_ = jax.lax.while_loop(_eliminate_cond, _eliminate, in_state) coresubset_nodes = _coresubset_nodes( safe_push_forward_nodes, out_weights, indices, self.mode, is_affine_augmented=True, ) return Coresubset(coresubset_nodes, dataset), solver_state
def _push_forward( nodes: Shaped[Array, " n"], test_functions: Callable[[Array], Real[Array, " m-1"]] | None, augment: bool = True, ) -> Shaped[Array, "n m"]: r""" Push the 'nodes' forward through the 'test_functions'. :param nodes: The nodes to push-forward through the test-functions :param test_functions: A callable that applies a set of specified test-functions :math:`\Phi^\prime = \{\phi_1,\dots,\phi_{m-1}\}` where each function is a map :math:`\phi_i \colon \Omega \to \mathbb{R}`; a value of non implies the identity map :math:`\Phi^\prime \colon x \mapsto x`, and necessarily assumes that :math:`x \in \Omega \subseteq \mathbb{R}^{m-1}` :param augment: If to prepend the affine-augmentation test function :math:`\{x \mapsto 1\}` to the explicitly pushed forward nodes \Phi^\prime(x), to yield \Phi(x); default behaviour prepends the affine-augmentation function :return: The pushed-forward nodes. """ if test_functions is None: push_forward_nodes = nodes else: push_forward_nodes = jax.vmap(test_functions, in_axes=0)(nodes) if augment: shape, dtype = push_forward_nodes.shape[0], push_forward_nodes.dtype affine_augmentation = jnp.ones((shape,), dtype) push_forward_nodes = jnp.c_[affine_augmentation, push_forward_nodes] return push_forward_nodes def _co_linearize( nodes: Shaped[Array, "n m"], weights: Shaped[Array, " n"] ) -> tuple[Shaped[Array, "n m"], Shaped[Array, " n"], Shaped[Array, " n"]]: """ Make zero-weighted nodes co-linear with the maximum weighted node. Due to the static shape requirements imposed by JAX, we implicitly remove nodes by setting their corresponding weight to zero. This is sufficient in the recombination algorithm for all but one scenario, the computation of the null space basis. Because the zero-weighted nodes still exist in the node matrix, they influence the SVD and yield an erroneous null space basis. We ameliorate this problem by setting the zero-weighted nodes equal (co-linear) to the largest weighted node (an arbitrary but consistent choice). Because the nodes are now co-linear to each other and the largest weighted node, we know that at least all but one of them can be safely eliminated by the recombination procedure. Thus, the nodes become effectively "invisible" to the elimination procedure. The only caveat is that we don't know which of the equal nodes will be retained post elimination. To handle this, we keep an index (reference) from the zero-weighted nodes to the largest weighted node, and we redistribute the largest weight equally over all the "co-linearized" nodes (preserving the CoM and allowing any node to be eliminated). :param nodes: The nodes to co-linearize :param weights: The weights to apply the co-linearization correction to :return: The co-linearized nodes, corrected weights, and co-linearized-to-original reference indices. """ max_index = jnp.argmax(weights) non_zero_weights_mask = weights > 0 zero_weights_mask = 1 - non_zero_weights_mask n_zeros = zero_weights_mask.sum() # Create a new set of indices that replace the zero-weighted node indices with the # maximum weighted node's index. indices = jnp.arange(weights.shape[0]) indices *= non_zero_weights_mask indices += zero_weights_mask * max_index # Renormalize the maximum weight; ensures the weight sum is preserved under the new # (co-linearized) indices; prevents co-linearization from changing the weight sum. weights = weights.at[max_index].divide(n_zeros + 1) return nodes[indices], weights[indices], indices # pylint: disable=line-too-long # Credit: https://github.com/patrick-kidger/lineax/blob/9b923c8df6556551fedc7adeea7979b5c7b3ffb0/lineax/_solver/svd.py#L67 # noqa: E501 # for the rank determination code. # pylint: enable=line-too-long def _resolve_null_basis( nodes: Shaped[Array, "n m"], rcond: float | None = None, ) -> tuple[Shaped[Array, "n n"], Integer[Array, ""]]: r""" Resolve the largest left null space basis, and its rank, for passed the node matrix. By largest left null space basis, we mean the null space basis under the assumption that the rank of the null space is maximal (assumed to be ``n``). If the rank is not maximal, then only the first :math:`n - m^\prime` basis vectors will be actual basis vectors for the null space (where :math:`m^\prime` is the rank of the node matrix). The remaining "basis" vectors can, and should, be ignored in upstream computations by using the left null space rank value as a cut-off index. :param nodes: Matrix of nodes (m-vectors) whose null space is to be determined :param rcond: The relative condition number of the Matrix of nodes :return: The largest left null space basis and its rank, for the passed node matrix. """ q, s, _ = jsp.linalg.svd(nodes, full_matrices=True) _rcond = _resolve_rcond(nodes.shape, s.dtype, rcond) if s.size > 0: _rcond *= jnp.max(s[0]) mask = s > _rcond matrix_rank = sum(mask) null_space_rank = jnp.maximum(0, nodes.shape[0] - matrix_rank) largest_null_space_basis = q.T[::-1] return largest_null_space_basis, null_space_rank # pylint: disable=line-too-long # Credit: https://github.com/patrick-kidger/lineax/blob/9b923c8df6556551fedc7adeea7979b5c7b3ffb0/lineax/_misc.py#L34 # noqa: E501 # pylint: enable=line-too-long def _resolve_rcond( shape: tuple[int, ...], dtype: DTypeLike, rcond: float | None = None ) -> Float[Array, ""]: """ Resolve the relative condition number (rcond). :param shape: The shape of the matrix whose relative condition number to resolved :param dtype: The element dtype of the matrix whose rcond is to be resolved :param rcond: The relative condition number of a given matrix; if ``None``, ``rcond = dtype_floating_point_eps * max(shape)``; else if negative, ``rcond = dtype_floating_point_eps`` :return: The resolved relative condition number (rcond) """ epsilon = jnp.asarray(jnp.finfo(dtype).eps, dtype) if rcond is None: return epsilon * max(shape) return jnp.where(rcond < jnp.asarray(0), epsilon, rcond) def _coresubset_nodes( push_forward_nodes: Shaped[Array, "n m"], weights: Shaped[Array, " n"], indices: Shaped[Array, " n"], mode: Literal["implicit-explicit", "implicit", "explicit"], is_affine_augmented: bool = False, ) -> Data: r""" Determine the coresubset nodes based on the 'mode'. :param push_forward_nodes: The dataset push forward nodes :param weights: The coresubset weights :param mode: 'implicit-explicit' explicitly removes :math:`n - m` points, yielding a coreset of size :math:`m`, with :math:`m - m^\prime` zero-weighted (implicitly removed) points; 'implicit' explicitly removes no points, yielding a coreset of size :math:`n` with :math:`n - m^\prime` zero-weighted (implicitly removed) points; 'explicit' explicitly removes :math:`n - m^\prime` points, yielding a coreset of size :math:`m^\prime`, but unlike the other methods is not JIT compatible as the coreset size :math:`m^\prime` is unknown at compile time. :param is_affine_augmented: If the 'push_forward_nodes' include the :math:`\phi_1` affine-augmentation map. :return: The coresubset nodes as defined by the 'mode'. """ n, m = push_forward_nodes.shape m = m if is_affine_augmented else m + 1 if mode == "implicit-explicit": # Inside the JIT context we cannot explicitly remove all the non-zero # weights, because we don't know how many there will be a priori (`m^\prime` # is unknown until after the singular value decomposition is performed). # However, we do have an upper bound on the number of non-zero points # `min(n, m) \ge m^\prime`. Thus, we need only return the `min(n, m)` non-zero # weights where `min(n, m) - m^\prime` of these may be zero-weighted (implicitly # removed). The fill value is set to `argmin(weights)` to ensure we always index # a zero-weighted data point whenever the weight is zero. idx = jnp.flatnonzero(weights, size=min(n, m), fill_value=jnp.argmin(weights)) elif mode == "implicit": idx = jnp.flatnonzero(weights, size=n, fill_value=jnp.argmin(weights)) elif mode == "explicit": # Explicit mode is JIT incompatible try: idx = jnp.flatnonzero(weights) except jax.errors.ConcretizationTypeError as err: raise ValueError( "'explicit' mode is incompatible with transformations such as 'jax.jit'" ) from err else: # Should only get here if the `__check_init__`` has been skipped/avoided, or if # this function is called from an unexpected place. raise ValueError( "Invalid mode, expected 'implicit-explicit', 'implicit' or 'explicit'." ) return Data(indices[idx], weights[idx])
[docs] class TreeRecombination(RecombinationSolver[Data, None]): r""" Tree recombination based coresubset solver. Based on Algorithm 7 Chapter 3.3 of :cite:`tchernychova2016recombination`, which is an order of magnitude more efficient than Algorithm 5 in Chapter 3.2, originally introduced in :cite:`litterer2012recombination`. The time complexity is of order :math:`\mathcal{O}(\log_2(\frac{n}{c_r m}) m^3)`, where :math`c_r` is the `tree_reduction_factor`. The time complexity can be equivalently expressed as :math:`\mathcal{O}(m^3)`, using the same arguments as used in :class:`CaratheodoryRecombination`. .. note:: As the ratio of :math:`n / m` grows, the constant factor for the time complexity of :class:`TreeRecombination` increases at a logarithmic rate, rather than at a quadratic rate for plain :class:`CaratheodoryRecombination`. Hence, in general, we would expect :class:`TreeRecombination` to be the more efficient choice for all but the smallest values of :math:`n / m`. :param test_functions: the map :math:`\Phi^\prime = \{ \phi_1, \dots, \phi_{M-1} \}` where each :math:`\phi_i \colon \Omega \to \mathbb{R}` represents a linearly independent test-function; a value of :data:`None` implies the identity function (necessarily assuming :math:`\Omega \subseteq \mathbb{R}^{M-1}`) :param mode: 'implicit-explicit' explicitly removes :math:`n - m` points, yielding a coreset of size :math:`m`, with :math:`m - m^\prime` zero-weighted (implicitly removed) points; 'implicit' explicitly removes no points, yielding a coreset of size :math:`n` with :math:`n - m^\prime` zero-weighted (implicitly removed) points; 'explicit' explicitly removes :math:`n - m^\prime` points, yielding a coreset of size :math:`m^\prime`, but unlike the other methods is not JIT compatible as the coreset size :math:`m^\prime` is unknown at compile time. :param rcond: a relative condition number; any singular value :math:`s` below the threshold :math:`\text{rcond} * \text{max}(s)` is treated as equal to zero; if :code:`rcond is None`, it defaults to `floating point eps * max(n, d)` :param tree_reduction_factor: The factor by which each tree reduction step reduces the number of non-zero points; the remaining number of non-zero nodes, after performing recombination, is equal to `n_nodes / tree_reduction_factor`; """ rcond: float | None = None tree_reduction_factor: int = 2
[docs] @override def reduce( self, dataset: Data, solver_state: None = None ) -> tuple[Coresubset, None]: nodes, weights = jtu.tree_leaves(dataset.normalize(preserve_zeros=True)) # Push the nodes forward through the test-functions \Phi^\prime. push_forward_nodes = _push_forward(nodes, self.test_functions, augment=False) n, m = push_forward_nodes.shape # We don't apply the affine-augmentation test-function \phi_1 here, instead # deferring it to `CaratheodoryRecombination.reduce`. Thus, we have to manually # correct the value for `m`. padding, count, depth = _prepare_tree(n, m + 1, self.tree_reduction_factor) car_recomb_solver = CaratheodoryRecombination(rcond=self.rcond, mode="implicit") def _tree_reduce(_, state: tuple[Array, Array]) -> tuple[Array, Array]: """ Apply Tree-Based Caratheodory Recombination (Gaussian-Elimination). Partitions the dataset into 'count' clusters of size 'n / count' and then computes the cluster centroids. Caratheodory recombination is then performed on these centroids (rather than on the full dataset), with every node in the eliminated centroids' cluster being implicitly removed (given zero-weight). There are 'tree_reduction_factor * m' clusters, with each step reducing the number of remaining clusters down to 'm'. We can repeat the process until each cluster contains, at most, a single non-zero weighted point (at this point the recombination problem has been solved). :param _: Not used :param state: Tuple of node weights and indices; indices are passed to keep a correspondence between the original data indices and :return: Updated tuple of node weights and indices; weights are zeroed (implicitly removed) where appropriate; indices are shuffled to ensure balanced centroids in subsequent iterations (centroids are balanced when they are all constructed from subsets with as near to an equal number of non-zero weighted nodes as possible). """ _weights, _indices = state # Index weights to a centroid; argsort ensures that centroids are balanced. centroid_indices = jnp.argsort(_weights).reshape(count, -1, order="F") centroid_nodes, centroid_weights = _centroid( push_forward_nodes[_indices[centroid_indices]], _weights[centroid_indices], ) centroid_dataset = Data(centroid_nodes, centroid_weights) # Solve the measure reduction problem on the centroid dataset. centroid_coresubset, _ = car_recomb_solver.reduce(centroid_dataset) coresubset_indices = centroid_coresubset.unweighted_indices coresubset_weights = centroid_coresubset.points.weights # Propagate centroid coresubset weights to the underlying weights for each # centroid, as defined by `centroid_indices`. weight_update_indices = centroid_indices[coresubset_indices] weight_update = coresubset_weights / centroid_weights[coresubset_indices] updated_weights = _weights[weight_update_indices] * weight_update[..., None] # Maintain a correspondence between the original data indices and the sorted # indices, used to construct the balanced centroids. updated_indices = _indices[weight_update_indices.reshape(-1, order="F")] return updated_weights.reshape(-1, order="F"), updated_indices in_state = (jnp.pad(weights, (0, padding)), jnp.arange(n + padding)) out_weights, indices = jax.lax.fori_loop(0, depth, _tree_reduce, in_state) coresubset_nodes = _coresubset_nodes( push_forward_nodes, out_weights, indices, self.mode ) return Coresubset(coresubset_nodes, dataset), solver_state
def _prepare_tree( n: int, m: int, tree_reduction_factor: int = 2 ) -> tuple[int, int, int]: r""" Compute and apply dataset padding and compute tree count and depth. :param n: Number of nodes :param m: Number of test-functions :param tree_reduction_factor: The factor by which each tree reduction step reduces the number of non-zero points; the remaining number of non-zero nodes, after performing recombination, is equal to `n_nodes / tree_reduction_factor` :return: The required amount of padding, to allow reshaping of the nodes into equal sized clusters), the tree_count (number of clusters), and the maximum tree depth (number of tree_reduction iterations required to complete tree recombination) """ tree_count = tree_reduction_factor * m max_tree_depth = max(1, math.ceil(math.log(n / m, tree_reduction_factor))) padding = m * tree_reduction_factor**max_tree_depth - n return padding, tree_count, max_tree_depth @jax.vmap def _centroid( nodes: Shaped[Array, "tree_count n/tree_count m"], weights: Shaped[Array, "tree_count n/tree_count"], ) -> tuple[Shaped[Array, "n/tree_count m"], Shaped[Array, " n/tree_count"]]: """ Compute the centroid mass and node centre (centre-of-mass). :param nodes: A set of clustered nodes where the leading axis indexes each cluster, which this function vmaps over, and the middle axis indexes each node within a given cluster. :param weights: A set of clustered weights associated with each node; has the same index layout as the nodes. :return: Cluster centroid (centre-of-mass) and total cluster mass for all clusters """ centroid_nodes = jnp.nan_to_num(jnp.average(nodes, 0, weights)) centroid_weights = jnp.sum(weights) return centroid_nodes, centroid_weights