From 1e5b9d7b2be915149591c011a68edc167c02785b Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 24 Nov 2024 20:18:00 +0000 Subject: [PATCH] Improve: cBLAS Bilinear Form benchmarks --- scripts/bench.cxx | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/scripts/bench.cxx b/scripts/bench.cxx index 4bc4b801..254b9759 100644 --- a/scripts/bench.cxx +++ b/scripts/bench.cxx @@ -664,6 +664,41 @@ void vdot_f64c_blas(simsimd_f64c_t const *a, simsimd_f64c_t const *b, simsimd_si cblas_zdotc_sub((int)n, (simsimd_f64_t const *)a, 1, (simsimd_f64_t const *)b, 1, result); } +void bilinear_f32_blas(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_f32_t const *c, simsimd_size_t n, + simsimd_distance_t *result) { + simsimd_f32_t intermediate[n]; + simsimd_f32_t alpha = 1.0f, beta = 0.0f; + cblas_sgemv(CblasRowMajor, CblasNoTrans, (int)n, (int)n, alpha, c, (int)n, b, 1, beta, intermediate, 1); + *result = cblas_sdot((int)n, a, 1, intermediate, 1); +} + +void bilinear_f64_blas(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_f64_t const *c, simsimd_size_t n, + simsimd_distance_t *result) { + simsimd_f64_t intermediate[n]; + simsimd_f64_t alpha = 1.0, beta = 0.0; + cblas_dgemv(CblasRowMajor, CblasNoTrans, (int)n, (int)n, alpha, c, n, b, 1, beta, intermediate, 1); + *result = cblas_ddot((int)n, a, 1, intermediate, 1); +} + +void bilinear_f32c_blas(simsimd_f32c_t const *a, simsimd_f32c_t const *b, simsimd_f32c_t const *c, simsimd_size_t n, + simsimd_distance_t *results) { + simsimd_f32c_t intermediate[n]; + simsimd_f32c_t alpha = {1.0f, 0.0f}, beta = {0.0f, 0.0f}; + cblas_cgemv(CblasRowMajor, CblasNoTrans, (int)n, (int)n, &alpha, c, n, b, 1, &beta, intermediate, 1); + simsimd_f32_t f32_result[2] = {0, 0}; + cblas_cdotu_sub((int)n, (simsimd_f32_t const *)a, 1, (simsimd_f32_t const *)intermediate, 1, f32_result); + results[0] = f32_result[0]; + results[1] = f32_result[1]; +} + +void bilinear_f64c_blas(simsimd_f64c_t const *a, simsimd_f64c_t const *b, simsimd_f64c_t const *c, simsimd_size_t n, + simsimd_distance_t *results) { + simsimd_f64c_t intermediate[n]; + simsimd_f64c_t alpha = {1.0, 0.0}, beta = {0.0, 0.0}; + cblas_zgemv(CblasRowMajor, CblasNoTrans, (int)n, (int)n, &alpha, c, n, b, 1, &beta, intermediate, 1); + cblas_zdotu_sub((int)n, (simsimd_f64_t const *)a, 1, (simsimd_f64_t const *)intermediate, 1, results); +} + #endif int main(int argc, char **argv) { @@ -738,6 +773,11 @@ int main(int argc, char **argv) { dense_("vdot_f32c_blas", vdot_f32c_blas, simsimd_vdot_f32c_accurate); dense_("vdot_f64c_blas", vdot_f64c_blas, simsimd_vdot_f64c_serial); + curved_("bilinear_f64_blas", bilinear_f64_blas, simsimd_bilinear_f64_serial); + curved_("bilinear_f64c_blas", bilinear_f64c_blas, simsimd_bilinear_f64c_serial); + curved_("bilinear_f32_blas", bilinear_f32_blas, simsimd_bilinear_f32_accurate); + curved_("bilinear_f32c_blas", bilinear_f32c_blas, simsimd_bilinear_f32c_accurate); + #endif #if SIMSIMD_TARGET_NEON @@ -995,6 +1035,8 @@ int main(int argc, char **argv) { fma_("fma_bf16_skylake", simsimd_fma_bf16_skylake, simsimd_fma_bf16_accurate, simsimd_l2_bf16_accurate); fma_("wsum_bf16_skylake", simsimd_wsum_bf16_skylake, simsimd_wsum_bf16_accurate, simsimd_l2_bf16_accurate); + curved_("bilinear_f64_skylake", simsimd_bilinear_f64_skylake, simsimd_bilinear_f64_serial); + curved_("bilinear_f64c_skylake", simsimd_bilinear_f64c_skylake, simsimd_bilinear_f64c_serial); #endif sparse_("intersect_u16_serial", simsimd_intersect_u16_serial, simsimd_intersect_u16_accurate); @@ -1003,12 +1045,16 @@ int main(int argc, char **argv) { sparse_("intersect_u32_accurate", simsimd_intersect_u32_accurate, simsimd_intersect_u32_accurate); curved_("bilinear_f64_serial", simsimd_bilinear_f64_serial, simsimd_bilinear_f64_serial); + curved_("bilinear_f64c_serial", simsimd_bilinear_f64c_serial, simsimd_bilinear_f64c_serial); curved_("mahalanobis_f64_serial", simsimd_mahalanobis_f64_serial, simsimd_mahalanobis_f64_serial); curved_("bilinear_f32_serial", simsimd_bilinear_f32_serial, simsimd_bilinear_f32_accurate); + curved_("bilinear_f32c_serial", simsimd_bilinear_f32c_serial, simsimd_bilinear_f32c_accurate); curved_("mahalanobis_f32_serial", simsimd_mahalanobis_f32_serial, simsimd_mahalanobis_f32_accurate); curved_("bilinear_f16_serial", simsimd_bilinear_f16_serial, simsimd_bilinear_f16_accurate); + curved_("bilinear_f16c_serial", simsimd_bilinear_f16c_serial, simsimd_bilinear_f16c_accurate); curved_("mahalanobis_f16_serial", simsimd_mahalanobis_f16_serial, simsimd_mahalanobis_f16_accurate); curved_("bilinear_bf16_serial", simsimd_bilinear_bf16_serial, simsimd_bilinear_bf16_accurate); + curved_("bilinear_bf16c_serial", simsimd_bilinear_bf16c_serial, simsimd_bilinear_bf16c_accurate); curved_("mahalanobis_bf16_serial", simsimd_mahalanobis_bf16_serial, simsimd_mahalanobis_bf16_accurate); dense_("dot_bf16_serial", simsimd_dot_bf16_serial, simsimd_dot_bf16_accurate);