# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Scalar-valued kernel functions."""
from collections.abc import Callable
import equinox as eqx
import jax.numpy as jnp
from jax import Array, vmap
from jax.scipy.special import factorial
from jaxtyping import Shaped
from typing_extensions import override
from coreax.kernels.base import ProductKernel, ScalarValuedKernel, UniCompositeKernel
from coreax.util import squared_distance
[docs]
class LinearKernel(ScalarValuedKernel):
r"""
Define a linear kernel.
Given output scale :math:`\rho` and constant :math:`a`, the linear kernel
is defined as :math:`k: \mathbb{R}^d\times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = a + \rho (x)^T(y)`.
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
:param constant: Additive constant, :math:`a`, must be non-negative
"""
output_scale: float = eqx.field(default=1.0, converter=float)
constant: float = eqx.field(default=0.0, converter=float)
def __check_init__(self):
"""Check attributes are valid."""
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
if self.constant < 0:
raise ValueError("'constant' must be non-negative")
[docs]
@override
def compute_elementwise(self, x, y):
return self.output_scale * jnp.dot(x, y) + self.constant
[docs]
@override
def grad_x_elementwise(self, x, y):
return self.output_scale * jnp.asarray(y)
[docs]
@override
def grad_y_elementwise(self, x, y):
return self.output_scale * jnp.asarray(x)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
d = len(jnp.atleast_1d(x))
return jnp.array(self.output_scale * d)
[docs]
class PolynomialKernel(ScalarValuedKernel):
r"""
Define a polynomial kernel.
Given :math:`\rho =` ``output_scale``, :math:`c =` ``constant``, and
:math:`d=` ``degree``, the polynomial kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho (x^Ty + c)^d`.
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
:param constant: Additive constant, :math:`c`, must be non-negative
:param degree: Degree of kernel, must be a positive integer greater than 1
"""
output_scale: float = 1.0
constant: float = 0.0
degree: int = 2
def __check_init__(self):
"""Ensure degree is an integer greater than 1 and other attributes are valid."""
min_degree = 2
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
if self.constant < 0:
raise ValueError("'constant' must be non-negative")
if not isinstance(self.degree, int) or self.degree < min_degree:
raise ValueError("'degree' must be a positive integer greater than 1")
[docs]
@override
def compute_elementwise(self, x, y):
return self.output_scale * (jnp.dot(x, y) + self.constant) ** self.degree
[docs]
@override
def grad_x_elementwise(self, x, y):
return (
self.output_scale
* self.degree
* jnp.asarray(y)
* (jnp.dot(x, y) + self.constant) ** (self.degree - 1)
)
[docs]
@override
def grad_y_elementwise(self, x, y):
return (
self.output_scale
* self.degree
* jnp.asarray(x)
* (jnp.dot(x, y) + self.constant) ** (self.degree - 1)
)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
dot = jnp.dot(x, y)
body = dot + self.constant
d = len(jnp.asarray(x))
return (
self.output_scale
* self.degree
* (
((self.degree - 1) * dot * body ** (self.degree - 2))
+ (d * body ** (self.degree - 1))
)
)
[docs]
class ExponentialKernel(ScalarValuedKernel):
r"""
Define an exponential kernel.
Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
exponential kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * \exp(-\frac{||x-y||}{2 \lambda^2})` where
:math:`||\cdot||` is the usual :math:`L_2`-norm.
.. warning::
The exponential kernel is not differentiable when :math:`x=y`.
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
"""
length_scale: float = 1.0
output_scale: float = 1.0
def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
[docs]
@override
def compute_elementwise(self, x, y):
return self.output_scale * jnp.exp(
-jnp.linalg.norm(jnp.subtract(x, y)) / (2 * self.length_scale**2)
)
[docs]
@override
def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)
[docs]
@override
def grad_y_elementwise(self, x, y):
sub = jnp.subtract(x, y)
dist = jnp.linalg.norm(sub)
factor = 2 * self.length_scale**2
return self.output_scale * sub * jnp.exp(-dist / factor) / (factor * dist)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
d = len(jnp.atleast_1d(x))
sub = jnp.subtract(x, y)
dist = jnp.linalg.norm(sub)
factor = 2 * self.length_scale**2
exp = jnp.exp(-dist / factor)
first_term = (-exp * sub / dist**2) * ((1 / dist) + 1 / factor)
second_term = exp / dist
return (self.output_scale / factor) * (
jnp.dot(first_term, sub) + d * second_term
)
[docs]
class LaplacianKernel(ScalarValuedKernel):
r"""
Define a Laplacian kernel.
Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
Laplacian kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * \exp(-\frac{||x-y||_1}{2 \lambda^2})` where
:math:`||\cdot||_1` is the :math:`L_1`-norm.
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
"""
length_scale: float = eqx.field(default=1.0, converter=float)
output_scale: float = eqx.field(default=1.0, converter=float)
def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
[docs]
@override
def compute_elementwise(self, x, y):
return self.output_scale * jnp.exp(
-jnp.linalg.norm(jnp.subtract(x, y), ord=1) / (2 * self.length_scale**2)
)
[docs]
@override
def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)
[docs]
@override
def grad_y_elementwise(self, x, y):
return (
jnp.sign(jnp.subtract(x, y))
/ (2 * self.length_scale**2)
* self.compute_elementwise(x, y)
)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
k = self.compute_elementwise(x, y)
d = len(jnp.asarray(x))
return -d * k / (4 * self.length_scale**4)
[docs]
class SquaredExponentialKernel(ScalarValuedKernel):
r"""
Define a squared exponential kernel.
Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
squared exponential kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * \exp(-\frac{||x-y||^2}{2 \lambda^2})` where
:math:`||\cdot||` is the usual :math:`L_2`-norm.
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
"""
length_scale: float = eqx.field(default=1.0, converter=float)
output_scale: float = eqx.field(default=1.0, converter=float)
def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
[docs]
@override
def compute_elementwise(self, x, y):
return self.output_scale * jnp.exp(
-squared_distance(x, y) / (2 * self.length_scale**2)
)
[docs]
@override
def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)
[docs]
@override
def grad_y_elementwise(self, x, y):
return (
jnp.subtract(x, y) / self.length_scale**2 * self.compute_elementwise(x, y)
)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
k = self.compute_elementwise(x, y)
scale = 1 / self.length_scale**2
d = len(jnp.asarray(x))
return scale * k * (d - scale * squared_distance(x, y))
[docs]
def get_sqrt_kernel(self, dim: int) -> "SquaredExponentialKernel":
r"""
Return the square root kernel for this kernel.
The square root kernel for the squared exponential kernel is given in Table 1 of
:cite:`dwivedi2024kernelthinning` (it is equivalent to the Gaussian kernel).
Since the definition in the table does not support an arbitrary
`output_scale` for the original kernel, it has been derived from Definition
5: if the original kernel has an output scale of :math:`\rho`, the output
scale for the resulting square root kernel is multiplied by :math:`\sqrt{\rho}`.
:param dim: Dimension of the data.
"""
new_length_scale = self.length_scale / jnp.sqrt(2)
new_output_scale = jnp.sqrt(self.output_scale) * jnp.power(
2 / (jnp.pi * jnp.square(self.length_scale)), dim / 4
)
return SquaredExponentialKernel(
length_scale=new_length_scale, output_scale=new_output_scale
)
[docs]
class PCIMQKernel(ScalarValuedKernel):
r"""
Define a pre-conditioned inverse multi-quadric (PCIMQ) kernel.
Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
PCIMQ kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \frac{\rho}{\sqrt{1 + \frac{||x-y||^2}{2 \lambda^2}}}`
where :math:`||\cdot||` is the usual :math:`L_2`-norm.
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
"""
length_scale: float = eqx.field(default=1.0, converter=float)
output_scale: float = eqx.field(default=1.0, converter=float)
def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
[docs]
@override
def compute_elementwise(self, x, y):
scaling = 2 * self.length_scale**2
mq_array = squared_distance(x, y) / scaling
return self.output_scale / jnp.sqrt(1 + mq_array)
[docs]
@override
def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)
[docs]
@override
def grad_y_elementwise(self, x, y):
return (
self.output_scale
* jnp.subtract(x, y)
/ (2 * self.length_scale**2)
* (self.compute_elementwise(x, y) / self.output_scale) ** 3
)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
k = self.compute_elementwise(x, y) / self.output_scale
scale = 2 * self.length_scale**2
d = len(jnp.asarray(x))
return (
self.output_scale
/ scale
* (d * k**3 - 3 * k**5 * squared_distance(x, y) / scale)
)
[docs]
class RationalQuadraticKernel(ScalarValuedKernel):
r"""
Define a rational quadratic kernel.
Given :math:`\lambda =` ``length_scale``, :math:`\rho =` ``output_scale``, and
:math:`\alpha =` ``relative_weighting``, the rational quadratic kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * (1 + \frac{||x-y||^2}{2 \alpha \lambda^2})^{-\alpha}` where
:math:`||\cdot||` is the usual :math:`L_2`-norm.
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
:param relative_weighting: Parameter controlling the relative weighting of
large-scale and small-scale variations, :math:`\alpha`. As
:math:`alpha \to \infty` the rational quadratic kernel is identical to the
squared exponential kernel. Must be non-negative
"""
length_scale: float = 1.0
output_scale: float = 1.0
relative_weighting: float = 1.0
def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
if self.relative_weighting < 0:
raise ValueError("'relative_weighting' must be non-negative")
[docs]
@override
def compute_elementwise(self, x, y):
return (
self.output_scale
* (
1
+ squared_distance(x, y)
/ (2 * self.relative_weighting * self.length_scale**2)
)
** -self.relative_weighting
)
[docs]
@override
def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)
[docs]
@override
def grad_y_elementwise(self, x, y):
return (self.output_scale * jnp.subtract(x, y) / self.length_scale**2) * (
1
+ squared_distance(x, y)
/ (2 * self.relative_weighting * self.length_scale**2)
) ** (-self.relative_weighting - 1)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
d = len(jnp.atleast_1d(x))
sq_dist = squared_distance(x, y)
power = self.relative_weighting + 1
div = self.relative_weighting * self.length_scale**2
body = 1 + sq_dist / (2 * div)
factor = self.output_scale / self.length_scale**2
first_term = factor * body**-power
second_term = -(factor * power * sq_dist / div) * body ** -(power + 1)
return d * first_term + second_term
[docs]
class MaternKernel(ScalarValuedKernel):
r"""
Define Matérn kernel with smoothness parameter a multiple of :math:`\frac{1}{2}`.
Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
Matérn kernel with smoothness parameter :math:`\nu` set to be a multiple of
:math:`\frac{1}{2}`, i.e. :math:`\nu = p + \frac{1}{2}` where
:math:`p=` ``degree`` :math:`\in\mathbb{N}`, is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
.. math::
k(x, y) = \rho^2 * \exp\left(-\frac{\sqrt{2p+1}||x-y||}{\lambda}\right)
\frac{p!}{(2p)!}\sum_{i=0}^p\frac{(p+i)!}{i!(p-i)!}
\left(2\sqrt{2p+1}\frac{||x-y||}{\lambda}\right)^{p-i}
where :math:`||\cdot||` is the usual :math:`L_2`-norm.
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
:param degree: Kernel degree, :math:`p`, must be a non-negative integer
"""
length_scale: float = eqx.field(default=1.0, converter=float)
output_scale: float = eqx.field(default=1.0, converter=float)
degree: int = 1
def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
if not isinstance(self.degree, int) or self.degree < 0:
raise ValueError("'degree' must be a non-negative integer")
def _compute_summation_term(
self,
body: float,
iteration: Shaped[Array, " *number_of_iterations"] | int,
) -> Shaped[Array, ""]:
r"""
Compute the summation term of the Matérn kernel for a given iteration.
Given :math:`p`=``degree``:math:`\in\mathbb{N}`, compute
.. math::
\gamma := \sum_{i=0}^p\frac{(p+i)!}{i!(p-i)!}
\left(2\sqrt{2p+1}\frac{||x-y||}{\lambda}\right)^{p-i}.
:param body: Float representing
:math:`\left(\sqrt{2p+1}\frac{||x-y||}{\lambda}\right)`
:param iteration: Current iteration
:return: :math:`\gamma` as a zero-dimensional array
"""
factorial_term = factorial(self.degree + iteration) / (
factorial(iteration) * factorial(self.degree - iteration)
)
distance_term = (2 * body) ** (self.degree - iteration)
return factorial_term * distance_term
[docs]
@override
def compute_elementwise(self, x, y):
norm = jnp.linalg.norm(jnp.subtract(x, y))
body = (jnp.sqrt(2 * self.degree + 1) * norm) / self.length_scale
factor = (
self.output_scale**2
* jnp.exp(-body)
* factorial(self.degree)
/ factorial(2 * self.degree)
)
summation = 1.0
if self.degree > 0:
mapped_function = vmap(self._compute_summation_term, in_axes=(None, 0))
summation = mapped_function(body, jnp.arange(self.degree + 1)).sum()
return factor * summation
[docs]
class PeriodicKernel(ScalarValuedKernel):
r"""
Define a periodic kernel.
Given :math:`\lambda =` ``length_scale``, :math:`\rho =` ``output_scale``, and
:math:`p =` ``periodicity``, the periodic kernel is defined as
:math:`k: \mathbb{R}^d\times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * \exp(\frac{-2 \sin^2(\pi ||x-y||/p)}{\lambda^2})` where
:math:`||\cdot||` is the usual :math:`L_2`-norm.
.. warning::
The periodic kernel is not differentiable when :math:`x=y`.
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
:param periodicity: Parameter controlling the periodicity of the kernel :math:`p`
"""
length_scale: float = 1.0
output_scale: float = 1.0
periodicity: float = 1.0
def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
[docs]
@override
def compute_elementwise(self, x, y):
return self.output_scale * (
jnp.exp(
-2
* jnp.sin(
jnp.pi * jnp.linalg.norm(jnp.subtract(x, y)) / self.periodicity
)
** 2
/ self.length_scale**2
)
)
[docs]
@override
def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)
[docs]
@override
def grad_y_elementwise(self, x, y):
dist = jnp.linalg.norm(jnp.subtract(x, y))
body = jnp.pi * dist / self.periodicity
return (
(
4
* jnp.subtract(x, y)
* self.output_scale
* jnp.pi
/ (dist * self.periodicity * self.length_scale**2)
)
* jnp.sin(body)
* jnp.cos(body)
* jnp.exp(-(2 / self.length_scale**2) * jnp.sin(body) ** 2)
)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
d = len(jnp.atleast_1d(x))
sub = jnp.subtract(x, y)
dist = jnp.linalg.norm(sub)
factor = jnp.pi / self.periodicity
func_body = factor * dist
grad_factor = sub / dist
output_factor = 4 * factor * self.output_scale / self.length_scale**2
func_1 = 1 / dist
func_2 = jnp.sin(func_body)
func_3 = jnp.cos(func_body)
func_4 = jnp.exp(-(2 / self.length_scale**2) * func_2**2)
first_term = func_1 * func_2 * func_3 * func_4
second_term = (
-grad_factor * func_1**2 * func_2 * func_3 * func_4
- grad_factor * factor * func_1 * func_2**2 * func_4
+ grad_factor * factor * func_1 * func_3**2 * func_4
- (output_factor * sub * func_1**2 * func_2**2 * func_3**2 * func_4)
)
return output_factor * (d * first_term + jnp.dot(second_term, sub))
[docs]
class LocallyPeriodicKernel(ProductKernel):
r"""
Define a locally periodic kernel.
The periodic kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = r(x,y)l(x,y)` where :math:`r` is the periodic kernel and
:math:`l` is the squared exponential kernel.
.. warning::
The locally periodic kernel is not differentiable when :math:`x=y`.
:param periodic_length_scale: Periodic kernel smoothing/bandwidth parameter
:param periodic_output_scale: Periodic kernel normalisation constant
:param periodicity: Parameter controlling the periodicity of the Periodic kernel
:param squared_exponential_length_scale: SquaredExponential kernel
smoothing/bandwidth parameter]
:param squared_exponential_output_scale: SquaredExponential Kernel normalisation
constant
"""
def __init__(
self,
periodic_length_scale: float = 1.0,
periodic_output_scale: float = 1.0,
periodicity: float = 1.0,
squared_exponential_length_scale: float = 1.0,
squared_exponential_output_scale: float = 1.0,
):
"""Initialise LocallyPeriodicKernel with ProductKernel attributes."""
self.first_kernel = PeriodicKernel(
length_scale=periodic_length_scale,
output_scale=periodic_output_scale,
periodicity=periodicity,
)
self.second_kernel = SquaredExponentialKernel(
length_scale=squared_exponential_length_scale,
output_scale=squared_exponential_output_scale,
)
[docs]
class PoissonKernel(ScalarValuedKernel):
r"""
Define a Poisson kernel.
Given :math:`r=` ``index``, :math:`0 < r < 1`, and :math:`\rho =` ``output_scale``,
the Poisson kernel is defined as
:math:`k: [0, 2\pi) \times [0, 2\pi) \to \mathbb{R}`,
:math:`k(x, y) = \frac{\rho}{1 - 2r\cos(x-y) + r^2}`.
.. warning::
Unlike many other kernels in Coreax, the Poisson kernel is not defined on
arbitrary :math:`\mathbb{R}^d`, but instead a subset of the positive real line
:math:`[0, 2\pi)`. We do not check that inputs to methods in this class lie in
the correct domain, therefore unexpected behaviour may occur. For example,
passing :math:`n`-vectors to the `compute` method will be interpreted as one
observation of a `:math:`n`- dimensional vector, and not :math:`n` observations
of a one dimensional vector, and therefore would be an invalid use of this
kernel function.
:param index: Kernel parameter indexing the family of Poisson kernel functions
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
"""
index: float = eqx.field(default=0.5, converter=float)
output_scale: float = eqx.field(default=1.0, converter=float)
def __check_init__(self):
"""Check attributes are valid."""
if self.index <= 0 or self.index >= 1:
raise ValueError("'index' must be be between 0 and 1 exclusive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
[docs]
@override
def compute_elementwise(self, x, y):
return self.output_scale / (
1
- 2 * self.index * jnp.cos(jnp.linalg.norm(jnp.subtract(x, y)))
+ self.index**2
)
[docs]
@override
def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)
[docs]
@override
def grad_y_elementwise(self, x, y):
# Note that we do not take a norm here in order to maintain the dimensionality
# of the vectors x and y, this ensures calls to 'grad_y' and 'grad_x' have
# expected dimensionality.
distance = jnp.subtract(x, y)
return (2 * self.output_scale * self.index * jnp.sin(distance)) / (
1 - 2 * self.index * jnp.cos(distance) + self.index**2
) ** 2
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
distance = jnp.linalg.norm(jnp.subtract(x, y))
div = 1 - 2 * self.index * jnp.cos(distance) + self.index**2
first_term = (2 * self.output_scale * self.index * jnp.cos(distance)) / div**2
second_term = (
8 * self.output_scale * self.index**2 * jnp.sin(distance) ** 2
) / div**3
return first_term - second_term
[docs]
class SteinKernel(UniCompositeKernel):
r"""
Define the Stein kernel, i.e. the application of the Stein operator.
.. math::
\mathcal{A}_\mathbb{P}(g(\mathbf{x})) := \nabla_\mathbf{x} g(\mathbf{x})
+ g(\mathbf{x}) \nabla_\mathbf{x} \log f_X(\mathbf{x})^\intercal
w.r.t. probability measure :math:`\mathbb{P}` to the base kernel
:math:`k(\mathbf{x}, \mathbf{y})`. Here, differentiable vector-valued
:math:`g: \mathbb{R}^d \to \mathbb{R}^d`, and
:math:`\nabla_\mathbf{x} \log f_X(\mathbf{x})` is the *score function* of measure
:math:`\mathbb{P}`.
:math:`\mathbb{P}` is assumed to admit a density function :math:`f_X` w.r.t.
d-dimensional Lebesgue measure. The score function is assumed to be Lipschitz.
The key property of a Stein operator is zero expectation under
:math:`\mathbb{P}`, i.e.
:math:`\mathbb{E}_\mathbb{P}[\mathcal{A}_\mathbb{P} f(\mathbf{x})]`, for
positive differentiable :math:`f_X`.
The Stein kernel for base kernel :math:`k(\mathbf{x}, \mathbf{y})` is defined as
.. math::
k_\mathbb{P}(\mathbf{x}, \mathbf{y}) = \nabla_\mathbf{x} \cdot
\nabla_\mathbf{y}
k(\mathbf{x}, \mathbf{y}) + \nabla_\mathbf{x} \log f_X(\mathbf{x})
\cdot \nabla_\mathbf{y} k(\mathbf{x}, \mathbf{y}) + \nabla_\mathbf{y} \log
f_X(\mathbf{y}) \cdot \nabla_\mathbf{x} k(\mathbf{x}, \mathbf{y}) +
(\nabla_\mathbf{x} \log f_X(\mathbf{x}) \cdot \nabla_\mathbf{y} \log
f_X(\mathbf{y})) k(\mathbf{x}, \mathbf{y}).
This kernel requires a 'base' kernel to evaluate. The base kernel can be any
other implemented subclass of the Kernel abstract base class; even another Stein
kernel.
The score function
:math:`\nabla_\mathbf{x} \log f_X: \mathbb{R}^d \to \mathbb{R}^d` can be any
suitable Lipschitz score function, e.g. one that is learned from score matching
(:class:`~coreax.score_matching.ScoreMatching`), computed explicitly from a density
function, or known analytically.
:param base_kernel: Initialised kernel object with which to evaluate
the Stein kernel
:param score_function: A vector-valued callable defining a score function
:math:`\mathbb{R}^d \to \mathbb{R}^d`
"""
score_function: Callable[
[Shaped[Array, " n d"] | Shaped[Array, ""] | float | int],
Shaped[Array, " n d"] | Shaped[Array, " 1 1"],
]
[docs]
@override
def compute_elementwise(self, x, y):
k = self.base_kernel.compute_elementwise(x, y)
div = self.base_kernel.divergence_x_grad_y_elementwise(x, y)
gkx = self.base_kernel.grad_x_elementwise(x, y)
gky = self.base_kernel.grad_y_elementwise(x, y)
score_x = self.score_function(x)
score_y = self.score_function(y)
return div + gkx @ score_y + gky @ score_x + k * score_x @ score_y