Skip to content

Commit

Permalink
Add: AVX2 f16 implementation of divergence
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Oct 23, 2023
1 parent 9f17a89 commit edda544
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 196 deletions.
10 changes: 10 additions & 0 deletions cpp/bench.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,14 @@ int main(int argc, char** argv) {
register_<simsimd_f16_t>("neon_f16_ip", simsimd_neon_f16_ip, simsimd_accurate_f16_ip);
register_<simsimd_f16_t>("neon_f16_cos", simsimd_neon_f16_cos, simsimd_accurate_f16_cos);
register_<simsimd_f16_t>("neon_f16_l2sq", simsimd_neon_f16_l2sq, simsimd_accurate_f16_l2sq);
register_<simsimd_f16_t>("neon_f16_kl", simsimd_neon_f16_kl, simsimd_accurate_f16_kl);
register_<simsimd_f16_t>("neon_f16_js", simsimd_neon_f16_js, simsimd_accurate_f16_js);

register_<simsimd_f32_t>("neon_f32_ip", simsimd_neon_f32_ip, simsimd_accurate_f32_ip);
register_<simsimd_f32_t>("neon_f32_cos", simsimd_neon_f32_cos, simsimd_accurate_f32_cos);
register_<simsimd_f32_t>("neon_f32_l2sq", simsimd_neon_f32_l2sq, simsimd_accurate_f32_l2sq);
register_<simsimd_f32_t>("neon_f32_kl", simsimd_neon_f32_kl, simsimd_accurate_f32_kl);
register_<simsimd_f32_t>("neon_f32_js", simsimd_neon_f32_js, simsimd_accurate_f32_js);

register_<simsimd_i8_t>("neon_i8_cos", simsimd_neon_i8_cos, simsimd_accurate_i8_cos);
register_<simsimd_i8_t>("neon_i8_l2sq", simsimd_neon_i8_l2sq, simsimd_accurate_i8_l2sq);
Expand All @@ -152,6 +156,8 @@ int main(int argc, char** argv) {
register_<simsimd_f16_t>("avx2_f16_ip", simsimd_avx2_f16_ip, simsimd_accurate_f16_ip);
register_<simsimd_f16_t>("avx2_f16_cos", simsimd_avx2_f16_cos, simsimd_accurate_f16_cos);
register_<simsimd_f16_t>("avx2_f16_l2sq", simsimd_avx2_f16_l2sq, simsimd_accurate_f16_l2sq);
register_<simsimd_f16_t>("avx2_f16_kl", simsimd_avx2_f16_kl, simsimd_accurate_f16_kl);
register_<simsimd_f16_t>("avx2_f16_js", simsimd_avx2_f16_js, simsimd_accurate_f16_js);

register_<simsimd_i8_t>("avx2_i8_cos", simsimd_avx2_i8_cos, simsimd_accurate_i8_cos);
register_<simsimd_i8_t>("avx2_i8_l2sq", simsimd_avx2_i8_l2sq, simsimd_accurate_i8_l2sq);
Expand All @@ -173,10 +179,14 @@ int main(int argc, char** argv) {
register_<simsimd_f16_t>("serial_f16_ip", simsimd_serial_f16_ip, simsimd_accurate_f16_ip);
register_<simsimd_f16_t>("serial_f16_cos", simsimd_serial_f16_cos, simsimd_accurate_f16_cos);
register_<simsimd_f16_t>("serial_f16_l2sq", simsimd_serial_f16_l2sq, simsimd_accurate_f16_l2sq);
register_<simsimd_f16_t>("serial_f16_kl", simsimd_serial_f16_kl, simsimd_accurate_f16_kl);
register_<simsimd_f16_t>("serial_f16_js", simsimd_serial_f16_js, simsimd_accurate_f16_js);

register_<simsimd_f32_t>("serial_f32_ip", simsimd_serial_f32_ip, simsimd_accurate_f32_ip);
register_<simsimd_f32_t>("serial_f32_cos", simsimd_serial_f32_cos, simsimd_accurate_f32_cos);
register_<simsimd_f32_t>("serial_f32_l2sq", simsimd_serial_f32_l2sq, simsimd_accurate_f32_l2sq);
register_<simsimd_f32_t>("serial_f32_kl", simsimd_serial_f32_kl, simsimd_accurate_f32_kl);
register_<simsimd_f32_t>("serial_f32_js", simsimd_serial_f32_js, simsimd_accurate_f32_js);

register_<simsimd_i8_t>("serial_i8_cos", simsimd_serial_i8_cos, simsimd_accurate_i8_cos);
register_<simsimd_i8_t>("serial_i8_l2sq", simsimd_serial_i8_l2sq, simsimd_accurate_i8_l2sq);
Expand Down
249 changes: 53 additions & 196 deletions include/simsimd/probability.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ simsimd_neon_f16_js(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size
#endif // SIMSIMD_TARGET_ARM_NEON
#endif // SIMSIMD_TARGET_ARM

#if SIMSIMD_TARGET_X86 && 0
#if SIMSIMD_TARGET_X86
#if SIMSIMD_TARGET_X86_AVX2

/*
Expand All @@ -216,233 +216,90 @@ simsimd_neon_f16_js(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size
* - Requires compiler capabilities: avx2, f16c, fma.
*/

__attribute__((target("avx2,f16c,fma"))) //
inline static simsimd_f32_t
simsimd_avx2_f16_kl(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n) {
__m256 d2_vec = _mm256_set1_ps(0);
simsimd_size_t i = 0;
for (; i + 8 <= n; i += 8) {
__m256 a_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i)));
__m256 b_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i)));
__m256 d_vec = _mm256_sub_ps(a_vec, b_vec);
d2_vec = _mm256_fmadd_ps(d_vec, d_vec, d2_vec);
}

d2_vec = _mm256_add_ps(_mm256_permute2f128_ps(d2_vec, d2_vec, 1), d2_vec);
d2_vec = _mm256_hadd_ps(d2_vec, d2_vec);
d2_vec = _mm256_hadd_ps(d2_vec, d2_vec);

simsimd_f32_t result;
_mm_store_ss(&result, _mm256_castps256_ps128(d2_vec));

// Accumulate the tail:
for (; i < n; ++i) {
simsimd_f32_t n = SIMSIMD_UNCOMPRESS_F16(a[i]) - SIMSIMD_UNCOMPRESS_F16(b[i]);
result += n * n;
}
return result;
__attribute__((target("avx2,fma"))) //
inline static __m256
simsimd_avx2_f32_log(__m256 x) {
__m256 a = _mm256_fmadd_ps(_mm256_set1_ps(5.17591238022f), x, _mm256_set1_ps(-2.29561495781f));
__m256 b = _mm256_fmadd_ps(_mm256_set1_ps(0.844007015228f), x, _mm256_set1_ps(-5.68692588806f));
__m256 c = _mm256_fmadd_ps(_mm256_set1_ps(4.58445882797f), x, _mm256_set1_ps(-2.47071170807f));
__m256 d = _mm256_fmadd_ps(_mm256_set1_ps(0.0141278216615f), x, _mm256_set1_ps(-0.165253549814f));
__m256 x2 = _mm256_mul_ps(x, x);
__m256 x4 = _mm256_mul_ps(x2, x2);
return _mm256_add_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(c, d, x2), b, x2), _mm256_mul_ps(a, x4));
}

__attribute__((target("avx2,f16c,fma"))) //
inline static simsimd_f32_t
simsimd_avx2_f16_js(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n) {
__m256 ab_vec = _mm256_set1_ps(0);
simsimd_avx2_f16_kl(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n) {
__m256 sum_vec = _mm256_set1_ps(0);
simsimd_f32_t epsilon = 1e-3;
__m256 epsilon_vec = _mm256_set1_ps(epsilon);
simsimd_size_t i = 0;
for (; i + 8 <= n; i += 8) {
__m256 a_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i)));
__m256 b_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i)));
ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec);
__m256 ratio_vec = _mm256_div_ps(_mm256_add_ps(a_vec, epsilon_vec), _mm256_add_ps(b_vec, epsilon_vec));
__m256 log_ratio_vec = simsimd_avx2_f32_log(ratio_vec);
__m256 prod_vec = _mm256_mul_ps(a_vec, log_ratio_vec);
sum_vec = _mm256_add_ps(sum_vec, prod_vec);
}

ab_vec = _mm256_add_ps(_mm256_permute2f128_ps(ab_vec, ab_vec, 1), ab_vec);
ab_vec = _mm256_hadd_ps(ab_vec, ab_vec);
ab_vec = _mm256_hadd_ps(ab_vec, ab_vec);
sum_vec = _mm256_add_ps(_mm256_permute2f128_ps(sum_vec, sum_vec, 1), sum_vec);
sum_vec = _mm256_hadd_ps(sum_vec, sum_vec);
sum_vec = _mm256_hadd_ps(sum_vec, sum_vec);

simsimd_f32_t result;
_mm_store_ss(&result, _mm256_castps256_ps128(ab_vec));
simsimd_f32_t sum;
_mm_store_ss(&sum, _mm256_castps256_ps128(sum_vec));

// Accumulate the tail:
for (; i < n; ++i)
result += SIMSIMD_UNCOMPRESS_F16(a[i]) * SIMSIMD_UNCOMPRESS_F16(b[i]);
return 1 - result;
sum += SIMSIMD_UNCOMPRESS_F16(a[i]) *
SIMSIMD_LOG((SIMSIMD_UNCOMPRESS_F16(a[i]) + epsilon) / (SIMSIMD_UNCOMPRESS_F16(b[i]) + epsilon));
return sum;
}

__attribute__((target("avx2,f16c,fma"))) //
inline static simsimd_f32_t
simsimd_avx2_f16_cos(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n) {

__m256 ab_vec = _mm256_set1_ps(0), a2_vec = _mm256_set1_ps(0), b2_vec = _mm256_set1_ps(0);
simsimd_avx2_f16_js(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n) {
__m256 sum_vec = _mm256_set1_ps(0);
simsimd_f32_t epsilon = 1e-3;
__m256 epsilon_vec = _mm256_set1_ps(epsilon);
simsimd_size_t i = 0;
for (; i + 8 <= n; i += 8) {
__m256 a_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i)));
__m256 b_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i)));
ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec);
a2_vec = _mm256_fmadd_ps(a_vec, a_vec, a2_vec);
b2_vec = _mm256_fmadd_ps(b_vec, b_vec, b2_vec);
__m256 m_vec = _mm256_add_ps(a_vec, b_vec); // M = P + Q
__m256 ratio_a_vec = _mm256_div_ps(_mm256_add_ps(a_vec, epsilon_vec), _mm256_add_ps(m_vec, epsilon_vec));
__m256 ratio_b_vec = _mm256_div_ps(_mm256_add_ps(b_vec, epsilon_vec), _mm256_add_ps(m_vec, epsilon_vec));
__m256 log_ratio_a_vec = simsimd_avx2_f32_log(ratio_a_vec);
__m256 log_ratio_b_vec = simsimd_avx2_f32_log(ratio_b_vec);
__m256 prod_a_vec = _mm256_mul_ps(a_vec, log_ratio_a_vec);
__m256 prod_b_vec = _mm256_mul_ps(b_vec, log_ratio_b_vec);
sum_vec = _mm256_add_ps(sum_vec, prod_a_vec);
sum_vec = _mm256_add_ps(sum_vec, prod_b_vec);
}

ab_vec = _mm256_add_ps(_mm256_permute2f128_ps(ab_vec, ab_vec, 1), ab_vec);
ab_vec = _mm256_hadd_ps(ab_vec, ab_vec);
ab_vec = _mm256_hadd_ps(ab_vec, ab_vec);

a2_vec = _mm256_add_ps(_mm256_permute2f128_ps(a2_vec, a2_vec, 1), a2_vec);
a2_vec = _mm256_hadd_ps(a2_vec, a2_vec);
a2_vec = _mm256_hadd_ps(a2_vec, a2_vec);
sum_vec = _mm256_add_ps(_mm256_permute2f128_ps(sum_vec, sum_vec, 1), sum_vec);
sum_vec = _mm256_hadd_ps(sum_vec, sum_vec);
sum_vec = _mm256_hadd_ps(sum_vec, sum_vec);

b2_vec = _mm256_add_ps(_mm256_permute2f128_ps(b2_vec, b2_vec, 1), b2_vec);
b2_vec = _mm256_hadd_ps(b2_vec, b2_vec);
b2_vec = _mm256_hadd_ps(b2_vec, b2_vec);

simsimd_f32_t ab, a2, b2;
_mm_store_ss(&ab, _mm256_castps256_ps128(ab_vec));
_mm_store_ss(&a2, _mm256_castps256_ps128(a2_vec));
_mm_store_ss(&b2, _mm256_castps256_ps128(b2_vec));
simsimd_f32_t sum;
_mm_store_ss(&sum, _mm256_castps256_ps128(sum_vec));

// Accumulate the tail:
for (; i < n; ++i) {
simsimd_f32_t ai = SIMSIMD_UNCOMPRESS_F16(a[i]), bi = SIMSIMD_UNCOMPRESS_F16(b[i]);
ab += ai * bi, a2 += ai * ai, b2 += bi * bi;
}

// Replace simsimd_approximate_inverse_square_root with `rsqrtss`
__m128 a2_sqrt_recjs = _mm_rsqrt_ss(_mm_set_ss((float)a2));
__m128 b2_sqrt_recjs = _mm_rsqrt_ss(_mm_set_ss((float)b2));
__m128 result = _mm_mul_ss(a2_sqrt_recjs, b2_sqrt_recjs); // Multjsly the recjsrocal square roots
result = _mm_mul_ss(result, _mm_set_ss((float)ab)); // Multjsly by ab
result = _mm_sub_ss(_mm_set_ss(1.0f), result); // Subtract from 1
return ab != 0 ? _mm_cvtss_f32(result) : 1; // Extract the final result
}

/*
* @file x86_avx2_i8.h
* @brief x86 AVX2 implementation of the most common similarity metrics for 8-bit signed integral numbers.
* @author Ash Vardanian
*
* - Implements: L2 squared, cosine similarity, inner product (same as cosine).
* - As AVX2 doesn't support masked loads of 16-bit words, implementations have a separate `for`-loop for tails.
* - Uses `i8` for storage, `i16` for multjslication, and `i32` for accumulation, if no better option is available.
* - Requires compiler capabilities: avx2.
*/

__attribute__((target("avx2"))) //
inline static simsimd_f32_t
simsimd_avx2_i8_kl(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n) {

__m256i d2_high_vec = _mm256_setzero_si256();
__m256i d2_low_vec = _mm256_setzero_si256();

simsimd_size_t i = 0;
for (; i + 32 <= n; i += 32) {
__m256i a_vec = _mm256_loadu_si256((__m256i const*)(a + i));
__m256i b_vec = _mm256_loadu_si256((__m256i const*)(b + i));

// Sign extend int8 to int16
__m256i a_low = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(a_vec));
__m256i a_high = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(a_vec, 1));
__m256i b_low = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(b_vec));
__m256i b_high = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(b_vec, 1));

// Subtract and multjsly
__m256i d_low = _mm256_sub_epi16(a_low, b_low);
__m256i d_high = _mm256_sub_epi16(a_high, b_high);
__m256i d2_low_part = _mm256_madd_epi16(d_low, d_low);
__m256i d2_high_part = _mm256_madd_epi16(d_high, d_high);

// Accumulate into int32 vectors
d2_low_vec = _mm256_add_epi32(d2_low_vec, d2_low_part);
d2_high_vec = _mm256_add_epi32(d2_high_vec, d2_high_part);
}

// Accumulate the 32-bit integers from `d2_high_vec` and `d2_low_vec`
__m256i d2_vec = _mm256_add_epi32(d2_low_vec, d2_high_vec);
__m128i d2_sum = _mm_add_epi32(_mm256_extracti128_si256(d2_vec, 0), _mm256_extracti128_si256(d2_vec, 1));
d2_sum = _mm_hadd_epi32(d2_sum, d2_sum);
d2_sum = _mm_hadd_epi32(d2_sum, d2_sum);
int d2 = _mm_extract_epi32(d2_sum, 0);

// Take care of the tail:
for (; i < n; ++i) {
int n = a[i] - b[i];
d2 += n * n;
}

return (simsimd_f32_t)d2;
}
__attribute__((target("avx2"))) //
inline static simsimd_f32_t
simsimd_avx2_i8_cos(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n) {

__m256i ab_high_vec = _mm256_setzero_si256();
__m256i ab_low_vec = _mm256_setzero_si256();
__m256i a2_high_vec = _mm256_setzero_si256();
__m256i a2_low_vec = _mm256_setzero_si256();
__m256i b2_high_vec = _mm256_setzero_si256();
__m256i b2_low_vec = _mm256_setzero_si256();

simsimd_size_t i = 0;
for (; i + 32 <= n; i += 32) {
__m256i a_vec = _mm256_loadu_si256((__m256i const*)(a + i));
__m256i b_vec = _mm256_loadu_si256((__m256i const*)(b + i));

// Unpack int8 to int32
__m256i a_low = _mm256_cvtepi8_epi32(_mm256_castsi256_si128(a_vec));
__m256i a_high = _mm256_cvtepi8_epi32(_mm256_extracti128_si256(a_vec, 1));
__m256i b_low = _mm256_cvtepi8_epi32(_mm256_castsi256_si128(b_vec));
__m256i b_high = _mm256_cvtepi8_epi32(_mm256_extracti128_si256(b_vec, 1));

// Multjsly and accumulate
ab_low_vec = _mm256_add_epi32(ab_low_vec, _mm256_mullo_epi32(a_low, b_low));
ab_high_vec = _mm256_add_epi32(ab_high_vec, _mm256_mullo_epi32(a_high, b_high));
a2_low_vec = _mm256_add_epi32(a2_low_vec, _mm256_mullo_epi32(a_low, a_low));
a2_high_vec = _mm256_add_epi32(a2_high_vec, _mm256_mullo_epi32(a_high, a_high));
b2_low_vec = _mm256_add_epi32(b2_low_vec, _mm256_mullo_epi32(b_low, b_low));
b2_high_vec = _mm256_add_epi32(b2_high_vec, _mm256_mullo_epi32(b_high, b_high));
}

// Horizontal sum across the 256-bit register
__m256i ab_vec = _mm256_add_epi32(ab_low_vec, ab_high_vec);
__m128i ab_sum = _mm_add_epi32(_mm256_extracti128_si256(ab_vec, 0), _mm256_extracti128_si256(ab_vec, 1));
ab_sum = _mm_hadd_epi32(ab_sum, ab_sum);
ab_sum = _mm_hadd_epi32(ab_sum, ab_sum);

__m256i a2_vec = _mm256_add_epi32(a2_low_vec, a2_high_vec);
__m128i a2_sum = _mm_add_epi32(_mm256_extracti128_si256(a2_vec, 0), _mm256_extracti128_si256(a2_vec, 1));
a2_sum = _mm_hadd_epi32(a2_sum, a2_sum);
a2_sum = _mm_hadd_epi32(a2_sum, a2_sum);

__m256i b2_vec = _mm256_add_epi32(b2_low_vec, b2_high_vec);
__m128i b2_sum = _mm_add_epi32(_mm256_extracti128_si256(b2_vec, 0), _mm256_extracti128_si256(b2_vec, 1));
b2_sum = _mm_hadd_epi32(b2_sum, b2_sum);
b2_sum = _mm_hadd_epi32(b2_sum, b2_sum);

// Further reduce to a single sum for each vector
int ab = _mm_extract_epi32(ab_sum, 0);
int a2 = _mm_extract_epi32(a2_sum, 0);
int b2 = _mm_extract_epi32(b2_sum, 0);

// Take care of the tail:
for (; i < n; ++i) {
int ai = a[i], bi = b[i];
ab += ai * bi, a2 += ai * ai, b2 += bi * bi;
simsimd_f32_t ai = SIMSIMD_UNCOMPRESS_F16(a[i]);
simsimd_f32_t bi = SIMSIMD_UNCOMPRESS_F16(b[i]);
simsimd_f32_t mi = ai + bi;
sum += a[i] * SIMSIMD_LOG((a[i] + epsilon) / (mi + epsilon));
sum += b[i] * SIMSIMD_LOG((b[i] + epsilon) / (mi + epsilon));
}

// Replace simsimd_approximate_inverse_square_root with `rsqrtss`
__m128 a2_sqrt_recjs = _mm_rsqrt_ss(_mm_set_ss((float)a2));
__m128 b2_sqrt_recjs = _mm_rsqrt_ss(_mm_set_ss((float)b2));
__m128 result = _mm_mul_ss(a2_sqrt_recjs, b2_sqrt_recjs); // Multjsly the recjsrocal square roots
result = _mm_mul_ss(result, _mm_set_ss((float)ab)); // Multjsly by ab
result = _mm_sub_ss(_mm_set_ss(1.0f), result); // Subtract from 1
return ab != 0 ? _mm_cvtss_f32(result) : 1; // Extract the final result
}

__attribute__((target("avx2"))) //
inline static simsimd_f32_t
simsimd_avx2_i8_js(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n) {
return simsimd_avx2_i8_cos(a, b, n);
return sum / 2;
}

#endif // SIMSIMD_TARGET_X86_AVX2

#if SIMSIMD_TARGET_X86_AVX512
#if SIMSIMD_TARGET_X86_AVX512 && 0

/*
* @file x86_avx512_f32.h
Expand Down

0 comments on commit edda544

Please sign in to comment.