Skip to content

Commit

Permalink
[Experimental] Add Kleidi i8mm gemm kernels (pytorch#1295)
Browse files Browse the repository at this point in the history
* Update git ignore

* [experimental] Add Kleidi compile def at the top level

* [Experimental] Add Kleidi i8mm gemm kernels

Add kernel level tests, with basic cross compilation support.

Tested with S24 + r26c

```
[----------] 6 tests from test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm
[ RUN      ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.k_eq_gs_32
[       OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.k_eq_gs_32 (0 ms)
[ RUN      ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.large_k_n_gs32
[       OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.large_k_n_gs32 (79 ms)
[ RUN      ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.even_n_gs32
[       OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.even_n_gs32 (28 ms)
[ RUN      ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.k_eq_gs128
[       OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.k_eq_gs128 (3 ms)
[ RUN      ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.clamp_k_eq_gs128
[       OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.clamp_k_eq_gs128 (3 ms)
[ RUN      ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.m_clamp_k_eq_gs128
[       OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.m_clamp_k_eq_gs128 (5 ms)
[----------] 6 tests from test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm (121 ms total)

[----------] 6 tests from test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm
[ RUN      ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.k_eq_gs_32
[       OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.k_eq_gs_32 (0 ms)
[ RUN      ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.large_k_n_gs32
[       OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.large_k_n_gs32 (79 ms)
[ RUN      ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.even_n_gs32
[       OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.even_n_gs32 (28 ms)
[ RUN      ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.k_eq_gs128
[       OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.k_eq_gs128 (3 ms)
[ RUN      ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.clamp_k_eq_gs128
[       OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.clamp_k_eq_gs128 (3 ms)
[ RUN      ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.m_clamp_k_eq_gs128
[       OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.m_clamp_k_eq_gs128 (5 ms)
[----------] 6 tests from test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm (121 ms total)
```

* [Exeprimental] Kleidi: rename arg name for packing functions

* [Experimental] Change kernel cmake_out dir to avoid conflict
  • Loading branch information
digantdesai authored Nov 20, 2024
1 parent 72fb597 commit ca52cdc
Show file tree
Hide file tree
Showing 10 changed files with 557 additions and 17 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -371,4 +371,7 @@ venv/
sweep/

# Model checkpoints
checkpoints/
checkpoints/

# Experimental
torchao/experimental/cmake-out
4 changes: 4 additions & 0 deletions torchao/experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ if(NOT TORCHAO_INCLUDE_DIRS)
endif()

option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF)
if(TORCHAO_BUILD_KLEIDIAI)
message(STATUS "Building with Arm KleidiAI library")
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
endif()
include(CMakePrintHelpers)

add_compile_options("-Wall" "-Werror" "-Wno-deprecated")
Expand Down
5 changes: 2 additions & 3 deletions torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64"))
add_library(
torchao_kernels_aarch64
${CMAKE_CURRENT_SOURCE_DIR}/reduction/find_min_and_max.cpp
Expand All @@ -25,7 +24,7 @@ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")

# Temporarily exposing this to the parent scope until we wire
# this up properly from the top level
set(TORCHAO_ENABLE_KLEIDI ON PARENT_SCOPE)
set(TORCHAO_BUILD_KLEIDI ON PARENT_SCOPE)
target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai)
endif()
endif()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ size_t activation_data_size(int m, int k, int group_size) {
}

void prepare_activation_data(
void* activation_data,
void* prepared_activation_data,
int m,
int k,
int group_size,
const float* activations) {
(void)group_size; // unused
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
get_ukernel(), activation_data, m, k, activations);
get_ukernel(), prepared_activation_data, m, k, activations);
}

size_t weight_data_size(int n, int k, int group_size) {
Expand All @@ -63,7 +63,7 @@ size_t weight_data_size(int n, int k, int group_size) {
}

void prepare_weight_data(
void* weight_data,
void* prepared_weight_data,
int n,
int k,
int group_size,
Expand All @@ -73,7 +73,7 @@ void prepare_weight_data(
const float* bias) {
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
get_ukernel(),
weight_data,
prepared_weight_data,
n,
k,
group_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ size_t activation_data_size(int m, int k, int group_size) {
}

void prepare_activation_data(
void* activation_data,
void* prepared_activation_data,
int m,
int k,
int group_size,
const float* activations) {
(void) group_size; // unused
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
get_ukernel(),
activation_data,
prepared_activation_data,
m,
k,
activations);
Expand All @@ -64,7 +64,7 @@ size_t weight_data_size(int n, int k, int group_size) {
}

void prepare_weight_data(
void* weight_data,
void* prepared_weight_data,
int n,
int k,
int group_size,
Expand All @@ -74,7 +74,7 @@ void prepare_weight_data(
const float* bias) {
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
get_ukernel(),
weight_data,
prepared_weight_data,
n,
k,
group_size,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h>
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>

namespace torchao::kernels::cpu::aarch64::kleidi {
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
namespace neon_i8mm_8x4x32 {

const Ukernel get_ukernel() {
return Ukernel{
.get_m_step =
kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
.get_n_step =
kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
.get_mr =
kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
.get_nr =
kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
.get_kr =
kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
.get_sr =
kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
.get_lhs_packed_offset =
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
.get_rhs_packed_offset =
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
.get_dst_offset =
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
.get_dst_size =
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
.run_matmul =
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm};
}

size_t activation_data_size(int m, int k, int group_size) {
(void)group_size; // unused
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(
get_ukernel(), m, k);
}

void prepare_activation_data(
void* prepared_activation_data,
int m,
int k,
int group_size,
const float* activations) {
(void)group_size; // unused
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
get_ukernel(), prepared_activation_data, m, k, activations);
}

size_t weight_data_size(int n, int k, int group_size) {
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(
get_ukernel(), n, k, group_size);
}

void prepare_weight_data(
void* prepared_weight_data,
int n,
int k,
int group_size,
const int8_t* weight_qvals,
const float* weight_scales,
const int8_t* weight_zeros,
const float* bias) {
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
get_ukernel(),
prepared_weight_data,
n,
k,
group_size,
weight_qvals,
weight_scales,
weight_zeros,
bias);
}

void kernel(
float32_t* output,
int output_m_stride,
int m,
int n,
int k,
int group_size,
const void* weight_data,
const void* activation_data,
float clamp_min,
float clamp_max) {
if (clamp_min == 0 && clamp_max == 0) {
clamp_min = std::numeric_limits<float>::lowest();
clamp_max = std::numeric_limits<float>::max();
}

auto ukernel = get_ukernel();
ukernel.run_matmul(
m,
n,
k,
group_size,
activation_data,
weight_data,
output,
/*dst_stride_row=*/n * sizeof(float),
/*dst_stride_col=*/sizeof(float),
clamp_min,
clamp_max);
}

size_t get_preferred_alignement() {
return 16;
}
} // namespace neon_i8mm_8x4x32
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
} // namespace torchao::kernels::cpu::aarch64::kleidi
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h>

#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>

namespace torchao::kernels::cpu::aarch64::kleidi {
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
namespace neon_i8mm_4x8x32 {

const Ukernel get_ukernel() {
return Ukernel{
.get_m_step =
kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
.get_n_step =
kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
.get_mr =
kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
.get_nr =
kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
.get_kr =
kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
.get_sr =
kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
.get_lhs_packed_offset =
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
.get_rhs_packed_offset =
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
.get_dst_offset =
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
.get_dst_size =
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
.run_matmul =
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm};
}

size_t activation_data_size(int m, int k, int group_size) {
(void)group_size; // unused
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(
get_ukernel(), m, k);
}

void prepare_activation_data(
void* prepared_activation_data,
int m,
int k,
int group_size,
const float* activations) {
(void)group_size; // unused
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
get_ukernel(), prepared_activation_data, m, k, activations);
}

size_t weight_data_size(int n, int k, int group_size) {
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(
get_ukernel(), n, k, group_size);
}

void prepare_weight_data(
void* prepared_weight_data,
int n,
int k,
int group_size,
const int8_t* weight_qvals,
const float* weight_scales,
const int8_t* weight_zeros,
const float* bias) {
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
get_ukernel(),
prepared_weight_data,
n,
k,
group_size,
weight_qvals,
weight_scales,
weight_zeros,
bias);
}

void kernel(
float32_t* output,
int output_m_stride,
int m,
int n,
int k,
int group_size,
const void* weight_data,
const void* activation_data,
float clamp_min,
float clamp_max) {
if (clamp_min == 0 && clamp_max == 0) {
clamp_min = std::numeric_limits<float>::lowest();
clamp_max = std::numeric_limits<float>::max();
}

auto ukernel = get_ukernel();
ukernel.run_matmul(
m,
n,
k,
group_size,
activation_data,
weight_data,
output,
/*dst_stride_row=*/n * sizeof(float),
/*dst_stride_col=*/sizeof(float),
clamp_min,
clamp_max);
}

size_t get_preferred_alignement() {
return 16;
}

} // namespace neon_i8mm_4x8x32
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
} // namespace torchao::kernels::cpu::aarch64::kleidi
25 changes: 23 additions & 2 deletions torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ FetchContent_Declare(
)
FetchContent_MakeAvailable(googletest)

if (ANDROID_ABI)
# We are cross compiling, delay test discovery till runtime
set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST)
endif()

add_compile_options("-Wall" "-Werror")

include(CMakePrintHelpers)
Expand All @@ -35,13 +40,29 @@ endif()

add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64)

# The TORCHAO_ENABLE_KLEIDI cmake variable should be set by `torchao_kernels_aarch64"
if(TORCHAO_ENABLE_KLEIDI)
# The TORCHAO_BUILD_KLEIDI cmake variable should be set by `torchao_kernels_aarch64"
if(TORCHAO_BUILD_KLEIDI)
add_compile_definitions(TORCHAO_ENABLE_KLEIDI)
endif()

if(TORCHAO_BUILD_ARM_I8MM)
add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM)
endif()

enable_testing()

if (ANDROID_ABI)
# Given where we are today this is sufficent. But needs to be revisited.
# This is also needed for native builds, but keeping it only for cross builds
# for now given the hacky nature.
file(GLOB DOTPROD_SRC_FILES test*.cpp)
message(SRC_FILES: ${DOTPROD_SRC_FILES})
set_property(SOURCE
${DOTPROD_SRC_FILES}
APPEND_STRING PROPERTY
COMPILE_FLAGS " -march=armv8.2-a+dotprod ")
endif()

add_executable(test_quantization test_quantization.cpp)
target_link_libraries(
test_quantization
Expand Down
Loading

0 comments on commit ca52cdc

Please sign in to comment.