# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example coreset generation using a video of a pouncing cat.
This example showcases how a coreset can be generated from video data. In this context,
a coreset is a set of frames that best capture the information in the original video.
Firstly, principal component analysis (PCA) is applied to the video data to reduce
dimensionality. Then, a coreset is generated using Stein kernel herding, with a
SquaredExponentialKernel base kernel. The score function (gradient of the log-density
function) for the Stein kernel is estimated by applying kernel density estimation (KDE)
to the data, and then taking gradients.
To reduce computational requirements, a map reduce approach is used, splitting the
original dataset into distinct segments, with each segment handled on a different
process.
The coreset attained from Stein kernel herding is compared to a coreset generated via
uniform random sampling. Coreset quality is measured using maximum mean discrepancy
(MMD).
"""
from pathlib import Path
import equinox as eqx
import imageio
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax import Array, random
from sklearn.decomposition import PCA
from coreax.data import Data
from coreax.kernels import SquaredExponentialKernel, SteinKernel, median_heuristic
from coreax.metrics import MMD
from coreax.score_matching import KernelDensityMatching
from coreax.solvers import KernelHerding, MapReduce, RandomSample
# Examples are written to be easy to read, copy and paste by users, so we ignore the
# pylint warnings raised that go against this approach
# pylint: disable=too-many-statements
# pylint: disable=too-many-locals
# pylint: disable=duplicate-code
[docs]
def main(
in_path: Path = Path("../examples/data/pounce/pounce.gif"),
out_path: Path | None = None,
) -> tuple[float, float]:
"""
Run the 'pounce' example for video sampling with Stein kernel herding.
Take a video of a pouncing cat, apply PCA and then generate a coreset using
Stein kernel herding. Compare the result from this to a coreset generated
via uniform random sampling. Coreset quality is measured using maximum mean
discrepancy (MMD).
To reduce computational requirements, a map reduce approach is used, splitting the
original dataset into distinct segments, with each segment handled on a different
process.
:param in_path: Path to directory containing input video, assumed relative to this
module file unless an absolute path is given
:param out_path: Path to save output to, if not :data:`None`, assumed relative to
this module file unless an absolute path is given
:return: Coreset MMD, random sample MMD
"""
# Convert input and absolute paths to absolute paths
if not in_path.is_absolute():
in_path = Path(__file__).parent.joinpath(in_path)
if out_path is not None and not out_path.is_absolute():
out_path = Path(__file__).parent.joinpath(out_path)
# Create output directory
if out_path is not None:
out_path.mkdir(exist_ok=True)
# Read in the data as a video. Frame 0 is missing A from RGBA.
_, *image_data = imageio.v2.mimread(in_path)
raw_data = np.asarray(image_data)
raw_data_reshaped = raw_data.reshape(raw_data.shape[0], -1)
# Fix random behaviour
random_seed = 1_989
np.random.seed(random_seed)
# Run PCA to reduce the dimension of the images whilst minimising effects on some of
# the statistical properties, i.e. variance.
num_principle_components = 25
pca = PCA(num_principle_components)
principle_components_data = pca.fit_transform(raw_data_reshaped)
# Setup the original data object
data = Data(principle_components_data)
# Request a 10 frame summary of the video
coreset_size = 10
# Set the length_scale parameter of the underlying squared exponential kernel
num_points_length_scale_selection = min(principle_components_data.shape[0], 1_000)
generator = np.random.default_rng(random_seed)
idx = generator.choice(
principle_components_data.shape[0],
num_points_length_scale_selection,
replace=False,
)
length_scale: Array = median_heuristic(principle_components_data[idx])
# Learn a score function via kernel density estimation
kernel_density_score_matcher = KernelDensityMatching(
length_scale=length_scale.item()
)
score_function = kernel_density_score_matcher.match(
principle_components_data[idx, :]
)
# Run kernel herding with a Stein kernel
sample_key = random.key(random_seed)
herding_solver = KernelHerding(
coreset_size,
kernel=SteinKernel(
SquaredExponentialKernel(length_scale=length_scale),
score_function=score_function,
),
)
mapped_herding_solver = MapReduce(herding_solver, leaf_size=20)
herding_coreset, _ = eqx.filter_jit(mapped_herding_solver.reduce)(data)
# Get and sort the coreset indices ready for producing the output video
coreset_indices_herding = jnp.sort(herding_coreset.unweighted_indices)
# Generate a coreset via uniform random sampling for comparison
random_solver = RandomSample(coreset_size, sample_key, unique=True)
random_coreset, _ = eqx.filter_jit(random_solver.reduce)(data)
# Define a reference kernel to use for comparisons of MMD. We'll use a normalised
# SquaredExponentialKernel (which is also a Gaussian kernel)
print("Computing MMD...")
mmd_kernel = SquaredExponentialKernel(
length_scale=length_scale,
output_scale=1.0 / (length_scale * jnp.sqrt(2.0 * jnp.pi)),
)
# Compute the MMD between the original data and the coreset generated via herding
mmd_metric = MMD(kernel=mmd_kernel)
maximum_mean_discrepancy_herding = herding_coreset.compute_metric(mmd_metric)
# Compute the MMD between the original data and the coreset generated via random
# sampling
maximum_mean_discrepancy_random = random_coreset.compute_metric(mmd_metric)
# Print the MMD values
print(f"Random sampling coreset MMD: {maximum_mean_discrepancy_random}")
print(f"Herding coreset MMD: {maximum_mean_discrepancy_herding}")
# Save a new video. Y_ is the original sequence with dimensions preserved
coreset_images = raw_data[coreset_indices_herding]
if out_path is not None:
imageio.mimsave(
out_path / Path("pounce_map_reduce_coreset.gif"), coreset_images
)
# Plot to visualise which frames were chosen from the sequence action frames are
# where the "pounce" occurs
action_frames = np.arange(63, 85)
x = np.arange(num_points_length_scale_selection)
y = np.zeros(num_points_length_scale_selection)
y[coreset_indices_herding] = 1.0
z = np.zeros(num_points_length_scale_selection)
z[jnp.intersect1d(coreset_indices_herding, action_frames)] = 1.0
plt.figure(figsize=(20, 3))
plt.bar(x, y, alpha=0.5)
plt.bar(x, z)
plt.xlabel("Frame")
plt.ylabel("Chosen")
plt.tight_layout()
if out_path is not None:
plt.savefig(out_path / "pounce_map_reduce_frames.png")
plt.close()
return (
float(maximum_mean_discrepancy_herding),
float(maximum_mean_discrepancy_random),
)
# pylint: enable=too-many-statements
# pylint: enable=too-many-locals
# pylint: enable=duplicate-code
if __name__ == "__main__":
main()