# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Scalar-valued kernel functions."""
from typing import Callable, Union
import equinox as eqx
import jax.numpy as jnp
from jax import Array, vmap
from jax.scipy.special import factorial
from jaxtyping import Shaped
from typing_extensions import override
from coreax.kernels.base import ProductKernel, ScalarValuedKernel, UniCompositeKernel
from coreax.util import squared_distance
[docs]
class LinearKernel(ScalarValuedKernel):
r"""
Define a linear kernel.
Given :math:`\rho`=` ``output_scale``and :math:`a=` ``constant`` the linear kernel
is defined as :math:`k: \mathbb{R}^d\times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = a + \rho (x)^T(y)`.
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
:param constant: Additive constant, :math:`a`, must be non-negative
"""
output_scale: float = eqx.field(default=1.0, converter=float)
constant: float = eqx.field(default=0.0, converter=float)
def __check_init__(self):
"""Check attributes are valid."""
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
if self.constant < 0:
raise ValueError("'constant' must be non-negative")
[docs]
@override
def compute_elementwise(self, x, y):
return self.output_scale * jnp.dot(x, y) + self.constant
[docs]
@override
def grad_x_elementwise(self, x, y):
return self.output_scale * jnp.asarray(y)
[docs]
@override
def grad_y_elementwise(self, x, y):
return self.output_scale * jnp.asarray(x)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
d = len(jnp.atleast_1d(x))
return jnp.array(self.output_scale * d)
[docs]
class PolynomialKernel(ScalarValuedKernel):
r"""
Define a polynomial kernel.
Given :math:`\rho =` ``output_scale``, :math:`c =` ``constant``, and
:math:`d=` ``degree``, the polynomial kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho (x^Ty + c)^d`.
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
:param constant: Additive constant, :math:`c`, must be non-negative
:param degree: Degree of kernel, must be a positive integer greater than 1
"""
output_scale: float = 1.0
constant: float = 0.0
degree: int = 2
def __check_init__(self):
"""Ensure degree is an integer greater than 1 and other attributes are valid."""
min_degree = 2
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
if self.constant < 0:
raise ValueError("'constant' must be non-negative")
if not isinstance(self.degree, int) or self.degree < min_degree:
raise ValueError("'degree' must be a positive integer greater than 1")
[docs]
@override
def compute_elementwise(self, x, y):
return self.output_scale * (jnp.dot(x, y) + self.constant) ** self.degree
[docs]
@override
def grad_x_elementwise(self, x, y):
return (
self.output_scale
* self.degree
* jnp.asarray(y)
* (jnp.dot(x, y) + self.constant) ** (self.degree - 1)
)
[docs]
@override
def grad_y_elementwise(self, x, y):
return (
self.output_scale
* self.degree
* jnp.asarray(x)
* (jnp.dot(x, y) + self.constant) ** (self.degree - 1)
)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
dot = jnp.dot(x, y)
body = dot + self.constant
d = len(jnp.asarray(x))
return (
self.output_scale
* self.degree
* (
((self.degree - 1) * dot * body ** (self.degree - 2))
+ (d * body ** (self.degree - 1))
)
)
[docs]
class ExponentialKernel(ScalarValuedKernel):
r"""
Define an exponential kernel.
Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
exponential kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * \exp(-\frac{||x-y||}{2 \lambda^2})` where
:math:`||\cdot||` is the usual :math:`L_2`-norm.
.. warning::
The exponential kernel is not differentiable when :math:`x=y`.
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
"""
length_scale: float = 1.0
output_scale: float = 1.0
def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
[docs]
@override
def compute_elementwise(self, x, y):
return self.output_scale * jnp.exp(
-jnp.linalg.norm(jnp.subtract(x, y)) / (2 * self.length_scale**2)
)
[docs]
@override
def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)
[docs]
@override
def grad_y_elementwise(self, x, y):
sub = jnp.subtract(x, y)
dist = jnp.linalg.norm(sub)
factor = 2 * self.length_scale**2
return self.output_scale * sub * jnp.exp(-dist / factor) / (factor * dist)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
d = len(jnp.atleast_1d(x))
sub = jnp.subtract(x, y)
dist = jnp.linalg.norm(sub)
factor = 2 * self.length_scale**2
exp = jnp.exp(-dist / factor)
first_term = (-exp * sub / dist**2) * ((1 / dist) + 1 / factor)
second_term = exp / dist
return (self.output_scale / factor) * (
jnp.dot(first_term, sub) + d * second_term
)
[docs]
class LaplacianKernel(ScalarValuedKernel):
r"""
Define a Laplacian kernel.
Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
Laplacian kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * \exp(-\frac{||x-y||_1}{2 \lambda^2})` where
:math:`||\cdot||_1` is the :math:`L_1`-norm.
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
"""
length_scale: float = eqx.field(default=1.0, converter=float)
output_scale: float = eqx.field(default=1.0, converter=float)
def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
[docs]
@override
def compute_elementwise(self, x, y):
return self.output_scale * jnp.exp(
-jnp.linalg.norm(jnp.subtract(x, y), ord=1) / (2 * self.length_scale**2)
)
[docs]
@override
def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)
[docs]
@override
def grad_y_elementwise(self, x, y):
return (
jnp.sign(jnp.subtract(x, y))
/ (2 * self.length_scale**2)
* self.compute_elementwise(x, y)
)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
k = self.compute_elementwise(x, y)
d = len(jnp.asarray(x))
return -d * k / (4 * self.length_scale**4)
[docs]
class SquaredExponentialKernel(ScalarValuedKernel):
r"""
Define a squared exponential kernel.
Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
squared exponential kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * \exp(-\frac{||x-y||^2}{2 \lambda^2})` where
:math:`||\cdot||` is the usual :math:`L_2`-norm.
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
"""
length_scale: float = eqx.field(default=1.0, converter=float)
output_scale: float = eqx.field(default=1.0, converter=float)
def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
[docs]
@override
def compute_elementwise(self, x, y):
return self.output_scale * jnp.exp(
-squared_distance(x, y) / (2 * self.length_scale**2)
)
[docs]
@override
def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)
[docs]
@override
def grad_y_elementwise(self, x, y):
return (
jnp.subtract(x, y) / self.length_scale**2 * self.compute_elementwise(x, y)
)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
k = self.compute_elementwise(x, y)
scale = 1 / self.length_scale**2
d = len(jnp.asarray(x))
return scale * k * (d - scale * squared_distance(x, y))
[docs]
class PCIMQKernel(ScalarValuedKernel):
r"""
Define a pre-conditioned inverse multi-quadric (PCIMQ) kernel.
Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
PCIMQ kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \frac{\rho}{\sqrt{1 + \frac{||x-y||^2}{2 \lambda^2}}}
where :math:`||\cdot||` is the usual :math:`L_2`-norm.
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
"""
length_scale: float = eqx.field(default=1.0, converter=float)
output_scale: float = eqx.field(default=1.0, converter=float)
def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
[docs]
@override
def compute_elementwise(self, x, y):
scaling = 2 * self.length_scale**2
mq_array = squared_distance(x, y) / scaling
return self.output_scale / jnp.sqrt(1 + mq_array)
[docs]
@override
def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)
[docs]
@override
def grad_y_elementwise(self, x, y):
return (
self.output_scale
* jnp.subtract(x, y)
/ (2 * self.length_scale**2)
* (self.compute_elementwise(x, y) / self.output_scale) ** 3
)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
k = self.compute_elementwise(x, y) / self.output_scale
scale = 2 * self.length_scale**2
d = len(jnp.asarray(x))
return (
self.output_scale
/ scale
* (d * k**3 - 3 * k**5 * squared_distance(x, y) / scale)
)
[docs]
class RationalQuadraticKernel(ScalarValuedKernel):
r"""
Define a rational quadratic kernel.
Given :math:`\lambda =` ``length_scale``, :math:`\rho =` ``output_scale``, and
:math:`\alpha =` ``relative_weighting``, the rational quadratic kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * (1 + \frac{||x-y||^2}{2 \alpha \lambda^2})^{-\alpha}` where
:math:`||\cdot||` is the usual :math:`L_2`-norm.
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
:param relative_weighting: Parameter controlling the relative weighting of
large-scale and small-scale variations, :math:`\alpha`. As
:math:`alpha \to \infty` the rational quadratic kernel is identical to the
squared exponential kernel. Must be non-negative
"""
length_scale: float = 1.0
output_scale: float = 1.0
relative_weighting: float = 1.0
def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
if self.relative_weighting < 0:
raise ValueError("'relative_weighting' must be non-negative")
[docs]
@override
def compute_elementwise(self, x, y):
return (
self.output_scale
* (
1
+ squared_distance(x, y)
/ (2 * self.relative_weighting * self.length_scale**2)
)
** -self.relative_weighting
)
[docs]
@override
def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)
[docs]
@override
def grad_y_elementwise(self, x, y):
return (self.output_scale * jnp.subtract(x, y) / self.length_scale**2) * (
1
+ squared_distance(x, y)
/ (2 * self.relative_weighting * self.length_scale**2)
) ** (-self.relative_weighting - 1)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
d = len(jnp.atleast_1d(x))
sq_dist = squared_distance(x, y)
power = self.relative_weighting + 1
div = self.relative_weighting * self.length_scale**2
body = 1 + sq_dist / (2 * div)
factor = self.output_scale / self.length_scale**2
first_term = factor * body**-power
second_term = -(factor * power * sq_dist / div) * body ** -(power + 1)
return d * first_term + second_term
[docs]
class MaternKernel(ScalarValuedKernel):
r"""
Define Matérn kernel with smoothness parameter a multiple of :math:`\frac{1}{2}`.
Given :math:`\lambda =` ``length_scale`` and :math:`\rho =` ``output_scale``, the
Matérn kernel with smoothness parameter :math:`\nu` set to be a multiple of
:math:`\frac{1}{2}`, i.e. :math:`\nu = p + \frac{1}{2}` where
:math:`p`=` ``degree`` `:`math:`\in\mathbb{N}`, is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
.. math::
k(x, y) = \rho^2 * \exp\left(-\frac{\sqrt{2p+1}||x-y||}{\lambda}\right)
\frac{p!}{(2p)!}\sum_{i=0}^p\frac{(p+i)!}{i!(p-i)!}
\left(2\sqrt{2p+1}\frac{||x-y||}{\lambda}\right)^{p-i}
where :math:`||\cdot||` is the usual :math:`L_2`-norm.
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
:param degree: Kernel degree, :math:`p`, must be a non-negative integer
"""
length_scale: float = eqx.field(default=1.0, converter=float)
output_scale: float = eqx.field(default=1.0, converter=float)
degree: int = 1
def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
if not isinstance(self.degree, int) or self.degree < 0:
raise ValueError("'degree' must be a non-negative integer")
def _compute_summation_term(
self,
body: float,
iteration: Union[Shaped[Array, " *number_of_iterations"], int],
) -> Shaped[Array, ""]:
r"""
Compute the summation term of the Matérn kernel for a given iteration.
Given :math:`p`=``degree``:math:`\in\mathbb{N}`, compute
.. math::
\gamma := \sum_{i=0}^p\frac{(p+i)!}{i!(p-i)!}
\left(2\sqrt{2p+1}\frac{||x-y||}{\lambda}\right)^{p-i}.
:param body: Float representing
:math:`\left(\sqrt{2p+1}\frac{||x-y||}{\lambda}\right)`
:param iteration: Current iteration
:return: :math:`\gamma` as a zero-dimensional array
"""
factorial_term = factorial(self.degree + iteration) / (
factorial(iteration) * factorial(self.degree - iteration)
)
distance_term = (2 * body) ** (self.degree - iteration)
return factorial_term * distance_term
[docs]
@override
def compute_elementwise(self, x, y):
norm = jnp.linalg.norm(jnp.subtract(x, y))
body = (jnp.sqrt(2 * self.degree + 1) * norm) / self.length_scale
factor = (
self.output_scale**2
* jnp.exp(-body)
* factorial(self.degree)
/ factorial(2 * self.degree)
)
summation = 1.0
if self.degree > 0:
mapped_function = vmap(self._compute_summation_term, in_axes=(None, 0))
summation = mapped_function(body, jnp.arange(self.degree + 1)).sum()
return factor * summation
[docs]
class PeriodicKernel(ScalarValuedKernel):
r"""
Define a periodic kernel.
Given :math:`\lambda =` ``length_scale``, :math:`\rho =` ``output_scale``, and
:math:`p =` ``periodicity``, the periodic kernel is defined as
:math:`k: \mathbb{R}^d\times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = \rho * \exp(\frac{-2 \sin^2(\pi ||x-y||/p)}{\lambda^2})` where
:math:`||\cdot||` is the usual :math:`L_2`-norm.
.. warning::
The periodic kernel is not differentiable when :math:`x=y`.
:param length_scale: Kernel smoothing/bandwidth parameter, :math:`\lambda`, must be
positive
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
:param periodicity: Parameter controlling the periodicity of the kernel :math:`p`
"""
length_scale: float = 1.0
output_scale: float = 1.0
periodicity: float = 1.0
def __check_init__(self):
"""Check attributes are valid."""
if self.length_scale <= 0:
raise ValueError("'length_scale' must be positive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
[docs]
@override
def compute_elementwise(self, x, y):
return self.output_scale * (
jnp.exp(
-2
* jnp.sin(
jnp.pi * jnp.linalg.norm(jnp.subtract(x, y)) / self.periodicity
)
** 2
/ self.length_scale**2
)
)
[docs]
@override
def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)
[docs]
@override
def grad_y_elementwise(self, x, y):
dist = jnp.linalg.norm(jnp.subtract(x, y))
body = jnp.pi * dist / self.periodicity
return (
(
4
* jnp.subtract(x, y)
* self.output_scale
* jnp.pi
/ (dist * self.periodicity * self.length_scale**2)
)
* jnp.sin(body)
* jnp.cos(body)
* jnp.exp(-(2 / self.length_scale**2) * jnp.sin(body) ** 2)
)
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
d = len(jnp.atleast_1d(x))
sub = jnp.subtract(x, y)
dist = jnp.linalg.norm(sub)
factor = jnp.pi / self.periodicity
func_body = factor * dist
grad_factor = sub / dist
output_factor = 4 * factor * self.output_scale / self.length_scale**2
func_1 = 1 / dist
func_2 = jnp.sin(func_body)
func_3 = jnp.cos(func_body)
func_4 = jnp.exp(-(2 / self.length_scale**2) * func_2**2)
first_term = func_1 * func_2 * func_3 * func_4
second_term = (
-grad_factor * func_1**2 * func_2 * func_3 * func_4
- grad_factor * factor * func_1 * func_2**2 * func_4
+ grad_factor * factor * func_1 * func_3**2 * func_4
- (output_factor * sub * func_1**2 * func_2**2 * func_3**2 * func_4)
)
return output_factor * (d * first_term + jnp.dot(second_term, sub))
[docs]
class LocallyPeriodicKernel(ProductKernel):
r"""
Define a locally periodic kernel.
The periodic kernel is defined as
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}`,
:math:`k(x, y) = r(x,y)l(x,y)` where :math:`r` is the periodic kernel and
:math:`l` is the squared exponential kernel.
.. warning::
The locally periodic kernel is not differentiable when :math:`x=y`.
:param periodic_length_scale: Periodic kernel smoothing/bandwidth parameter
:param periodic_output_scale: Periodic kernel normalisation constant
:param periodicity: Parameter controlling the periodicity of the Periodic kernel
:param squared_exponential_length_scale: SquaredExponential kernel
smoothing/bandwidth parameter]
:param squared_exponential_output_scale: SquaredExponential Kernel normalisation
constant
"""
def __init__(
self,
periodic_length_scale: float = 1.0,
periodic_output_scale: float = 1.0,
periodicity: float = 1.0,
squared_exponential_length_scale: float = 1.0,
squared_exponential_output_scale: float = 1.0,
):
"""Initialise LocallyPeriodicKernel with ProductKernel attributes."""
self.first_kernel = PeriodicKernel(
length_scale=periodic_length_scale,
output_scale=periodic_output_scale,
periodicity=periodicity,
)
self.second_kernel = SquaredExponentialKernel(
length_scale=squared_exponential_length_scale,
output_scale=squared_exponential_output_scale,
)
[docs]
class PoissonKernel(ScalarValuedKernel):
r"""
Define a Poisson kernel.
Given :math:`r=` ``index``, :math:`0 < r < 1`, and :math:`\rho =` ``output_scale``,
the Poisson kernel is defined as
:math:`k: [0, 2\pi) \times [0, 2\pi) \to \mathbb{R}`,
:math:`k(x, y) = \frac{\rho}{1 - 2r\cos(x-y) + r^2}`.
.. warning::
Unlike many other kernels in Coreax, the Poisson kernel is not defined on
arbitrary :math:`\mathbb{R}^d`, but instead a subset of the positive real line
:math:`[0, 2\pi)`. We do not check that inputs to methods in this class lie in
the correct domain, therefore unexpected behaviour may occur. For example,
passing :math:`n`-vectors to the `compute` method will be interpreted as one
observation of a `:math:`n`- dimensional vector, and not :math:`n` observations
of a one dimensional vector, and therefore would be an invalid use of this
kernel function.
:param index: Kernel parameter indexing the family of Poisson kernel functions
:param output_scale: Kernel normalisation constant, :math:`\rho`, must be positive
"""
index: float = eqx.field(default=0.5, converter=float)
output_scale: float = eqx.field(default=1.0, converter=float)
def __check_init__(self):
"""Check attributes are valid."""
if self.index <= 0 or self.index >= 1:
raise ValueError("'index' must be be between 0 and 1 exclusive")
if self.output_scale <= 0:
raise ValueError("'output_scale' must be positive")
[docs]
@override
def compute_elementwise(self, x, y):
return self.output_scale / (
1
- 2 * self.index * jnp.cos(jnp.linalg.norm(jnp.subtract(x, y)))
+ self.index**2
)
[docs]
@override
def grad_x_elementwise(self, x, y):
return -self.grad_y_elementwise(x, y)
[docs]
@override
def grad_y_elementwise(self, x, y):
# Note that we do not take a norm here in order to maintain the dimensionality
# of the vectors x and y, this ensures calls to 'grad_y' and 'grad_x' have
# expected dimensionality.
distance = jnp.subtract(x, y)
return (2 * self.output_scale * self.index * jnp.sin(distance)) / (
1 - 2 * self.index * jnp.cos(distance) + self.index**2
) ** 2
[docs]
@override
def divergence_x_grad_y_elementwise(self, x, y):
distance = jnp.linalg.norm(jnp.subtract(x, y))
div = 1 - 2 * self.index * jnp.cos(distance) + self.index**2
first_term = (2 * self.output_scale * self.index * jnp.cos(distance)) / div**2
second_term = (
8 * self.output_scale * self.index**2 * jnp.sin(distance) ** 2
) / div**3
return first_term - second_term
[docs]
class SteinKernel(UniCompositeKernel):
r"""
Define the Stein kernel, i.e. the application of the Stein operator.
.. math::
\mathcal{A}_\mathbb{P}(g(\mathbf{x})) := \nabla_\mathbf{x} g(\mathbf{x})
+ g(\mathbf{x}) \nabla_\mathbf{x} \log f_X(\mathbf{x})^\intercal
w.r.t. probability measure :math:`\mathbb{P}` to the base kernel
:math:`k(\mathbf{x}, \mathbf{y})`. Here, differentiable vector-valued
:math:`g: \mathbb{R}^d \to \mathbb{R}^d`, and
:math:`\nabla_\mathbf{x} \log f_X(\mathbf{x})` is the *score function* of measure
:math:`\mathbb{P}`.
:math:`\mathbb{P}` is assumed to admit a density function :math:`f_X` w.r.t.
d-dimensional Lebesgue measure. The score function is assumed to be Lipschitz.
The key property of a Stein operator is zero expectation under
:math:`\mathbb{P}`, i.e.
:math:`\mathbb{E}_\mathbb{P}[\mathcal{A}_\mathbb{P} f(\mathbf{x})]`, for
positive differentiable :math:`f_X`.
The Stein kernel for base kernel :math:`k(\mathbf{x}, \mathbf{y})` is defined as
.. math::
k_\mathbb{P}(\mathbf{x}, \mathbf{y}) = \nabla_\mathbf{x} \cdot
\nabla_\mathbf{y}
k(\mathbf{x}, \mathbf{y}) + \nabla_\mathbf{x} \log f_X(\mathbf{x})
\cdot \nabla_\mathbf{y} k(\mathbf{x}, \mathbf{y}) + \nabla_\mathbf{y} \log
f_X(\mathbf{y}) \cdot \nabla_\mathbf{x} k(\mathbf{x}, \mathbf{y}) +
(\nabla_\mathbf{x} \log f_X(\mathbf{x}) \cdot \nabla_\mathbf{y} \log
f_X(\mathbf{y})) k(\mathbf{x}, \mathbf{y}).
This kernel requires a 'base' kernel to evaluate. The base kernel can be any
other implemented subclass of the Kernel abstract base class; even another Stein
kernel.
The score function
:math:`\nabla_\mathbf{x} \log f_X: \mathbb{R}^d \to \mathbb{R}^d` can be any
suitable Lipschitz score function, e.g. one that is learned from score matching
(:class:`~coreax.score_matching.ScoreMatching`), computed explicitly from a density
function, or known analytically.
:param base_kernel: Initialised kernel object with which to evaluate
the Stein kernel
:param score_function: A vector-valued callable defining a score function
:math:`\mathbb{R}^d \to \mathbb{R}^d`
"""
score_function: Callable[[Shaped[Array, " n d"]], Shaped[Array, " n d"]]
[docs]
@override
def compute_elementwise(self, x, y):
k = self.base_kernel.compute_elementwise(x, y)
div = self.base_kernel.divergence_x_grad_y_elementwise(x, y)
gkx = self.base_kernel.grad_x_elementwise(x, y)
gky = self.base_kernel.grad_y_elementwise(x, y)
score_x = self.score_function(x)
score_y = self.score_function(y)
return div + gkx @ score_y + gky @ score_x + k * score_x @ score_y