Skip to content

Commit

Permalink
Move FP32 kernels to OSS
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#653

- Move FP32 kernels to OSS

Differential Revision: D68119470
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 13, 2025
1 parent 56fb6ad commit 622d3bb
Show file tree
Hide file tree
Showing 10 changed files with 7,320 additions and 0 deletions.
54 changes: 54 additions & 0 deletions bench/FP32Benchmark.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include <chrono>
#include <cmath>
#include <memory>
#include <random>

#ifdef USE_MKL
#include <mkl.h>
#endif

#include "./FbgemmFP32.h"
#include "bench/BenchUtils.h"

using namespace fbgemm;

int main(int argc, const char* argv[]) {
int num_instances = 1;
#ifdef _OPENMP
const char* inst = getenv("GEMMBENCH_NUM_INSTANCES");
if (inst != nullptr && *inst) {
num_instances = std::max(atoi(inst), num_instances);
}
num_instances =
parseArgumentInt(argc, argv, "--inst=", num_instances, num_instances);
printf("Running %d instances\n", num_instances);
if (num_instances > 1) {
// Set-up execution for multi-instance mode
// Number of threads in OpenMP parallel region is explicitly
// set to the number of instances to be executed.
omp_set_num_threads(num_instances);
#ifdef USE_MKL
// each instance should be run with a single thread
mkl_set_num_threads(1);
#endif
} else {
// When running single instance use OMP_NUM_THREADS to determine
// parallelism. Default behaviour is using a single thread.
int num_threads = parseArgumentInt(argc, argv, "--num_threads=", 1, 1);
const char* val = getenv("OMP_NUM_THREADS");
if (val == nullptr || !*val) {
omp_set_num_threads(num_threads);
}
}

#endif

int repetitions = parseArgumentInt(argc, argv, "--repit=", 1, 1);
bool no_flush = parseArgumentBool(argc, argv, "--no-flush", false);
bool no_mkl = parseArgumentBool(argc, argv, "--no-mkl", false);
bool enable_avx512_ymm = parseArgumentBool(argc, argv, "--avx512-256", false);
fbgemmEnableAvx512Ymm(enable_avx512_ymm);
performance_test<float>(num_instances, !no_flush, repetitions, !no_mkl);
}
55 changes: 55 additions & 0 deletions include/fbgemm/FbgemmFP32.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#pragma once

// WARNING: this is a legacy fp16 fbgemm implementation and will soon be
// upgraded to match with new fbgemm interface.

#include <cpuinfo.h>
#include <cassert>
#include <cstdlib>
#include <memory>
#include <stdexcept>
#include <vector>

#include "fbgemm/FbgemmFPCommon.h"
#include "fbgemm/FbgemmPackMatrixB.h"
#include "fbgemm/Types.h"
#include "fbgemm/Utils.h"

namespace fbgemm {
template <>
struct TypeConverter<float> {
float operator()(float src) const {
return src;
}
};

using GemmParamsFP32 = GemmParams<float>;
using PackedGemmMatrixFP32 = PackedGemmMatrixB<float>;

template <typename T, int _kernel_ncol_blocks, int _brow>
void cblas_gemm_compute(
const matrix_op_t transa,
const int m,
const float* A,
const PackedGemmMatrixB<T>& Bp,
const float beta,
float* C,
int thread_id = 0,
int num_threads = 1);

extern template void cblas_gemm_compute(
const matrix_op_t transa,
const int m,
const float* A,
const PackedGemmMatrixFP32& Bp,
const float beta,
float* C,
int thread_id,
int num_threads);

template <>
const isa_descriptor<float>& getIsaHandlers(inst_set_t isa, float);

} // namespace fbgemm
152 changes: 152 additions & 0 deletions src/fp32/FbgemmFP32.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#define FBGEMM_EXPORTS
#include <array>
#include <cmath>
#include <utility>

#include "./FbgemmFP32UKernelsAvx2.h"
#include "./FbgemmFP32UKernelsAvx512.h"
#include "./FbgemmFP32UKernelsAvx512_256.h"
#include "fbgemm/Fbgemm.h"
#include "fbgemm/FbgemmFPCommon.h"

namespace fbgemm {

namespace {
// optimized kernels to cover all cases
// 2 in ?x2 should be the same as kernel_ncol_blocks.
// Here with kernel_ncol_blocks = 2, we can provide up to 6x2 kernels, due to
// the restrictions of ymm register numbers (16).
constexpr kernel_array_t<float> kernel_f32_avx2 = {
#ifndef __aarch64__
nullptr,
gemmkernel_1x2_Avx2_fp32_fA0fB0fC0,
gemmkernel_2x2_Avx2_fp32_fA0fB0fC0,
gemmkernel_3x2_Avx2_fp32_fA0fB0fC0,
gemmkernel_4x2_Avx2_fp32_fA0fB0fC0,
gemmkernel_5x2_Avx2_fp32_fA0fB0fC0,
gemmkernel_6x2_Avx2_fp32_fA0fB0fC0};
#else
nullptr};
#endif

constexpr kernel_array_t<float> kernel_f32_avx512 = {
#ifndef __aarch64__
nullptr,
gemmkernel_1x2_Avx512_fp32_fA0fB0fC0,
gemmkernel_2x2_Avx512_fp32_fA0fB0fC0,
gemmkernel_3x2_Avx512_fp32_fA0fB0fC0,
gemmkernel_4x2_Avx512_fp32_fA0fB0fC0,
gemmkernel_5x2_Avx512_fp32_fA0fB0fC0,
gemmkernel_6x2_Avx512_fp32_fA0fB0fC0,
gemmkernel_7x2_Avx512_fp32_fA0fB0fC0,
gemmkernel_8x2_Avx512_fp32_fA0fB0fC0,
gemmkernel_9x2_Avx512_fp32_fA0fB0fC0,
gemmkernel_10x2_Avx512_fp32_fA0fB0fC0,
gemmkernel_11x2_Avx512_fp32_fA0fB0fC0,
gemmkernel_12x2_Avx512_fp32_fA0fB0fC0,
gemmkernel_13x2_Avx512_fp32_fA0fB0fC0,
gemmkernel_14x2_Avx512_fp32_fA0fB0fC0};
#else
nullptr};
#endif

// clang-format on
constexpr kernel_array_t<float> kernel_f32_avx512_256 = {
#ifndef __aarch64__
nullptr,
gemmkernel_1x2_Avx2_fp32_fA0fB0fC0,
gemmkernel_2x2_Avx2_fp32_fA0fB0fC0,
gemmkernel_3x2_Avx2_fp32_fA0fB0fC0,
gemmkernel_4x2_Avx2_fp32_fA0fB0fC0,
gemmkernel_5x2_Avx2_fp32_fA0fB0fC0,
gemmkernel_6x2_Avx2_fp32_fA0fB0fC0,
gemmkernel_7x2_Avx512_256_fp32_fA0fB0fC0,
gemmkernel_8x2_Avx512_256_fp32_fA0fB0fC0,
gemmkernel_9x2_Avx512_256_fp32_fA0fB0fC0,
gemmkernel_10x2_Avx512_256_fp32_fA0fB0fC0,
gemmkernel_11x2_Avx512_256_fp32_fA0fB0fC0,
gemmkernel_12x2_Avx512_256_fp32_fA0fB0fC0,
gemmkernel_13x2_Avx512_256_fp32_fA0fB0fC0,
gemmkernel_14x2_Avx512_256_fp32_fA0fB0fC0};
#else
nullptr};
#endif

} // namespace

template <>
const isa_descriptor<float>& getIsaHandlers(inst_set_t isa, float) {
static isa_descriptor<float> avx2_descriptor =
std::make_tuple(kernel_f32_avx2, partition_avx2);
static isa_descriptor<float> avx512_descriptor =
std::make_tuple(kernel_f32_avx512, partition_avx512);
static isa_descriptor<float> avx512_256_descriptor =
std::make_tuple(kernel_f32_avx512_256, partition_avx512);

switch (isa) {
case inst_set_t::sve:
case inst_set_t::anyarch:
case inst_set_t::avx2:
return avx2_descriptor;

case inst_set_t::avx512:
case inst_set_t::avx512_vnni:
return avx512_descriptor;

case inst_set_t::avx512_ymm:
case inst_set_t::avx512_vnni_ymm:
return avx512_256_descriptor;
}

throw std::runtime_error("Unsupported uArch");
}

#ifdef FBGEMM_FP32_FALLBACK_TO_REF_KERNEL
template <>
FBGEMM_API void ref_kernel<float>(
int kernel_nrows,
GemmParams<float>* gp,
const float* C_base,
int m_total,
int n_total,
int simd_len) {
int kernel_ncol_blocks = 2;
int block_col_size = simd_len * kernel_ncol_blocks;
for (int jb = 0; jb < gp->b_block_cols; ++jb) {
for (int k = 0; k < gp->k; ++k) {
for (int i = 0; i < kernel_nrows; ++i) {
float a = gp->A[i + k * kernel_nrows];
for (int j = 0; j < block_col_size; ++j) {
float* C_ptr =
gp->C + i * (gp->ldc / sizeof(float)) + jb * block_col_size + j;
assert(C_ptr < C_base + m_total * n_total);
float b = gp->B[(jb * gp->k + k) * block_col_size + j];
if (k == 0) {
if (gp->beta) {
*C_ptr = std::fma(a, b, (gp->beta) * (*C_ptr));
} else {
*C_ptr = a * b;
}
} else {
*C_ptr = std::fma(a, b, *C_ptr);
}
}
}
}
}
}
#endif // FBGEMM_FP32_FALLBACK_TO_REF_KERNEL

template void cblas_gemm_compute(
const matrix_op_t transa,
const int m,
const float* A,
const PackedGemmMatrixB<float>& Bp,
const float beta,
float* C,
int thread_id,
int num_threads);

} // namespace fbgemm
Loading

0 comments on commit 622d3bb

Please sign in to comment.