diff --git a/.cspell/custom_misc.txt b/.cspell/custom_misc.txt index 95c9e4b92..9e78c47cc 100644 --- a/.cspell/custom_misc.txt +++ b/.cspell/custom_misc.txt @@ -48,6 +48,7 @@ pmatrix PMLR primaryclass PRNG +prob propto rcond recomb diff --git a/CHANGELOG.md b/CHANGELOG.md index 1554deb3e..160920bee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 (https://github.com/gchq/coreax/pull/909) - Added a method `SquaredExponentialKernel.get_sqrt_kernel` which returns a square root kernel for the squared exponential kernel. (https://github.com/gchq/coreax/pull/883) +- Added a new coreset algorithm Kernel Thinning. (https://github.com/gchq/coreax/pull/915) - Added (loose) lower bounds to all direct dependencies. (https://github.com/gchq/coreax/pull/920) ### Fixed diff --git a/coreax/solvers/__init__.py b/coreax/solvers/__init__.py index a2e15e79d..71cb9e40e 100644 --- a/coreax/solvers/__init__.py +++ b/coreax/solvers/__init__.py @@ -27,6 +27,7 @@ GreedyKernelPointsState, HerdingState, KernelHerding, + KernelThinning, RandomSample, RPCholesky, RPCholeskyState, @@ -49,6 +50,7 @@ "RandomSample", "HerdingState", "KernelHerding", + "KernelThinning", "SteinThinning", "RPCholeskyState", "RPCholesky", diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index ce63a8ca8..c5fa34549 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -14,6 +14,7 @@ """Solvers for constructing coresubsets.""" +import math from collections.abc import Callable from typing import Optional, TypeVar, Union @@ -23,7 +24,8 @@ import jax.random as jr import jax.scipy as jsp import jax.tree_util as jtu -from jaxtyping import Array, ArrayLike, Scalar, Shaped +from jax import lax +from jaxtyping import Array, ArrayLike, Bool, Float, Scalar, Shaped from typing_extensions import override from coreax.coreset import Coresubset @@ -33,6 +35,7 @@ MinimalEuclideanNormSolver, RegularisedLeastSquaresSolver, ) +from coreax.metrics import MMD from coreax.score_matching import ScoreMatching, convert_stein_kernel from coreax.solvers.base import ( CoresubsetSolver, @@ -859,3 +862,381 @@ def _greedy_body( return Coresubset(updated_coreset_indices, dataset), GreedyKernelPointsState( padded_feature_gramian ) + + +class KernelThinning(CoresubsetSolver[_Data, None], ExplicitSizeSolver): + r""" + Kernel Thinning - a hierarchical coreset construction solver. + + `Kernel Thinning` is a hierarchical, and probabilistic algorithm for coreset + construction. It builds a coreset by splitting the dataset into several candidate + coresets by repeatedly halving the dataset and applying probabilistic swapping. + The best of these candidates (the one with the lowest MMD) is chosen which is + further refined to minimise the Maximum Mean Discrepancy (MMD) between the original + dataset and the coreset. This implementation is a modification of the Kernel + Thinning algorithm in :cite:`dwivedi2024kernelthinning` to make it an + ExplicitSizeSolver. + + :param kernel: A `~coreax.kernels.ScalarValuedKernel` instance defining the primary + kernel function used for choosing the best coreset and refining it. + :param random_key: Key for random number generation, enabling reproducibility of + probabilistic components in the algorithm. + :param delta: A float between 0 and 1 used to compute the swapping probability + during the splitting process. A recommended value is :math:`1 / \log(\log(n))`, + where :math:`n` is the length of the original dataset. + :param sqrt_kernel: A `~coreax.kernels.ScalarValuedKernel` instance representing the + square root kernel used for splitting the original dataset. + """ + + kernel: ScalarValuedKernel + random_key: KeyArrayLike + delta: float + sqrt_kernel: ScalarValuedKernel + + def reduce( + self, dataset: _Data, solver_state: None = None + ) -> tuple[Coresubset[_Data], None]: + """ + Reduce 'dataset' to a :class:`~coreax.coreset.Coresubset` with 'KernelThinning'. + + This is done by first computing the number of halving steps required, referred + to as `m`. The original data is clipped so that it is divisible by a power of + two. The kernel halving algorithm is then recursively applied to halve the data. + + Subsequently, a `baseline_coreset` is added to the ensemble of coresets. The + best coreset is selected to minimise the Maximum Mean Discrepancy (MMD) and + finally, it is refined further for optimal results. This final refinement step + can reintroduce the clipped data dataset if they are found to reduce the MMD. + + :param dataset: The original dataset to be reduced. + :param solver_state: The state of the solver. + + :return: A tuple containing the final coreset and the solver state (None). + """ + if self.coreset_size > len(dataset): + raise ValueError(MSG) + n = len(dataset) + m = math.floor(math.log2(n) - math.log2(self.coreset_size)) + clipped_original_dataset = dataset[: self.coreset_size * 2**m] + + partition = self.kt_half_recursive(clipped_original_dataset, m, dataset) + baseline_coreset = self.get_baseline_coreset(dataset, self.coreset_size) + partition.append(baseline_coreset) + + best_coreset_indices = self.kt_choose(partition, dataset) + return self.kt_refine(Coresubset(Data(best_coreset_indices), dataset)) + + def kt_half_recursive( + self, + current_coreset: Union[_Data, Coresubset[_Data]], + m: int, + original_dataset: _Data, + ) -> list[Coresubset[_Data]]: + """ + Recursively halve the original dataset into coresets. + + :param current_coreset: The current coreset or dataset being partitioned. + :param m: The remaining depth of recursion. + :param original_dataset: The original dataset. + :return: Fully partitioned list of coresets. + """ + # If m == 0, do not do anything just convert to original data to type Coresubset + if m == 0: + return [ + Coresubset(Data(jnp.arange(len(current_coreset))), original_dataset) + ] + + # Recursively call self.kt_half on the coreset (or the dataset) + if isinstance(current_coreset, Coresubset): + subset1, subset2 = self.kt_half(current_coreset.coreset) + else: + subset1, subset2 = self.kt_half(current_coreset) + + # Update pre_coreset_data for both subsets to point to the original dataset + subset1 = eqx.tree_at(lambda x: x.pre_coreset_data, subset1, original_dataset) + subset2 = eqx.tree_at(lambda x: x.pre_coreset_data, subset2, original_dataset) + + # Update indices: map current subset's indices to original dataset + if isinstance(current_coreset, Coresubset): + parent_indices = current_coreset.nodes.data # Parent subset's indices + subset1_indices = subset1.nodes.data.flatten() # Indices relative to parent + subset2_indices = subset2.nodes.data.flatten() # Indices relative to parent + + # Map subset indices back to original dataset + subset1_indices = parent_indices[subset1_indices] + subset2_indices = parent_indices[subset2_indices] + + # Update the subsets with the remapped indices + subset1 = eqx.tree_at(lambda x: x.nodes.data, subset1, subset1_indices) + subset2 = eqx.tree_at(lambda x: x.nodes.data, subset2, subset2_indices) + + # Recurse for both subsets and concatenate results + return self.kt_half_recursive( + subset1, m - 1, original_dataset + ) + self.kt_half_recursive(subset2, m - 1, original_dataset) + + def kt_half(self, dataset: _Data) -> list[Coresubset[_Data]]: + """ + Partition the given dataset into two subsets. + + First, initialise two coresubsets, each of which will contain half the points of + the original dataset. Divide the points of the original dataset into pairs and + probabilistically decide which point of the pair should go to which of the + coresets. This function uses variables such as `a`, `b`, `sigma`, and `delta`, + they refer to the corresponding parameters in the paper + :cite:`dwivedi2024kernelthinning`. + + :param dataset: The input dataset to be halved. + :return: A list containing the two partitioned coresets. + """ + n = len(dataset) // 2 + original_array = dataset.data + first_coreset_indices = jnp.zeros(n, dtype=jnp.int32) + second_coreset_indices = jnp.zeros(n, dtype=jnp.int32) + + original_array_masking = jnp.zeros(2 * n) + coresets_masking = jnp.zeros(n) + + # Initialise parameter + sigma = jnp.float32(0) + k = self.sqrt_kernel.compute_elementwise + + def compute_kernel_distance(x1, x2): + """ + Compute kernel distance between two data points. + + :param x1: The first data point. + :param x2: The second data point. + :return: The kernel distance between `x1` and `x2`. + """ + return jnp.sqrt(k(x1, x1) + k(x2, x2) - 2 * k(x1, x2)) + + def get_a_and_sigma(b, sigma): + """Compute 'a' and new sigma parameter.""" + a = jnp.maximum(b * sigma * jnp.sqrt(2 * jnp.log(2 / self.delta)), b**2) + + # Update sigma + new_sigma = jnp.sqrt( + sigma**2 + jnp.maximum(b**2 * (1 + (b**2 - 2 * a) * sigma**2 / a**2), 0) + ) + + return a, new_sigma + + def get_alpha( + x1: Float[Array, "1 d"], + x2: Float[Array, "1 d"], + i: int, + current_first_coreset: Float[Array, "n d"], + original_dataset_masking: Bool[Array, "2n d"], + coreset_masking: Bool[Array, "n d"], + ) -> tuple[Float[Array, ""], Bool[Array, "2n d"], Bool[Array, "n d"]]: + r""" + Calculate the value of alpha and update the boolean arrays. + + .. math:: + \\alpha(x, y, S, S1) = + \\sum_{s \\in S}\\left(k(s, x) - k(s, y)\\right) - 2 + \\sum_{s \\in S1}\\left(k(s, x) - k(s, y)\\right), where S is the current + data-points already considered and S1 is the current state of the first + coreset. + + :param x1: The first data point in the kernel evaluation. + :param x2: The second data point in the kernel evaluation. + :param i: The current index in the iteration. + :param current_first_coreset: Current first_coreset_indices. + :param original_dataset_masking: A boolean array that tracks indices. + :param coreset_masking: A boolean array that tracks indices. + :return: A tuple containing: + - `alpha`: The computed value of alpha. + - `original_array_masking`: Updated boolean array for the dataset. + - `coresets_masking`: Updated boolean array for the coresets. + """ + # Define the vectorised functions: k(.,x_1), k(.,x_2) + k_vec_x1 = jax.vmap(lambda y: k(y, x1)) + k_vec_x2 = jax.vmap(lambda y: k(y, x2)) + + # Define the indexed versions of the above functions were, we can pass + # the index set k(original_array[], x_1) and k(original_array[], x_2) + k_vec_x1_idx = jax.vmap(lambda y: k(original_array[y], x1)) + k_vec_x2_idx = jax.vmap(lambda y: k(original_array[y], x2)) + + # Because the size of jax arrays are pre-fixed, we have to only sum the + # first few elements and ignore the rest of elements, this is achieved by + # dotting with a boolean array + term1 = jnp.dot( + (k_vec_x1(original_array) - k_vec_x2(original_array)), + original_dataset_masking, + ) + term2 = -2 * jnp.dot( + ( + k_vec_x1_idx(current_first_coreset) + - k_vec_x2_idx(current_first_coreset) + ), + coreset_masking, + ) + # For original_array_masking, set 2i and 2i+1 positions to 1 + original_dataset_masking = original_dataset_masking.at[2 * i].set(1) + original_dataset_masking = original_dataset_masking.at[2 * i + 1].set(1) + # For coresets_masking, set i-th position to 1 + coreset_masking = coreset_masking.at[i].set(1) + # Combine all terms + alpha = term1 + term2 + return alpha, original_dataset_masking, coreset_masking + + def probabilistic_swap( + i: int, a: jnp.ndarray, alpha: jnp.ndarray, random_key: KeyArrayLike + ) -> tuple[tuple[int, int], KeyArrayLike]: + """ + Perform a probabilistic swap based on the given parameters. + + :param i: The current index in the dataset. + :param a: The swap threshold computed based on kernel parameters. + :param alpha: The calculated value for probabilistic swapping. + :param random_key: A random key for generating random numbers. + :return: A tuple containing: + - A tuple of indices indicating the swapped values. + - The updated random key. + """ + key1, key2 = jax.random.split(random_key) + + prob = jax.random.uniform(key1) + return lax.cond( + prob < 1 / 2 * (1 - alpha / a), + lambda _: (2 * i, 2 * i + 1), # first case: val1 = x1, val2 = x2 + lambda _: (2 * i + 1, 2 * i), # second case: val1 = x2, val2 = x1 + None, + ), key2 + + def kernel_thinning_body_fun( + i: int, + state: tuple[ + Float[Array, "n"], # first_coreset_indices + Float[Array, "n"], # second_coreset_indices + Float[Array, "1"], # sigma parameter + Bool[Array, "2n"], # original_array_masking + Bool[Array, "n"], # coresets_masking + KeyArrayLike, + ], + ) -> tuple[ + Float[Array, "n"], + Float[Array, "n"], + Float[Array, "1"], + Bool[Array, "2n"], + Bool[Array, "n"], + KeyArrayLike, + ]: + """ + Perform one iteration of the halving process. + + :param i: The current iteration index. + :param state: A tuple containing: + - first_coreset_indices: The first array of indices. + - second_coreset_indices: The second array of indices. + - param: The sigma parameter. + - original_array_masking: Boolean array for masking. + - coresets_masking: Boolean array for masking coresets. + - random_key: A JAX random key. + :return: The updated state tuple after processing the current iteration. + """ + ( + first_coreset_indices, + second_coreset_indices, + sigma, + original_array_masking, + coresets_masking, + random_key, + ) = state + # Step 1: Get values from original array + x1 = original_array[i * 2] + x2 = original_array[i * 2 + 1] + # Step 2: Get a and new parameter + a, new_sigma = get_a_and_sigma(compute_kernel_distance(x1, x2), sigma) + # Step 3: Compute alpha + alpha, new_bool_arr_1, new_bool_arr_2 = get_alpha( + x1, + x2, + i, + first_coreset_indices, + original_array_masking, + coresets_masking, + ) + # Step 4: Get final values + (val1, val2), new_random_key = probabilistic_swap(i, a, alpha, random_key) + # Step 5: Update arrays + updated_first_coreset_indices = first_coreset_indices.at[i].set(val1) + updated_second_coreset_indices = second_coreset_indices.at[i].set(val2) + return ( + updated_first_coreset_indices, + updated_second_coreset_indices, + new_sigma, + new_bool_arr_1, + new_bool_arr_2, + new_random_key, + ) + + (final_arr1, final_arr2, _, _, _, _) = lax.fori_loop( + 0, # start index + n, # end index + kernel_thinning_body_fun, # body function + ( + first_coreset_indices, + second_coreset_indices, + sigma, + original_array_masking, + coresets_masking, + self.random_key, + ), + ) + return [Coresubset(final_arr1, dataset), Coresubset(final_arr2, dataset)] + + def get_baseline_coreset( + self, dataset: Data, baseline_coreset_size: int + ) -> Coresubset[_Data]: + """ + Generate a baseline coreset by randomly sampling from the dataset. + + :param dataset: The input dataset from which the baseline coreset is sampled. + :param baseline_coreset_size: The number of dataset in the baseline coreset. + :return: A randomly sampled baseline coreset with the specified size. + """ + baseline_coreset, _ = RandomSample( + coreset_size=baseline_coreset_size, random_key=self.random_key + ).reduce(dataset) + return baseline_coreset + + def kt_choose( + self, candidate_coresets: list[Coresubset[_Data]], points: _Data + ) -> Shaped[Array, " coreset_size"]: + """ + Select the best coreset from a list of candidate coresets based on MMD. + + :param candidate_coresets: A list of candidate coresets to be evaluated. + :param points: The original dataset against which the coresets are compared. + :return: The coreset with the smallest MMD relative to the input dataset. + """ + mmd = MMD(kernel=self.kernel) + candidate_coresets_jax = jnp.array([c.coreset.data for c in candidate_coresets]) + candidate_coresets_indices = jnp.array([c.nodes for c in candidate_coresets]) + mmd_values = jax.vmap(lambda c: mmd.compute(c, points))(candidate_coresets_jax) + + best_index = jnp.argmin(mmd_values) + + return candidate_coresets_indices[best_index] + + def kt_refine( + self, candidate_coreset: Coresubset[_Data] + ) -> tuple[Coresubset[_Data], None]: + """ + Refine the selected candidate coreset. + + Use :meth:`~coreax.solvers.KernelHerding.refine` which achieves the result of + looping through each element in coreset replacing that element with a point in + the original dataset to minimise MMD in each step. + + :param candidate_coreset: The candidate coreset to be refined. + :return: The refined coreset. + """ + refined_coreset, _ = KernelHerding( + coreset_size=self.coreset_size, kernel=self.kernel + ).refine(candidate_coreset) + return refined_coreset, None diff --git a/tests/unit/test_solvers.py b/tests/unit/test_solvers.py index fc8f8db7e..f1e051332 100644 --- a/tests/unit/test_solvers.py +++ b/tests/unit/test_solvers.py @@ -49,6 +49,7 @@ GreedyKernelPointsState, HerdingState, KernelHerding, + KernelThinning, MapReduce, RandomSample, RPCholesky, @@ -2154,3 +2155,192 @@ def solver_factory(self) -> Union[Solver, jtu.Partial]: return jtu.Partial( TreeRecombination, test_functions=None, rcond=None, tree_reduction_factor=3 ) + + +class TestKernelThinning(ExplicitSizeSolverTest): + """Test cases for :class:`coreax.solvers.coresubset.KernelThinning`.""" + + @override + @pytest.fixture(scope="class") + def solver_factory(self) -> Union[type[Solver], jtu.Partial]: + kernel = PCIMQKernel() + coreset_size = self.shape[0] // 10 + return jtu.Partial( + KernelThinning, + coreset_size=coreset_size, + random_key=self.random_key, + kernel=kernel, + delta=0.01, + sqrt_kernel=kernel, + ) + + def test_kt_half_analytic(self) -> None: + # pylint: disable=line-too-long + r""" + Test the halving step of kernel thinning on analytical example. + + We aim to split [1, 2, 3, 4, 5, 6, 7, 8] into two coresets, S1 and S2, each + containing 4 elements, enforcing two unique coresets. + + First, let S be the full dataset, with S1 and S2 as subsets. S1 will contain + half the elements, and S2 will contain the other half. Let :math:`k` represent + the square root kernel. We will use variables labelled :math:`a`, :math:`b`, + :math:`\\alpha`, :math:`\\sigma`, :math:`\\delta`, and probability, which + will be updated iteratively to form the coresets. + + We process pairs :math:`(x, y)` sequentially: :math:`(1, 2)`, :math:`(3, 4)`, + :math:`(5, 6)`, and :math:`(7, 8)`. For each pair, we compute a probability + that determines whether :math:`x` goes to S1 and :math:`y` to S2, or vice + versa. In either case, both :math:`x` and :math:`y` are added to S. + + If this probability is greater than 0.5, we add the x and y to S1 and S2 + respectively, otherwise we swap x and y and then add x to S1 and y to S2. + + The process is as follows: + + - Start with :math:`\\delta = \\frac{1}{8}` and :math:`\\sigma = 0`. + - Take a pair :math:`(x, y)`. + - Compute :math:`b` and :math:`\\alpha`. + - Compute :math:`a` and update :math:`\\sigma`. + - Compute probability, update the sets, and proceed to the next pair, + using the updated :math:`\\sigma`. + + Functions: + + - :func:`b(x, y)` computes the distance measure between two points based + on the kernel: + .. math:: + b(x, y) = \\sqrt{k(x, x) + k(y, y) - 2 \\cdot k(x, y)} + + - :func:`get_swap_params(sigma, b_value, delta)` computes the parameters + required to update :math:`a` and :math:`\\sigma`: + .. math:: + a = \\max\\left(b \\cdot \\sigma \\cdot \\sqrt{2 \\ln(2 / \\delta)}, + b^2\\right) + .. math:: + \\sigma^2_{new} = \\sigma^2 + \\max\\left( + \\frac{b^2 (1 + (b^2 - 2a) \\cdot \\sigma^2 / a^2)}{a^2}, + 0\\right) + .. math:: + \\sigma_{new} = \\sqrt{\\sigma^2_{new}} + + - :func:`alpha(x, y, S, S1)` computes the difference between the total + kernel sum for all elements in S and the subset S1: + .. math:: + \\alpha(x, y, S, S1) = + \\sum_{s \\in S}\\left(k(s, x) - k(s, y)\\right) - 2 + \\sum_{s \\in S1}\\left(k(s, x) - k(s, y)\\right) + + - :func:`get_probability(alpha_val, a)` computes the probability that + determines the assignment of points to the coresets: + .. math:: + P = \\min\\left(1, \\max\\left(0, 0.5 \\cdot \\left(1 - \\frac{\\alpha}{a}\\right)\\right)\\right) + + and for square-root-kernel, choose a ``length_scale`` of + :math:`\frac{1}{\sqrt{2}}` to simplify computations with the + ``SquaredExponentialKernel``, in particular it becomes: + + .. math:: + k(x, y) = e^{-||x - y||^2} + + Calculations for each pair: + + Pair (1, 2): + - Inputs: S=[], S1=[], S2=[], sigma=0, delta=1/8. + - Compute b: + b(1,2) = sqrt(k(1,1) + k(2,2) - 2*k(1,2)) = 1.1243847608566284. + - Compute alpha: alpha = 0 (as S and S1 are empty). + - Compute a: + a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 1.264241099357605. + - Update sigma: + new_sigma^2 = sigma^2 + max(0, b^2 * (1 + (b^2 - 2*a) * sigma^2) / a^2). + new_sigma = sqrt(new_sigma^2) = 1.1243847608566284. + - Compute probability: + p = 0.5 * (1 - alpha / a) = 0.5. + - Assign: + Since p <= 0.5, assign x=1 to S2, y=2 to S1, and add both to S. + S1 = [2], S2 = [1], S = [1, 2]. + + Pair (3, 4): + - Inputs: S=[1, 2], S1=[2], S2=[1], sigma=1.1243847608566284. + - Compute b: + b(3,4) = sqrt(k(3,3) + k(4,4) - 2*k(3,4)) = 1.1243847608566284. + - Compute alpha: + alpha = sum(k(s, 3) - k(s, 4) for s in S) - 2 * sum(k(s, 3) - k(s, 4) for s in S1). + alpha = -0.3313715159893036. + - Compute a: + a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 2.9770602825192523. + - Update sigma: + new_sigma = sqrt(sigma^2 + max(0, b^2 * (1 + (b^2 - 2*a) * sigma^2) / a^2)). + new_sigma = 1.297198507467962. + - Compute probability: + p = 0.5 * (1 - alpha / a) = 0.5556541681289673. + - Assign: + Since p > 0.5, assign x=3 to S1 and y=4 to S2, and add both to S. + S1 = [2, 3], S2 = [1, 4], S = [1, 2, 3, 4]. + + Pair (5, 6): + - Inputs: S=[1, 2, 3, 4], S1=[2, 3], S2=[1, 4], sigma=1.297198507467962. + - Compute b: + b(5,6) = sqrt(k(5,5) + k(6,6) - 2*k(5,6)) = 1.1243847608566284. + - Compute alpha: + alpha = sum(k(s, 5) - k(s, 6) for s in S) - 2 * sum(k(s, 5) - k(s, 6) for s in S1). + alpha = 0.33124834299087524. + - Compute a: + a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 3.434623326772776. + - Update sigma: + new_sigma = sqrt(sigma^2 + max(0, b^2 * (1 + (b^2 - 2*a) * sigma^2) / a^2)). + new_sigma = 1.3914653590235087. + - Compute probability: + p = 0.5 * (1 - alpha / a) = 0.4517780542373657. + - Assign: + Since p <= 0.5, assign x=5 to S2 and y=6 to S1, and add both to S. + S1 = [2, 3, 6], S2 = [1, 4, 5], S = [1, 2, 3, 4, 5, 6]. + + Pair (7, 8): + - Inputs: S=[1, 2, 3, 4, 5, 6], S1=[2, 3, 6], S2=[1, 4, 5], sigma=1.3914653590235087. + - Compute b: + b(7,8) = sqrt(k(7,7) + k(8,8) - 2*k(7,8)) = 1.1243847608566284. + - Compute alpha: + alpha = sum(k(s, 7) - k(s, 8) for s in S) - 2 * sum(k(s, 7) - k(s, 8) for s in S1). + alpha = -0.33124834299087524. + - Compute a: + a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 3.6842159106604075. + - Update sigma: + new_sigma = sqrt(sigma^2 + max(0, b^2 * (1 + (b^2 - 2*a) * sigma^2) / a^2)). + new_sigma = 1.4490018035043584. + - Compute probability: + p = 0.5 * (1 - alpha / a) = 0.5449550747871399. + - Assign: + Since p > 0.5, assign x=7 to S1 and y=8 to S2, and add both to S. + S1 = [2, 3, 6, 7], S2 = [1, 4, 5, 8], S = [1, 2, 3, 4, 5, 6, 7, 8]. + + Final result: + S1 = [2, 3, 6, 7], S2 = [1, 4, 5, 8]. + """ # noqa: E501 + # pylint: enable=line-too-long + length_scale = 1.0 / jnp.sqrt(2) + kernel = SquaredExponentialKernel() + sqrt_kernel = SquaredExponentialKernel(length_scale=length_scale) + delta = 1 / 8 + random_key = jax.random.PRNGKey(seed=0) + data = Data(jnp.array([1, 2, 3, 4, 5, 6, 7, 8])) + thinning_solver = KernelThinning( + coreset_size=2, + kernel=kernel, + random_key=random_key, + delta=delta, + sqrt_kernel=sqrt_kernel, + ) + + def deterministic_uniform(_key, _shape=None): + return 0.5 + + # Patch `jax.random.uniform` with `deterministic_uniform` + with patch("jax.random.uniform", side_effect=deterministic_uniform): + coresets = [ + jnp.asarray(s.coreset.data) for s in thinning_solver.kt_half(data) + ] + + np.testing.assert_array_equal(coresets[0], jnp.array([[2], [3], [6], [7]])) + np.testing.assert_array_equal(coresets[1], jnp.array([[1], [4], [5], [8]]))