-
Notifications
You must be signed in to change notification settings - Fork 522
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: X-link: facebookresearch/FBGEMM#653 - Move FP32 kernels to OSS Differential Revision: D68119470
- Loading branch information
1 parent
56fb6ad
commit 622d3bb
Showing
10 changed files
with
7,320 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | ||
|
||
#include <chrono> | ||
#include <cmath> | ||
#include <memory> | ||
#include <random> | ||
|
||
#ifdef USE_MKL | ||
#include <mkl.h> | ||
#endif | ||
|
||
#include "./FbgemmFP32.h" | ||
#include "bench/BenchUtils.h" | ||
|
||
using namespace fbgemm; | ||
|
||
int main(int argc, const char* argv[]) { | ||
int num_instances = 1; | ||
#ifdef _OPENMP | ||
const char* inst = getenv("GEMMBENCH_NUM_INSTANCES"); | ||
if (inst != nullptr && *inst) { | ||
num_instances = std::max(atoi(inst), num_instances); | ||
} | ||
num_instances = | ||
parseArgumentInt(argc, argv, "--inst=", num_instances, num_instances); | ||
printf("Running %d instances\n", num_instances); | ||
if (num_instances > 1) { | ||
// Set-up execution for multi-instance mode | ||
// Number of threads in OpenMP parallel region is explicitly | ||
// set to the number of instances to be executed. | ||
omp_set_num_threads(num_instances); | ||
#ifdef USE_MKL | ||
// each instance should be run with a single thread | ||
mkl_set_num_threads(1); | ||
#endif | ||
} else { | ||
// When running single instance use OMP_NUM_THREADS to determine | ||
// parallelism. Default behaviour is using a single thread. | ||
int num_threads = parseArgumentInt(argc, argv, "--num_threads=", 1, 1); | ||
const char* val = getenv("OMP_NUM_THREADS"); | ||
if (val == nullptr || !*val) { | ||
omp_set_num_threads(num_threads); | ||
} | ||
} | ||
|
||
#endif | ||
|
||
int repetitions = parseArgumentInt(argc, argv, "--repit=", 1, 1); | ||
bool no_flush = parseArgumentBool(argc, argv, "--no-flush", false); | ||
bool no_mkl = parseArgumentBool(argc, argv, "--no-mkl", false); | ||
bool enable_avx512_ymm = parseArgumentBool(argc, argv, "--avx512-256", false); | ||
fbgemmEnableAvx512Ymm(enable_avx512_ymm); | ||
performance_test<float>(num_instances, !no_flush, repetitions, !no_mkl); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | ||
|
||
#pragma once | ||
|
||
// WARNING: this is a legacy fp16 fbgemm implementation and will soon be | ||
// upgraded to match with new fbgemm interface. | ||
|
||
#include <cpuinfo.h> | ||
#include <cassert> | ||
#include <cstdlib> | ||
#include <memory> | ||
#include <stdexcept> | ||
#include <vector> | ||
|
||
#include "fbgemm/FbgemmFPCommon.h" | ||
#include "fbgemm/FbgemmPackMatrixB.h" | ||
#include "fbgemm/Types.h" | ||
#include "fbgemm/Utils.h" | ||
|
||
namespace fbgemm { | ||
template <> | ||
struct TypeConverter<float> { | ||
float operator()(float src) const { | ||
return src; | ||
} | ||
}; | ||
|
||
using GemmParamsFP32 = GemmParams<float>; | ||
using PackedGemmMatrixFP32 = PackedGemmMatrixB<float>; | ||
|
||
template <typename T, int _kernel_ncol_blocks, int _brow> | ||
void cblas_gemm_compute( | ||
const matrix_op_t transa, | ||
const int m, | ||
const float* A, | ||
const PackedGemmMatrixB<T>& Bp, | ||
const float beta, | ||
float* C, | ||
int thread_id = 0, | ||
int num_threads = 1); | ||
|
||
extern template void cblas_gemm_compute( | ||
const matrix_op_t transa, | ||
const int m, | ||
const float* A, | ||
const PackedGemmMatrixFP32& Bp, | ||
const float beta, | ||
float* C, | ||
int thread_id, | ||
int num_threads); | ||
|
||
template <> | ||
const isa_descriptor<float>& getIsaHandlers(inst_set_t isa, float); | ||
|
||
} // namespace fbgemm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | ||
|
||
#define FBGEMM_EXPORTS | ||
#include <array> | ||
#include <cmath> | ||
#include <utility> | ||
|
||
#include "./FbgemmFP32UKernelsAvx2.h" | ||
#include "./FbgemmFP32UKernelsAvx512.h" | ||
#include "./FbgemmFP32UKernelsAvx512_256.h" | ||
#include "fbgemm/Fbgemm.h" | ||
#include "fbgemm/FbgemmFPCommon.h" | ||
|
||
namespace fbgemm { | ||
|
||
namespace { | ||
// optimized kernels to cover all cases | ||
// 2 in ?x2 should be the same as kernel_ncol_blocks. | ||
// Here with kernel_ncol_blocks = 2, we can provide up to 6x2 kernels, due to | ||
// the restrictions of ymm register numbers (16). | ||
constexpr kernel_array_t<float> kernel_f32_avx2 = { | ||
#ifndef __aarch64__ | ||
nullptr, | ||
gemmkernel_1x2_Avx2_fp32_fA0fB0fC0, | ||
gemmkernel_2x2_Avx2_fp32_fA0fB0fC0, | ||
gemmkernel_3x2_Avx2_fp32_fA0fB0fC0, | ||
gemmkernel_4x2_Avx2_fp32_fA0fB0fC0, | ||
gemmkernel_5x2_Avx2_fp32_fA0fB0fC0, | ||
gemmkernel_6x2_Avx2_fp32_fA0fB0fC0}; | ||
#else | ||
nullptr}; | ||
#endif | ||
|
||
constexpr kernel_array_t<float> kernel_f32_avx512 = { | ||
#ifndef __aarch64__ | ||
nullptr, | ||
gemmkernel_1x2_Avx512_fp32_fA0fB0fC0, | ||
gemmkernel_2x2_Avx512_fp32_fA0fB0fC0, | ||
gemmkernel_3x2_Avx512_fp32_fA0fB0fC0, | ||
gemmkernel_4x2_Avx512_fp32_fA0fB0fC0, | ||
gemmkernel_5x2_Avx512_fp32_fA0fB0fC0, | ||
gemmkernel_6x2_Avx512_fp32_fA0fB0fC0, | ||
gemmkernel_7x2_Avx512_fp32_fA0fB0fC0, | ||
gemmkernel_8x2_Avx512_fp32_fA0fB0fC0, | ||
gemmkernel_9x2_Avx512_fp32_fA0fB0fC0, | ||
gemmkernel_10x2_Avx512_fp32_fA0fB0fC0, | ||
gemmkernel_11x2_Avx512_fp32_fA0fB0fC0, | ||
gemmkernel_12x2_Avx512_fp32_fA0fB0fC0, | ||
gemmkernel_13x2_Avx512_fp32_fA0fB0fC0, | ||
gemmkernel_14x2_Avx512_fp32_fA0fB0fC0}; | ||
#else | ||
nullptr}; | ||
#endif | ||
|
||
// clang-format on | ||
constexpr kernel_array_t<float> kernel_f32_avx512_256 = { | ||
#ifndef __aarch64__ | ||
nullptr, | ||
gemmkernel_1x2_Avx2_fp32_fA0fB0fC0, | ||
gemmkernel_2x2_Avx2_fp32_fA0fB0fC0, | ||
gemmkernel_3x2_Avx2_fp32_fA0fB0fC0, | ||
gemmkernel_4x2_Avx2_fp32_fA0fB0fC0, | ||
gemmkernel_5x2_Avx2_fp32_fA0fB0fC0, | ||
gemmkernel_6x2_Avx2_fp32_fA0fB0fC0, | ||
gemmkernel_7x2_Avx512_256_fp32_fA0fB0fC0, | ||
gemmkernel_8x2_Avx512_256_fp32_fA0fB0fC0, | ||
gemmkernel_9x2_Avx512_256_fp32_fA0fB0fC0, | ||
gemmkernel_10x2_Avx512_256_fp32_fA0fB0fC0, | ||
gemmkernel_11x2_Avx512_256_fp32_fA0fB0fC0, | ||
gemmkernel_12x2_Avx512_256_fp32_fA0fB0fC0, | ||
gemmkernel_13x2_Avx512_256_fp32_fA0fB0fC0, | ||
gemmkernel_14x2_Avx512_256_fp32_fA0fB0fC0}; | ||
#else | ||
nullptr}; | ||
#endif | ||
|
||
} // namespace | ||
|
||
template <> | ||
const isa_descriptor<float>& getIsaHandlers(inst_set_t isa, float) { | ||
static isa_descriptor<float> avx2_descriptor = | ||
std::make_tuple(kernel_f32_avx2, partition_avx2); | ||
static isa_descriptor<float> avx512_descriptor = | ||
std::make_tuple(kernel_f32_avx512, partition_avx512); | ||
static isa_descriptor<float> avx512_256_descriptor = | ||
std::make_tuple(kernel_f32_avx512_256, partition_avx512); | ||
|
||
switch (isa) { | ||
case inst_set_t::sve: | ||
case inst_set_t::anyarch: | ||
case inst_set_t::avx2: | ||
return avx2_descriptor; | ||
|
||
case inst_set_t::avx512: | ||
case inst_set_t::avx512_vnni: | ||
return avx512_descriptor; | ||
|
||
case inst_set_t::avx512_ymm: | ||
case inst_set_t::avx512_vnni_ymm: | ||
return avx512_256_descriptor; | ||
} | ||
|
||
throw std::runtime_error("Unsupported uArch"); | ||
} | ||
|
||
#ifdef FBGEMM_FP32_FALLBACK_TO_REF_KERNEL | ||
template <> | ||
FBGEMM_API void ref_kernel<float>( | ||
int kernel_nrows, | ||
GemmParams<float>* gp, | ||
const float* C_base, | ||
int m_total, | ||
int n_total, | ||
int simd_len) { | ||
int kernel_ncol_blocks = 2; | ||
int block_col_size = simd_len * kernel_ncol_blocks; | ||
for (int jb = 0; jb < gp->b_block_cols; ++jb) { | ||
for (int k = 0; k < gp->k; ++k) { | ||
for (int i = 0; i < kernel_nrows; ++i) { | ||
float a = gp->A[i + k * kernel_nrows]; | ||
for (int j = 0; j < block_col_size; ++j) { | ||
float* C_ptr = | ||
gp->C + i * (gp->ldc / sizeof(float)) + jb * block_col_size + j; | ||
assert(C_ptr < C_base + m_total * n_total); | ||
float b = gp->B[(jb * gp->k + k) * block_col_size + j]; | ||
if (k == 0) { | ||
if (gp->beta) { | ||
*C_ptr = std::fma(a, b, (gp->beta) * (*C_ptr)); | ||
} else { | ||
*C_ptr = a * b; | ||
} | ||
} else { | ||
*C_ptr = std::fma(a, b, *C_ptr); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
#endif // FBGEMM_FP32_FALLBACK_TO_REF_KERNEL | ||
|
||
template void cblas_gemm_compute( | ||
const matrix_op_t transa, | ||
const int m, | ||
const float* A, | ||
const PackedGemmMatrixB<float>& Bp, | ||
const float beta, | ||
float* C, | ||
int thread_id, | ||
int num_threads); | ||
|
||
} // namespace fbgemm |
Oops, something went wrong.