Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/kernel thinning backup #915

Merged
merged 12 commits into from
Jan 15, 2025
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(https://github.com/gchq/coreax/pull/888)
- Added a method `SquaredExponentialKernel.get_sqrt_kernel` which returns a square
root kernel for the squared exponential kernel. (https://github.com/gchq/coreax/pull/883)
- Added a new coreset algorithm Kernel Thinning. (https://github.com/gchq/coreax/pull/915)

### Fixed

Expand Down
102 changes: 66 additions & 36 deletions coreax/solvers/coresubset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import jax.scipy as jsp
import jax.tree_util as jtu
from jax import lax
from jaxtyping import Array, ArrayLike, Scalar, Shaped
from jaxtyping import Array, ArrayLike, Bool, Float, Scalar, Shaped
from typing_extensions import override

from coreax.coreset import Coresubset
Expand Down Expand Up @@ -900,9 +900,8 @@ def reduce(
Reduce 'dataset' to a :class:`~coreax.coreset.Coresubset` with 'KernelThinning'.

This is done by first computing the number of halving steps required, referred
gw265981 marked this conversation as resolved.
Show resolved Hide resolved
to as `m`, is calculated. The original data is clipped so that it is divisible
by a power of two. The kernel halving algorithm is then recursively applied to
halve the data.
to as `m`. The original data is clipped so that it is divisible by a power of
two. The kernel halving algorithm is then recursively applied to halve the data.

Subsequently, a `baseline_coreset` is added to the ensemble of coresets. The
best coreset is selected to minimise the Maximum Mean Discrepancy (MMD) and
Expand All @@ -927,7 +926,12 @@ def reduce(
best_coreset_indices = self.kt_choose(partition, dataset)
return self.kt_refine(Coresubset(Data(best_coreset_indices), dataset))

def kt_half_recursive(self, current_coreset, m, original_dataset):
def kt_half_recursive(
self,
current_coreset: Union[_Data, Coresubset[_Data]],
m: int,
original_dataset: _Data,
) -> list[Coresubset[_Data]]:
"""
Recursively halve the original dataset into coresets.

Expand All @@ -936,13 +940,14 @@ def kt_half_recursive(self, current_coreset, m, original_dataset):
:param original_dataset: The original dataset.
:return: Fully partitioned list of coresets.
"""
# If m == 0, do not do anything just convert to original data to type Coresubset
if m == 0:
gw265981 marked this conversation as resolved.
Show resolved Hide resolved
return [
Coresubset(Data(jnp.arange(len(current_coreset))), original_dataset)
]

# Recursively call self.kt_half on the coreset (or the dataset)
if hasattr(current_coreset, "coreset"):
if isinstance(current_coreset, Coresubset):
subset1, subset2 = self.kt_half(current_coreset.coreset)
else:
subset1, subset2 = self.kt_half(current_coreset)
Expand All @@ -952,7 +957,7 @@ def kt_half_recursive(self, current_coreset, m, original_dataset):
subset2 = eqx.tree_at(lambda x: x.pre_coreset_data, subset2, original_dataset)

# Update indices: map current subset's indices to original dataset
if hasattr(current_coreset, "nodes") and hasattr(current_coreset.nodes, "data"):
if isinstance(current_coreset, Coresubset):
parent_indices = current_coreset.nodes.data # Parent subset's indices
subset1_indices = subset1.nodes.data.flatten() # Indices relative to parent
subset2_indices = subset2.nodes.data.flatten() # Indices relative to parent
Expand All @@ -965,7 +970,7 @@ def kt_half_recursive(self, current_coreset, m, original_dataset):
subset1 = eqx.tree_at(lambda x: x.nodes.data, subset1, subset1_indices)
subset2 = eqx.tree_at(lambda x: x.nodes.data, subset2, subset2_indices)

# Recur for both subsets and concatenate results
# Recurse for both subsets and concatenate results
return self.kt_half_recursive(
subset1, m - 1, original_dataset
) + self.kt_half_recursive(subset2, m - 1, original_dataset)
Expand Down Expand Up @@ -993,7 +998,7 @@ def kt_half(self, dataset: _Data) -> list[Coresubset[_Data]]:
coresets_masking = jnp.zeros(n)

# Initialise parameter
param = jnp.float32(0)
sigma = jnp.float32(0)
k = self.sqrt_kernel.compute_elementwise

def compute_kernel_distance(x1, x2):
Expand All @@ -1006,7 +1011,7 @@ def compute_kernel_distance(x1, x2):
"""
return jnp.sqrt(k(x1, x1) + k(x2, x2) - 2 * k(x1, x2))

def get_a_and_param(b, sigma):
def get_a_and_sigma(b, sigma):
"""Compute 'a' and new sigma parameter."""
a = jnp.maximum(b * sigma * jnp.sqrt(2 * jnp.log(2 / self.delta)), b**2)

Expand All @@ -1018,16 +1023,23 @@ def get_a_and_param(b, sigma):
return a, new_sigma

def get_alpha(
gw265981 marked this conversation as resolved.
Show resolved Hide resolved
x1: jnp.ndarray,
x2: jnp.ndarray,
x1: Float[Array, "1 d"],
x2: Float[Array, "1 d"],
i: int,
current_first_coreset: jnp.ndarray,
original_dataset_masking: jnp.ndarray,
coreset_masking: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
current_first_coreset: Float[Array, "n d"],
original_dataset_masking: Bool[Array, "2n d"],
coreset_masking: Bool[Array, "n d"],
) -> tuple[Float[Array, ""], Bool[Array, "2n d"], Bool[Array, "n d"]]:
r"""
Calculate the value of alpha and update the boolean arrays.

.. math::
\\alpha(x, y, S, S1) =
\\sum_{s \\in S}\\left(k(s, x) - k(s, y)\\right) - 2
\\sum_{s \\in S1}\\left(k(s, x) - k(s, y)\\right), where S is the current
data-points already considered and S1 is the current state of the first
coreset

:param x1: The first data point in the kernel evaluation.
:param x2: The second data point in the kernel evaluation.
:param i: The current index in the iteration.
Expand All @@ -1039,12 +1051,18 @@ def get_alpha(
- `original_array_masking`: Updated boolean array for the dataset.
- `coresets_masking`: Updated boolean array for the coresets.
"""
# Define the vectorised functions: k(.,x_1), k(.x_2)
k_vec_x1 = jax.vmap(lambda y: k(y, x1))
gw265981 marked this conversation as resolved.
Show resolved Hide resolved
k_vec_x2 = jax.vmap(lambda y: k(y, x2))

# Define the indexed versions of the above functions were, we can pass
# the index set k(original_array[], x_1) and k(original_array[], x_1)
gw265981 marked this conversation as resolved.
Show resolved Hide resolved
k_vec_x1_idx = jax.vmap(lambda y: k(original_array[y], x1))
k_vec_x2_idx = jax.vmap(lambda y: k(original_array[y], x2))
# Apply to original array and sum

# Because the size of jax arrays are pre-fixed, we have to only sum the
# first few elements and ignore the rest of elements, this is achieved by
# dotting with a boolean array
term1 = jnp.dot(
(k_vec_x1(original_array) - k_vec_x2(original_array)),
original_dataset_masking,
Expand Down Expand Up @@ -1092,19 +1110,19 @@ def probabilistic_swap(
def kernel_thinning_body_fun(
i: int,
state: tuple[
gw265981 marked this conversation as resolved.
Show resolved Hide resolved
jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
Float[Array, "n"], # first_coreset_indices
Float[Array, "n"], # second_coreset_indices
Float[Array, "1"], # sigma parameter
Bool[Array, "2n"], # original_array_masking
Bool[Array, "n"], # coresets_masking
KeyArrayLike,
],
) -> tuple[
jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
Float[Array, "n"],
Float[Array, "n"],
Float[Array, "1"],
Bool[Array, "2n"],
Bool[Array, "n"],
KeyArrayLike,
]:
"""
Expand All @@ -1114,31 +1132,43 @@ def kernel_thinning_body_fun(
:param state: A tuple containing:
- first_coreset_indices: The first array of indices.
- second_coreset_indices: The second array of indices.
- param: The scaling parameter.
- param: The sigma parameter.
- original_array_masking: Boolean array for masking.
- coresets_masking: Boolean array for masking coresets.
- random_key: A JAX random key.
:return: The updated state tuple after processing the current iteration.
"""
arr1, arr2, param, bool_arr_1, bool_arr_2, random_key = state
(
first_coreset_indices,
second_coreset_indices,
sigma,
original_array_masking,
coresets_masking,
random_key,
) = state
# Step 1: Get values from original array
x1 = original_array[i * 2]
x2 = original_array[i * 2 + 1]
# Step 2: Get a and new parameter
a, new_param = get_a_and_param(compute_kernel_distance(x1, x2), param)
a, new_sigma = get_a_and_sigma(compute_kernel_distance(x1, x2), sigma)
# Step 3: Compute alpha
alpha, new_bool_arr_1, new_bool_arr_2 = get_alpha(
x1, x2, i, arr1, bool_arr_1, bool_arr_2
x1,
x2,
i,
first_coreset_indices,
original_array_masking,
coresets_masking,
)
# Step 4: Get final values
(val1, val2), new_random_key = probabilistic_swap(i, a, alpha, random_key)
# Step 5: Update arrays
new_arr1 = arr1.at[i].set(val1)
new_arr2 = arr2.at[i].set(val2)
new_arr1 = first_coreset_indices.at[i].set(val1)
gw265981 marked this conversation as resolved.
Show resolved Hide resolved
new_arr2 = second_coreset_indices.at[i].set(val2)
return (
new_arr1,
new_arr2,
new_param,
new_sigma,
new_bool_arr_1,
new_bool_arr_2,
new_random_key,
Expand All @@ -1151,7 +1181,7 @@ def kernel_thinning_body_fun(
(
first_coreset_indices,
second_coreset_indices,
param,
sigma,
original_array_masking,
coresets_masking,
self.random_key,
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,6 +2208,44 @@ def test_kt_half_analytic(self) -> None:
- Compute probability, update the sets, and proceed to the next pair,
using the updated :math:`\\sigma`.

Functions:

- :func:`b(x, y)` computes the distance measure between two points based
on the kernel:
.. math::
b(x, y) = \\sqrt{k(x, x) + k(y, y) - 2 \\cdot k(x, y)}

- :func:`get_swap_params(sigma, b_value, delta)` computes the parameters
required to update :math:`a` and :math:`\\sigma`:
.. math::
a = \\max\\left(b \\cdot \\sigma \\cdot \\sqrt{2 \\ln(2 / \\delta)},
b^2\\right)
.. math::
\\sigma^2_{new} = \\sigma^2 + \\max\\left(
\\frac{b^2 (1 + (b^2 - 2a) \\cdot \\sigma^2 / a^2)}{a^2},
0\\right)
.. math::
\\sigma_{new} = \\sqrt{\\sigma^2_{new}}

- :func:`alpha(x, y, S, S1)` computes the difference between the total
kernel sum for all elements in S and the subset S1:
.. math::
\\alpha(x, y, S, S1) =
\\sum_{s \\in S}\\left(k(s, x) - k(s, y)\\right) - 2
\\sum_{s \\in S1}\\left(k(s, x) - k(s, y)\\right)

- :func:`get_probability(alpha_val, a)` computes the probability that
determines the assignment of points to the coresets:
.. math::
P = \\min\\left(1, \\max\\left(0, 0.5 \\cdot \\left(1 - \\frac{\\alpha}{a}\\right)\\right)\\right)

and for square-root-kernel, choose a ``length_scale`` of
:math:`\frac{1}{\sqrt{2}}` to simplify computations with the
``SquaredExponentialKernel``, in particular it becomes:

.. math::
k(x, y) = e^{-||x - y||^2}

Calculations for each pair:

Pair (1, 2):
Expand Down