diff --git a/src/utils/k_means.rs b/src/utils/k_means.rs index 8b4fd9d..ac58a08 100644 --- a/src/utils/k_means.rs +++ b/src/utils/k_means.rs @@ -1,24 +1,34 @@ +#![allow(clippy::ptr_arg)] + use super::parallelism::{ParallelIterator, Parallelism}; use base::scalar::*; use half::f16; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; -pub fn k_means( - parallelism: &impl Parallelism, +pub fn k_means( + parallelism: &P, c: usize, dims: usize, - samples: &Vec>, + samples: &Vec>, is_spherical: bool, iterations: usize, -) -> Vec> { +) -> Vec> { assert!(c > 0); assert!(dims > 0); let n = samples.len(); if n <= c { quick_centers(c, dims, samples.clone(), is_spherical) } else { - let mut lloyd_k_means = LloydKMeans::new(parallelism, c, dims, samples, is_spherical); + let compute = |parallelism: &P, centroids: &Vec>| { + if n >= 1000 && c >= 1000 { + rabitq_index(parallelism, dims, n, c, samples, centroids) + } else { + flat_index(parallelism, dims, n, c, samples, centroids) + } + }; + let mut lloyd_k_means = + LloydKMeans::new(parallelism, c, dims, samples, is_spherical, compute); for _ in 0..iterations { parallelism.check(); if lloyd_k_means.iterate() { @@ -29,11 +39,11 @@ pub fn k_means( } } -pub fn k_means_lookup(vector: &[S], centroids: &[Vec]) -> usize { +pub fn k_means_lookup(vector: &[f32], centroids: &[Vec]) -> usize { assert_ne!(centroids.len(), 0); let mut result = (f32::INFINITY, 0); for i in 0..centroids.len() { - let dis = S::reduce_sum_of_d2(vector, ¢roids[i]); + let dis = f32::reduce_sum_of_d2(vector, ¢roids[i]); if dis <= result.0 { result = (dis, i); } @@ -41,52 +51,248 @@ pub fn k_means_lookup(vector: &[S], centroids: &[Vec]) -> usiz result.1 } -fn quick_centers( +fn quick_centers( c: usize, dims: usize, - samples: Vec>, + samples: Vec>, is_spherical: bool, -) -> Vec> { +) -> Vec> { let n = samples.len(); assert!(c >= n); let mut rng = rand::thread_rng(); let mut centroids = samples; for _ in n..c { let r = (0..dims) - .map(|_| S::from_f32(rng.gen_range(-1.0f32..1.0f32))) + .map(|_| f32::from_f32(rng.gen_range(-1.0f32..1.0f32))) .collect(); centroids.push(r); } if is_spherical { for i in 0..c { let centroid = &mut centroids[i]; - let l = S::reduce_sum_of_x2(centroid).sqrt(); - S::vector_mul_scalar_inplace(centroid, 1.0 / l); + let l = f32::reduce_sum_of_x2(centroid).sqrt(); + f32::vector_mul_scalar_inplace(centroid, 1.0 / l); } } centroids } -struct LloydKMeans<'a, P, S> { +fn rabitq_index( + parallelism: &P, + dims: usize, + n: usize, + c: usize, + samples: &Vec>, + centroids: &Vec>, +) -> Vec { + fn code_alpha(vector: &[f32]) -> (f32, f32, f32, f32) { + let dims = vector.len(); + let sum_of_abs_x = f32::reduce_sum_of_abs_x(vector); + let sum_of_x2 = f32::reduce_sum_of_x2(vector); + let dis_u = sum_of_x2.sqrt(); + let x0 = sum_of_abs_x / (sum_of_x2 * (dims as f32)).sqrt(); + let x_x0 = dis_u / x0; + let fac_norm = (dims as f32).sqrt(); + let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt(); + let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); + let factor_ip = -2.0f32 / fac_norm * x_x0; + let cnt_pos = vector + .iter() + .map(|x| x.scalar_is_sign_positive() as i32) + .sum::(); + let cnt_neg = vector + .iter() + .map(|x| x.scalar_is_sign_negative() as i32) + .sum::(); + let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; + (sum_of_x2, factor_ppc, factor_ip, factor_err) + } + fn code_beta(vector: &[f32]) -> Vec { + let dims = vector.len(); + let mut code = Vec::new(); + for i in 0..dims { + code.push(vector[i].scalar_is_sign_positive() as u8); + } + code + } + let mut a0 = Vec::new(); + let mut a1 = Vec::new(); + let mut a2 = Vec::new(); + let mut a3 = Vec::new(); + let mut a4 = Vec::new(); + for vectors in centroids.chunks(32) { + use quantization::fast_scan::b4::pack; + let code_alphas = std::array::from_fn::<_, 32, _>(|i| { + if let Some(vector) = vectors.get(i) { + code_alpha(vector) + } else { + (0.0, 0.0, 0.0, 0.0) + } + }); + let code_betas = std::array::from_fn::<_, 32, _>(|i| { + let mut result = vec![0_u8; dims.div_ceil(4)]; + if let Some(vector) = vectors.get(i) { + let mut c = code_beta(vector); + c.resize(dims.next_multiple_of(4), 0); + for i in 0..dims.div_ceil(4) { + for j in 0..4 { + result[i] |= c[i * 4 + j] << j; + } + } + } + result + }); + a0.push(code_alphas.map(|x| x.0)); + a1.push(code_alphas.map(|x| x.1)); + a2.push(code_alphas.map(|x| x.2)); + a3.push(code_alphas.map(|x| x.3)); + a4.push(pack(dims.div_ceil(4) as _, code_betas).collect::>()); + } + parallelism + .into_par_iter(0..n) + .map(|i| { + fn gen(mut qvector: Vec) -> Vec { + let dims = qvector.len() as u32; + let t = dims.div_ceil(4); + qvector.resize(qvector.len().next_multiple_of(4), 0); + let mut lut = vec![0u8; t as usize * 16]; + for i in 0..t as usize { + unsafe { + // this hint is used to skip bound checks + std::hint::assert_unchecked(4 * i + 3 < qvector.len()); + std::hint::assert_unchecked(16 * i + 15 < lut.len()); + } + let t0 = qvector[4 * i + 0]; + let t1 = qvector[4 * i + 1]; + let t2 = qvector[4 * i + 2]; + let t3 = qvector[4 * i + 3]; + lut[16 * i + 0b0000] = 0; + lut[16 * i + 0b0001] = t0; + lut[16 * i + 0b0010] = t1; + lut[16 * i + 0b0011] = t1 + t0; + lut[16 * i + 0b0100] = t2; + lut[16 * i + 0b0101] = t2 + t0; + lut[16 * i + 0b0110] = t2 + t1; + lut[16 * i + 0b0111] = t2 + t1 + t0; + lut[16 * i + 0b1000] = t3; + lut[16 * i + 0b1001] = t3 + t0; + lut[16 * i + 0b1010] = t3 + t1; + lut[16 * i + 0b1011] = t3 + t1 + t0; + lut[16 * i + 0b1100] = t3 + t2; + lut[16 * i + 0b1101] = t3 + t2 + t0; + lut[16 * i + 0b1110] = t3 + t2 + t1; + lut[16 * i + 0b1111] = t3 + t2 + t1 + t0; + } + lut + } + fn fscan_process_lowerbound( + dims: u32, + lut: &(f32, f32, f32, f32, Vec), + (dis_u_2, factor_ppc, factor_ip, factor_err, t): ( + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[u8], + ), + epsilon: f32, + ) -> [Distance; 32] { + use quantization::fast_scan::b4::fast_scan_b4; + let &(dis_v_2, b, k, qvector_sum, ref s) = lut; + let r = fast_scan_b4(dims.div_ceil(4), t, s); + std::array::from_fn(|i| { + let rough = dis_u_2[i] + + dis_v_2 + + b * factor_ppc[i] + + ((2.0 * r[i] as f32) - qvector_sum) * factor_ip[i] * k; + let err = factor_err[i] * dis_v_2.sqrt(); + Distance::from_f32(rough - epsilon * err) + }) + } + use base::distance::Distance; + use quantization::quantize; + + let lut = { + let vector = &samples[i]; + let dis_v_2 = f32::reduce_sum_of_x2(vector); + let (k, b, qvector) = + quantize::quantize::<15>(f32::vector_to_f32_borrowed(vector).as_ref()); + let qvector_sum = if vector.len() <= 4369 { + quantize::reduce_sum_of_x_as_u16(&qvector) as f32 + } else { + quantize::reduce_sum_of_x_as_u32(&qvector) as f32 + }; + (dis_v_2, b, k, qvector_sum, gen(qvector)) + }; + + let mut result = (Distance::INFINITY, 0); + for block in 0..c.div_ceil(32) { + let lowerbound = fscan_process_lowerbound( + dims as _, + &lut, + (&a0[block], &a1[block], &a2[block], &a3[block], &a4[block]), + 1.9, + ); + for j in block * 32..std::cmp::min(block * 32 + 32, c) { + if lowerbound[j - block * 32] < result.0 { + let dis = + Distance::from_f32(f32::reduce_sum_of_d2(&samples[i], ¢roids[j])); + if dis <= result.0 { + result = (dis, j); + } + } + } + } + result.1 + }) + .collect::>() +} + +fn flat_index( + parallelism: &P, + _dims: usize, + n: usize, + c: usize, + samples: &Vec>, + centroids: &Vec>, +) -> Vec { + parallelism + .into_par_iter(0..n) + .map(|i| { + let mut result = (f32::INFINITY, 0); + for j in 0..c { + let dis_2 = f32::reduce_sum_of_d2(&samples[i], ¢roids[j]); + if dis_2 <= result.0 { + result = (dis_2, j); + } + } + result.1 + }) + .collect::>() +} + +struct LloydKMeans<'a, P, F> { parallelism: &'a P, dims: usize, c: usize, is_spherical: bool, - centroids: Vec>, + centroids: Vec>, assign: Vec, rng: StdRng, - samples: &'a Vec>, + samples: &'a Vec>, + compute: F, } const DELTA: f32 = f16::EPSILON.to_f32_const(); -impl<'a, P: Parallelism, S: ScalarLike> LloydKMeans<'a, P, S> { +impl<'a, P: Parallelism, F: Fn(&P, &Vec>) -> Vec> LloydKMeans<'a, P, F> { fn new( parallelism: &'a P, c: usize, dims: usize, - samples: &'a Vec>, + samples: &'a Vec>, is_spherical: bool, + compute: F, ) -> Self { let n = samples.len(); @@ -102,7 +308,7 @@ impl<'a, P: Parallelism, S: ScalarLike> LloydKMeans<'a, P, S> { .map(|i| { let mut result = (f32::INFINITY, 0); for j in 0..c { - let dis_2 = S::reduce_sum_of_d2(&samples[i], ¢roids[j]); + let dis_2 = f32::reduce_sum_of_d2(&samples[i], ¢roids[j]); if dis_2 <= result.0 { result = (dis_2, j); } @@ -120,6 +326,7 @@ impl<'a, P: Parallelism, S: ScalarLike> LloydKMeans<'a, P, S> { assign, rng, samples, + compute, } } @@ -134,18 +341,18 @@ impl<'a, P: Parallelism, S: ScalarLike> LloydKMeans<'a, P, S> { .parallelism .into_par_iter(0..n) .fold( - || (vec![vec![S::zero(); dims]; c], vec![0.0f32; c]), + || (vec![vec![f32::zero(); dims]; c], vec![0.0f32; c]), |(mut sum, mut count), i| { - S::vector_add_inplace(&mut sum[self.assign[i]], &samples[i]); + f32::vector_add_inplace(&mut sum[self.assign[i]], &samples[i]); count[self.assign[i]] += 1.0; (sum, count) }, ) .reduce( - || (vec![vec![S::zero(); dims]; c], vec![0.0f32; c]), + || (vec![vec![f32::zero(); dims]; c], vec![0.0f32; c]), |(mut sum, mut count), (sum_1, count_1)| { for i in 0..c { - S::vector_add_inplace(&mut sum[i], &sum_1[i]); + f32::vector_add_inplace(&mut sum[i], &sum_1[i]); count[i] += count_1[i]; } (sum, count) @@ -155,7 +362,7 @@ impl<'a, P: Parallelism, S: ScalarLike> LloydKMeans<'a, P, S> { let mut centroids = self .parallelism .into_par_iter(0..c) - .map(|i| S::vector_mul_scalar(&sum[i], 1.0 / count[i])) + .map(|i| f32::vector_mul_scalar(&sum[i], 1.0 / count[i])) .collect::>(); for i in 0..c { @@ -172,8 +379,8 @@ impl<'a, P: Parallelism, S: ScalarLike> LloydKMeans<'a, P, S> { o = (o + 1) % c; } centroids[i] = centroids[o].clone(); - S::kmeans_helper(&mut centroids[i], 1.0 + DELTA, 1.0 - DELTA); - S::kmeans_helper(&mut centroids[o], 1.0 - DELTA, 1.0 + DELTA); + f32::kmeans_helper(&mut centroids[i], 1.0 + DELTA, 1.0 - DELTA); + f32::kmeans_helper(&mut centroids[o], 1.0 - DELTA, 1.0 + DELTA); count[i] = count[o] / 2.0; count[o] -= count[i]; } @@ -182,25 +389,12 @@ impl<'a, P: Parallelism, S: ScalarLike> LloydKMeans<'a, P, S> { self.parallelism .into_par_iter(&mut centroids) .for_each(|centroid| { - let l = S::reduce_sum_of_x2(centroid).sqrt(); - S::vector_mul_scalar_inplace(centroid, 1.0 / l); + let l = f32::reduce_sum_of_x2(centroid).sqrt(); + f32::vector_mul_scalar_inplace(centroid, 1.0 / l); }); } - let assign = self - .parallelism - .into_par_iter(0..n) - .map(|i| { - let mut result = (f32::INFINITY, 0); - for j in 0..c { - let dis_2 = S::reduce_sum_of_d2(&samples[i], ¢roids[j]); - if dis_2 <= result.0 { - result = (dis_2, j); - } - } - result.1 - }) - .collect::>(); + let assign = (self.compute)(self.parallelism, ¢roids); let result = (0..n).all(|i| assign[i] == self.assign[i]); @@ -210,7 +404,7 @@ impl<'a, P: Parallelism, S: ScalarLike> LloydKMeans<'a, P, S> { result } - fn finish(self) -> Vec> { + fn finish(self) -> Vec> { self.centroids } } diff --git a/src/vchordrq/algorithm/build.rs b/src/vchordrq/algorithm/build.rs index 8df1905..c65bd85 100644 --- a/src/vchordrq/algorithm/build.rs +++ b/src/vchordrq/algorithm/build.rs @@ -48,8 +48,11 @@ pub fn build( let mut tuples_total = 0_u64; let samples = { let mut rand = rand::thread_rng(); - let max_number_of_samples = - internal_build.lists.last().unwrap().saturating_mul(256); + let max_number_of_samples = internal_build + .lists + .last() + .unwrap() + .saturating_mul(internal_build.sampling_factor); let mut samples = Vec::new(); let mut number_of_samples = 0_u32; heap_relation.traverse(false, |(_, vector)| { diff --git a/src/vchordrq/types.rs b/src/vchordrq/types.rs index 2002d7d..2301150 100644 --- a/src/vchordrq/types.rs +++ b/src/vchordrq/types.rs @@ -9,6 +9,9 @@ pub struct VchordrqInternalBuildOptions { pub lists: Vec, #[serde(default = "VchordrqInternalBuildOptions::default_spherical_centroids")] pub spherical_centroids: bool, + #[serde(default = "VchordrqInternalBuildOptions::default_sampling_factor")] + #[validate(range(min = 1, max = 1024))] + pub sampling_factor: u32, #[serde(default = "VchordrqInternalBuildOptions::default_build_threads")] #[validate(range(min = 1, max = 255))] pub build_threads: u16, @@ -30,6 +33,9 @@ impl VchordrqInternalBuildOptions { fn default_spherical_centroids() -> bool { false } + fn default_sampling_factor() -> u32 { + 256 + } fn default_build_threads() -> u16 { 1 } @@ -40,6 +46,7 @@ impl Default for VchordrqInternalBuildOptions { Self { lists: Self::default_lists(), spherical_centroids: Self::default_spherical_centroids(), + sampling_factor: Self::default_sampling_factor(), build_threads: Self::default_build_threads(), } } diff --git a/src/vchordrqfscan/algorithm/build.rs b/src/vchordrqfscan/algorithm/build.rs index 9bdc3b9..0528a5c 100644 --- a/src/vchordrqfscan/algorithm/build.rs +++ b/src/vchordrqfscan/algorithm/build.rs @@ -47,8 +47,11 @@ pub fn build( let mut tuples_total = 0_u64; let samples = { let mut rand = rand::thread_rng(); - let max_number_of_samples = - internal_build.lists.last().unwrap().saturating_mul(256); + let max_number_of_samples = internal_build + .lists + .last() + .unwrap() + .saturating_mul(internal_build.sampling_factor); let mut samples = Vec::new(); let mut number_of_samples = 0_u32; heap_relation.traverse(false, |(_, vector)| { diff --git a/src/vchordrqfscan/types.rs b/src/vchordrqfscan/types.rs index b3a1067..0fbe82e 100644 --- a/src/vchordrqfscan/types.rs +++ b/src/vchordrqfscan/types.rs @@ -9,6 +9,9 @@ pub struct VchordrqfscanInternalBuildOptions { pub lists: Vec, #[serde(default = "VchordrqfscanInternalBuildOptions::default_spherical_centroids")] pub spherical_centroids: bool, + #[serde(default = "VchordrqfscanInternalBuildOptions::default_sampling_factor")] + #[validate(range(min = 1, max = 1024))] + pub sampling_factor: u32, #[serde(default = "VchordrqfscanInternalBuildOptions::default_build_threads")] #[validate(range(min = 1, max = 255))] pub build_threads: u16, @@ -30,6 +33,9 @@ impl VchordrqfscanInternalBuildOptions { fn default_spherical_centroids() -> bool { false } + fn default_sampling_factor() -> u32 { + 256 + } fn default_build_threads() -> u16 { 1 } @@ -40,6 +46,7 @@ impl Default for VchordrqfscanInternalBuildOptions { Self { lists: Self::default_lists(), spherical_centroids: Self::default_spherical_centroids(), + sampling_factor: Self::default_sampling_factor(), build_threads: Self::default_build_threads(), } }