Skip to content

Commit

Permalink
Fix: AVX2 int8 angular distance
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Oct 27, 2023
1 parent 2e1a714 commit 143aa34
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 23 deletions.
4 changes: 2 additions & 2 deletions cpp/bench.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

#include <benchmark/benchmark.h>

#define SIMSIMD_RSQRT sqrtf
#define SIMSIMD_LOG logf
#define SIMSIMD_RSQRT(x) (1 / sqrtf(x))
#define SIMSIMD_LOG(x) (logf(x))
#include <simsimd/simsimd.h>

namespace bm = benchmark;
Expand Down
44 changes: 23 additions & 21 deletions include/simsimd/spatial.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
a2 += ai * ai; \
b2 += bi * bi; \
} \
return ab != 0 ? 1 - ab * SIMSIMD_RSQRT(a2) * SIMSIMD_RSQRT(b2) : 1; \
return ab != 0 ? (1 - ab * SIMSIMD_RSQRT(a2) * SIMSIMD_RSQRT(b2)) : 1; \
}

#ifdef __cplusplus
Expand Down Expand Up @@ -650,31 +650,31 @@ __attribute__((target("avx2"))) //
inline static simsimd_f32_t
simsimd_avx2_i8_cos(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n) {

__m256i ab_high_vec = _mm256_setzero_si256();
__m256i ab_low_vec = _mm256_setzero_si256();
__m256i a2_high_vec = _mm256_setzero_si256();
__m256i ab_high_vec = _mm256_setzero_si256();
__m256i a2_low_vec = _mm256_setzero_si256();
__m256i b2_high_vec = _mm256_setzero_si256();
__m256i a2_high_vec = _mm256_setzero_si256();
__m256i b2_low_vec = _mm256_setzero_si256();
__m256i b2_high_vec = _mm256_setzero_si256();

simsimd_size_t i = 0;
for (; i + 32 <= n; i += 32) {
__m256i a_vec = _mm256_loadu_si256((__m256i const*)(a + i));
__m256i b_vec = _mm256_loadu_si256((__m256i const*)(b + i));

// Unpack int8 to int32
__m256i a_low = _mm256_cvtepi8_epi32(_mm256_castsi256_si128(a_vec));
__m256i a_high = _mm256_cvtepi8_epi32(_mm256_extracti128_si256(a_vec, 1));
__m256i b_low = _mm256_cvtepi8_epi32(_mm256_castsi256_si128(b_vec));
__m256i b_high = _mm256_cvtepi8_epi32(_mm256_extracti128_si256(b_vec, 1));
// Unpack int8 to int16
__m256i a_low_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(a_vec, 0));
__m256i a_high_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(a_vec, 1));
__m256i b_low_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(b_vec, 0));
__m256i b_high_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(b_vec, 1));

// Multiply and accumulate
ab_low_vec = _mm256_add_epi32(ab_low_vec, _mm256_mullo_epi32(a_low, b_low));
ab_high_vec = _mm256_add_epi32(ab_high_vec, _mm256_mullo_epi32(a_high, b_high));
a2_low_vec = _mm256_add_epi32(a2_low_vec, _mm256_mullo_epi32(a_low, a_low));
a2_high_vec = _mm256_add_epi32(a2_high_vec, _mm256_mullo_epi32(a_high, a_high));
b2_low_vec = _mm256_add_epi32(b2_low_vec, _mm256_mullo_epi32(b_low, b_low));
b2_high_vec = _mm256_add_epi32(b2_high_vec, _mm256_mullo_epi32(b_high, b_high));
// Multiply and accumulate at int16 level, accumulate at int32 level
ab_low_vec = _mm256_add_epi32(ab_low_vec, _mm256_madd_epi16(a_low_16, b_low_16));
ab_high_vec = _mm256_add_epi32(ab_high_vec, _mm256_madd_epi16(a_high_16, b_high_16));
a2_low_vec = _mm256_add_epi32(a2_low_vec, _mm256_madd_epi16(a_low_16, a_low_16));
a2_high_vec = _mm256_add_epi32(a2_high_vec, _mm256_madd_epi16(a_high_16, a_high_16));
b2_low_vec = _mm256_add_epi32(b2_low_vec, _mm256_madd_epi16(b_low_16, b_low_16));
b2_high_vec = _mm256_add_epi32(b2_high_vec, _mm256_madd_epi16(b_high_16, b_high_16));
}

// Horizontal sum across the 256-bit register
Expand Down Expand Up @@ -704,13 +704,15 @@ simsimd_avx2_i8_cos(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t
ab += ai * bi, a2 += ai * ai, b2 += bi * bi;
}

// Replace simsimd_approximate_inverse_square_root with `rsqrtss`
// Compute the reciprocal of the square roots
__m128 a2_sqrt_recip = _mm_rsqrt_ss(_mm_set_ss((float)a2));
__m128 b2_sqrt_recip = _mm_rsqrt_ss(_mm_set_ss((float)b2));
__m128 result = _mm_mul_ss(a2_sqrt_recip, b2_sqrt_recip); // Multiply the reciprocal square roots
result = _mm_mul_ss(result, _mm_set_ss((float)ab)); // Multiply by ab
result = _mm_sub_ss(_mm_set_ss(1.0f), result); // Subtract from 1
return ab != 0 ? _mm_cvtss_f32(result) : 1; // Extract the final result

// Compute cosine similarity: ab / sqrt(a2 * b2)
__m128 denom = _mm_mul_ss(a2_sqrt_recip, b2_sqrt_recip); // Reciprocal of sqrt(a2 * b2)
__m128 result = _mm_mul_ss(_mm_set_ss((float)ab), denom); // ab * reciprocal of sqrt(a2 * b2)

return ab != 0 ? 1 - _mm_cvtss_f32(result) : 0; // Extract the final result
}

__attribute__((target("avx2"))) //
Expand Down

0 comments on commit 143aa34

Please sign in to comment.