Skip to content

Commit

Permalink
Improve: Expose intersections to Rust (#238)
Browse files Browse the repository at this point in the history
Closes #184
  • Loading branch information
GoWind authored Nov 20, 2024
1 parent 96ee8a5 commit 4037db5
Showing 1 changed file with 82 additions and 0 deletions.
82 changes: 82 additions & 0 deletions rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ extern "C" {
fn simsimd_kl_f32(a: *const f32, b: *const f32, c: usize, d: *mut Distance);
fn simsimd_kl_f64(a: *const f64, b: *const f64, c: usize, d: *mut Distance);

fn simsimd_intersect_u16(a: *const u16, b: *const u16, a_length: usize, b_length: usize, d: *mut Distance);
fn simsimd_intersect_u32(a: *const u32, b: *const u32, a_length: usize, b_length: usize, d: *mut Distance);

fn simsimd_uses_neon() -> i32;
fn simsimd_uses_neon_f16() -> i32;
fn simsimd_uses_neon_bf16() -> i32;
Expand Down Expand Up @@ -315,6 +318,16 @@ where
fn vdot(a: &[Self], b: &[Self]) -> Option<ComplexProduct>;
}

/// `Sparse` provides trait methods for sparse vectors.
pub trait Sparse
where
Self: Sized,
{
/// Computes the number of common elements between two sparse vectors.
/// both vectors must be sorted in ascending order.
fn intersect(a: &[Self], b: &[Self]) -> Option<Distance>;
}

impl BinarySimilarity for u8 {
fn hamming(a: &[Self], b: &[Self]) -> Option<Distance> {
if a.len() != b.len() {
Expand Down Expand Up @@ -379,6 +392,28 @@ impl SpatialSimilarity for i8 {
}
}

impl Sparse for u16 {

fn intersect(a: &[Self], b: &[Self]) -> Option<Distance> {
let mut distance_value: Distance = 0.0;
let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
unsafe { simsimd_intersect_u16(a.as_ptr(), b.as_ptr(), a.len(), b.len(), distance_ptr) };
Some(distance_value)
}

}

impl Sparse for u32 {

fn intersect(a: &[Self], b: &[Self]) -> Option<Distance> {
let mut distance_value: Distance = 0.0;
let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
unsafe { simsimd_intersect_u32(a.as_ptr(), b.as_ptr(), a.len(), b.len(), distance_ptr) };
Some(distance_value)
}

}

impl SpatialSimilarity for f16 {
fn cos(a: &[Self], b: &[Self]) -> Option<Distance> {
if a.len() != b.len() {
Expand Down Expand Up @@ -1121,4 +1156,51 @@ mod tests {
assert_almost_equal(0.025, result, 0.01);
}
}

#[test]
fn test_intersect_u16() {
{
let a_u16: &[u16] = &[153, 16384, 17408];
let b_u16: &[u16] = &[15360, 16384, 7408];

if let Some(result) = Sparse::intersect(a_u16, b_u16) {
println!("The result of intersect_u16 is {:.8}", result);
assert_almost_equal(1.0, result, 0.0001);
}
}

{
let a_u16: &[u16] = &[153, 11638, 08];
let b_u16: &[u16] = &[15360, 16384, 7408];

if let Some(result) = Sparse::intersect(a_u16, b_u16) {
println!("The result of intersect_u16 is {:.8}", result);
assert_almost_equal(0.0, result, 0.0001);
}
}
}

#[test]
fn test_intersect_u32() {
{
let a_u32: &[u32] = &[11, 153];
let b_u32: &[u32] = &[11, 153, 7408, 16384];

if let Some(result) = Sparse::intersect(a_u32, b_u32) {
println!("The result of intersect_u32 is {:.8}", result);
assert_almost_equal(2.0, result, 0.0001);
}
}

{
let a_u32: &[u32] = &[153, 7408, 11638];
let b_u32: &[u32] = &[153, 7408, 11638];

if let Some(result) = Sparse::intersect(a_u32, b_u32) {
println!("The result of intersect_u32 is {:.8}", result);
assert_almost_equal(3.0, result, 0.0001);
}
}

}
}

0 comments on commit 4037db5

Please sign in to comment.