Skip to content

Commit

Permalink
Add: i8, f16, and bf16 kernels
Browse files Browse the repository at this point in the history
The next steps would include more
AVX-512, AVX2, AMX, and SME on Arm.

ashvardanian/ParallelReductionsBenchmark#2
  • Loading branch information
ashvardanian committed Jan 12, 2025
1 parent d0e521e commit 3f54200
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 29 deletions.
62 changes: 41 additions & 21 deletions less_slow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1343,37 +1343,57 @@ BENCHMARK(f32x4x4_matmul_avx512);
* means the performance will still degrade—@b around 5ns in practice.
*
* Benchmark everything! Don't assume less work translates to faster execution.
* Read the specs of your hardware to understand it's theoretical upper limits,
* and double-check them with stress-tests. Pure @b Assembly is perfect for this!
*/

#if defined(__AVX512F__)

extern "C" std::uint32_t f32_matmul_avx512_flops_asm_kernel(void);
typedef std::uint32_t (*theoretic_tops_kernel_t)(void);

static void f32_matmul_avx512_flops(bm::State &state) {
std::size_t flops = 0;
for (auto _ : state) bm::DoNotOptimize(flops += f32_matmul_avx512_flops_asm_kernel());
state.SetItemsProcessed(flops);
static void measure_tops(bm::State &state, theoretic_tops_kernel_t theoretic_tops_kernel) {
std::size_t tops = 0;
for (auto _ : state) bm::DoNotOptimize(tops += theoretic_tops_kernel());
state.SetItemsProcessed(tops);
}

BENCHMARK(f32_matmul_avx512_flops)->MinTime(10);
BENCHMARK(f32_matmul_avx512_flops)->MinTime(10)->Threads(physical_cores());
/**
* Assuming we are not aiming for dynamic dispatch, we can simply check for
* the available features at compile time with more preprocessing directives:
*
* @see Arm Feature Detection: https://developer.arm.com/documentation/101028/0010/Feature-test-macros
*/
#if defined(__AVX512F__)

extern "C" std::uint32_t tops_f32_avx512_asm_kernel(void);
BENCHMARK_CAPTURE(measure_tops, tops_f32_avx512, tops_f32_avx512_asm_kernel)->MinTime(10);
BENCHMARK_CAPTURE(measure_tops, tops_f32_avx512, tops_f32_avx512_asm_kernel)->MinTime(10)->Threads(physical_cores());

#endif // defined(__AVX512F__)

#if defined(__ARM_NEON)

extern "C" std::uint32_t f32_matmul_neon_flops_asm_kernel(void);

static void f32_matmul_neon_flops(benchmark::State &state) {
std::size_t flops = 0;
for (auto _ : state) bm::DoNotOptimize(flops += f32_matmul_neon_flops_asm_kernel());
state.SetItemsProcessed(flops);
}

BENCHMARK(f32_matmul_neon_flops)->MinTime(10);
BENCHMARK(f32_matmul_neon_flops)->MinTime(10)->Threads(physical_cores());

#endif
extern "C" std::uint32_t tops_f32_neon_asm_kernel(void);
BENCHMARK_CAPTURE(measure_tops, tops_f32_neon, tops_f32_neon_asm_kernel)->MinTime(10);
BENCHMARK_CAPTURE(measure_tops, tops_f32_neon, tops_f32_neon_asm_kernel)->MinTime(10)->Threads(physical_cores());

#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
extern "C" std::uint32_t tops_f16_neon_asm_kernel(void);
BENCHMARK_CAPTURE(measure_tops, tops_f16_neon, tops_f16_neon_asm_kernel)->MinTime(10);
BENCHMARK_CAPTURE(measure_tops, tops_f16_neon, tops_f16_neon_asm_kernel)->MinTime(10)->Threads(physical_cores());
#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)

#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)
extern "C" std::uint32_t tops_bf16_neon_asm_kernel(void);
BENCHMARK_CAPTURE(measure_tops, tops_bf16_neon, tops_bf16_neon_asm_kernel)->MinTime(10);
BENCHMARK_CAPTURE(measure_tops, tops_bf16_neon, tops_bf16_neon_asm_kernel)->MinTime(10)->Threads(physical_cores());
#endif // defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)

#if defined(__ARM_FEATURE_DOTPROD)
extern "C" std::uint32_t tops_i8_neon_asm_kernel(void);
BENCHMARK_CAPTURE(measure_tops, tops_i8_neon, tops_i8_neon_asm_kernel)->MinTime(10);
BENCHMARK_CAPTURE(measure_tops, tops_i8_neon, tops_i8_neon_asm_kernel)->MinTime(10)->Threads(physical_cores());
#endif // defined(__ARM_FEATURE_DOTPROD)

#endif // defined(__ARM_NEON)

#pragma endregion // Compute vs Memory Bounds with Matrix Multiplications

Expand Down
79 changes: 73 additions & 6 deletions less_slow_aarch64.S
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,26 @@

.section .text
.global i32_add_asm_kernel
.global f32_matmul_neon_flops_asm_kernel
.global tops_f32_neon_asm_kernel
.global tops_f16_neon_asm_kernel
.global tops_bf16_neon_asm_kernel
.global tops_i8_neon_asm_kernel

# ----------------------------------------------------------------------------
# Simple function that adds two 32-bit signed integers using AArch64 ABI.
# Arguments in W0 (a) and W1 (b). Return value in W0.
# Simple function that adds two 32-bit integers.
# AArch64 ABI: W0 = 'a', W1 = 'b'. Return in W0.
# ----------------------------------------------------------------------------
i32_add_asm_kernel:
add w0, w0, w1
ret

# ----------------------------------------------------------------------------
# NEON micro-kernel to maximize FLOPs.
# f32 micro-kernel maximizing FLOPs:
# Each FMLA vD.4s, vN.4s, vM.4s => 4 multiplies + 4 adds = 8 FLOPs.
# We'll do 10 instructions => 80 FLOPs total.
# Let's do 10 instructions => 10 × 8 = 80 FLOPs total.
# Return 80 in W0.
# ----------------------------------------------------------------------------
f32_matmul_neon_flops_asm_kernel:
tops_f32_neon_asm_kernel:
fmla v0.4s, v1.4s, v2.4s
fmla v3.4s, v4.4s, v5.4s
fmla v6.4s, v7.4s, v8.4s
Expand All @@ -36,3 +40,66 @@ f32_matmul_neon_flops_asm_kernel:
ret

# ----------------------------------------------------------------------------
# f16 micro-kernel:
# Requires Armv8.2 half-precision vector arithmetic.
# Each FMLA vD.8h, vN.8h, vM.8h => 8 multiplies + 8 adds = 16 FLOPs.
# We'll do 10 instructions => 160 FLOPs total, returning 160 in W0.
# ----------------------------------------------------------------------------
tops_f16_neon_asm_kernel:
fmla v0.8h, v1.8h, v2.8h
fmla v3.8h, v4.8h, v5.8h
fmla v6.8h, v7.8h, v8.8h
fmla v9.8h, v10.8h, v11.8h
fmla v12.8h, v13.8h, v14.8h
fmla v15.8h, v16.8h, v17.8h
fmla v18.8h, v19.8h, v20.8h
fmla v21.8h, v22.8h, v23.8h
fmla v24.8h, v25.8h, v26.8h
fmla v27.8h, v28.8h, v29.8h

mov w0, #160
ret

# ----------------------------------------------------------------------------
# bf16 micro-kernel:
# Requires Armv8.6 BF16 instructions (BFMMLA, etc.).
# bfmmla vD.4s, vN.8h, vM.8h => 8 multiplies + 8 adds = 16 FLOPs.
# We'll do 10 instructions => 160 FLOPs total, returning 160 in W0.
# ----------------------------------------------------------------------------
tops_bf16_neon_asm_kernel:
bfmmla v0.4s, v1.8h, v2.8h
bfmmla v3.4s, v4.8h, v5.8h
bfmmla v6.4s, v7.8h, v8.8h
bfmmla v9.4s, v10.8h, v11.8h
bfmmla v12.4s, v13.8h, v14.8h
bfmmla v15.4s, v16.8h, v17.8h
bfmmla v18.4s, v19.8h, v20.8h
bfmmla v21.4s, v22.8h, v23.8h
bfmmla v24.4s, v25.8h, v26.8h
bfmmla v27.4s, v28.8h, v29.8h

mov w0, #160
ret

# ----------------------------------------------------------------------------
# i8 micro-kernel:
# Requires Armv8.4 sdot or i8mm extension,
# sdot vD.4s, vN.16b, vM.16b => 16 multiplies + 16 adds = 32 FLOPs.
# We'll do 10 instructions => 320 FLOPs total, returning 320 in W0.
# ----------------------------------------------------------------------------
tops_i8_neon_asm_kernel:
sdot v0.4s, v1.16b, v2.16b
sdot v3.4s, v4.16b, v5.16b
sdot v6.4s, v7.16b, v8.16b
sdot v9.4s, v10.16b, v11.16b
sdot v12.4s, v13.16b, v14.16b
sdot v15.4s, v16.16b, v17.16b
sdot v18.4s, v19.16b, v20.16b
sdot v21.4s, v22.16b, v23.16b
sdot v24.4s, v25.16b, v26.16b
sdot v27.4s, v28.16b, v29.16b

mov w0, #320
ret

# ----------------------------------------------------------------------------
4 changes: 2 additions & 2 deletions less_slow_amd64.S
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

.section .text
.global i32_add_asm_kernel
.global f32_matmul_avx512_flops_asm_kernel
.global tops_f32_avx512_asm_kernel

# ----------------------------------------------------------------------------
# Simple function that adds two 32-bit signed integers using System V AMD64.
Expand All @@ -21,7 +21,7 @@ i32_add_asm_kernel:
# AVX-512 micro-kernel maximizing FLOPs across all ZMM registers.
# ----------------------------------------------------------------------------

f32_matmul_avx512_flops_asm_kernel:
tops_f32_avx512_asm_kernel:
#
# Each vfmadd231ps does: DEST = DEST + (SRC1 * SRC2)
# That is 16 multiplies + 16 adds = 32 FLOPs per instruction.
Expand Down

0 comments on commit 3f54200

Please sign in to comment.