From 5036f3dfa367d70546cd61c81c1d6a58e4d83a76 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 22 Nov 2024 16:01:09 -0800 Subject: [PATCH] Reuse GELU implementation from PyTorch core kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch. Note that, because we will pick up Sleef internally and ignore it externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS. Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break. Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/) ghstack-source-id: 255095943 Pull Request resolved: https://github.com/pytorch/executorch/pull/7041 --- kernels/optimized/CMakeLists.txt | 2 ++ kernels/optimized/cpu/op_gelu.cpp | 51 ++++++++------------------ kernels/optimized/cpu/targets.bzl | 53 ++++++++++++++++++++++++---- kernels/optimized/optimized-oss.yaml | 5 +++ 4 files changed, 69 insertions(+), 42 deletions(-) diff --git a/kernels/optimized/CMakeLists.txt b/kernels/optimized/CMakeLists.txt index abdeeb7345..d818f5fcda 100644 --- a/kernels/optimized/CMakeLists.txt +++ b/kernels/optimized/CMakeLists.txt @@ -60,6 +60,8 @@ message("Generated files ${gen_command_sources}") list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(optimized_kernels ${_optimized_kernels__srcs}) +find_package(Torch CONFIG REQUIRED) +target_include_directories(optimized_kernels PRIVATE ${TORCH_INCLUDE_DIRS}) target_link_libraries( optimized_kernels PRIVATE executorch_core cpublas extension_threadpool ) diff --git a/kernels/optimized/cpu/op_gelu.cpp b/kernels/optimized/cpu/op_gelu.cpp index e2c48f61c4..1c4b198cdb 100644 --- a/kernels/optimized/cpu/op_gelu.cpp +++ b/kernels/optimized/cpu/op_gelu.cpp @@ -13,6 +13,7 @@ #include +#include #include #include @@ -46,48 +47,26 @@ void gelu( CTYPE* out_data = output.mutable_data_ptr(); size_t lim = input.numel(); - // TODO: Add fast path for tanh using sleef's tanh if (approximate == "tanh") { - // 0.5 * x * (1 + Tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)) - for (size_t i = 0; i < lim; ++i) { - const CTYPE x = in_data[i]; - const CTYPE kBeta = M_SQRT2 * M_2_SQRTPI * 0.5; - const CTYPE kKappa = 0.044715; - auto x_cube = x * x * x; - auto inner = kBeta * (x + kKappa * x_cube); - out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::tanh(inner)); + using Vec = at::vec::Vectorized; + int i = 0; + for (; i < lim - (lim % Vec::size()); i += Vec::size()) { + Vec x = Vec::loadu(in_data + i); + at::native::vectorized_gelu_approximated_with_tanh(x).store(out_data + i); } - } else if (approximate == "none") { // dont appx - // GELU(x) = x * Φ(x) where Φ(x) is the is the Cumulative Distribution - // Function for Gaussian Distribution. - -#ifndef __aarch64__ - for (size_t i = 0; i < lim; ++i) { - const CTYPE x = in_data[i]; - out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2)); + for (; i < lim; ++i) { + out_data[i] = at::native::scalar_gelu_approximated_with_tanh(in_data[i]); } -#else - size_t i = 0; - if (std::is_same::value) { - for (; i + 4 < lim; i += 4) { - const float32x4_t in = - vld1q_f32(static_cast(&in_data[i])); - const float32x4_t m_sqrt1_2x4 = { - M_SQRT1_2, M_SQRT1_2, M_SQRT1_2, M_SQRT1_2}; - const float32x4_t ones = vmovq_n_f32(1.0); - const float32x4_t halves = vmovq_n_f32(0.5); - float32x4_t out = Sleef_erff4_u10(vmulq_f32(in, m_sqrt1_2x4)); - vst1q_f32( - static_cast(&out_data[i]), - vmulq_f32(vmulq_f32(vaddq_f32(out, ones), in), halves)); - } + } else if (approximate == "none") { + using Vec = at::vec::Vectorized; + int i = 0; + for (; i < lim - (lim % Vec::size()); i += Vec::size()) { + Vec x = Vec::loadu(in_data + i); + at::native::vectorized_gelu(x).store(out_data + i); } for (; i < lim; ++i) { - const CTYPE x = in_data[i]; - out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2)); + out_data[i] = at::native::scalar_gelu(in_data[i]); } -#endif // __aarch64__ - } else { ET_KERNEL_CHECK_MSG( context, diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 5e5f6dd7b9..03d96083b2 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -28,12 +28,9 @@ _OPTIMIZED_ATEN_OPS = ( op_target(name = "op_sigmoid"), op_target( name = "op_gelu", - deps = select({ - "DEFAULT": [], - "ovr_config//cpu:arm64": [ - "fbsource//third-party/sleef:sleef_arm", - ], - }), + deps = [ + ":aten_headers_for_executorch", + ], ), op_target( name = "op_le", @@ -94,6 +91,13 @@ _OPTIMIZED_ATEN_OPS = ( ), ) + +def get_sleef_preprocessor_flags(): + if runtime.is_oss: + return [] + return ["-DAT_BUILD_ARM_VEC256_WITH_SLEEF"] + + def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -110,6 +114,43 @@ def define_common_targets(): aten_op_targets = [":{}".format(op["name"]) for op in enabled_ops] all_op_targets = aten_op_targets + runtime.cxx_library( + name = "aten_headers_for_executorch", + srcs = [], + visibility = ["//executorch/kernels/optimized/..."], + exported_deps = select({ + "DEFAULT": [], + "ovr_config//cpu:arm64": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + # fbsource//third-party/sleef:sleef currently fails to + # link with missing symbols, hence the fbcode-specific dep below. + }), + fbcode_exported_deps = [ + "//caffe2:aten-headers-cpu", + "//caffe2:generated-config-header", + "//caffe2/c10/core:base", + ] + select({ + "DEFAULT": [], + "ovr_config//cpu:x86_64": [ + "third-party//sleef:sleef", + ] + }), + xplat_exported_deps = [ + "//xplat/caffe2:aten_header", + "//xplat/caffe2:generated_aten_config_header", + "//xplat/caffe2/c10:c10", + ], + exported_preprocessor_flags = select({ + "ovr_config//cpu:x86_64": [ + "-DCPU_CAPABILITY=AVX2", + "-DCPU_CAPABILITY_AVX2", + "-DHAVE_AVX2_CPU_DEFINITION", + ] + get_sleef_preprocessor_flags(), + "ovr_config//cpu:arm64": get_sleef_preprocessor_flags(), + }) + ["-DSTANDALONE_TORCH_HEADER"], + ) + runtime.cxx_library( name = "binary_ops", exported_headers = ["binary_ops.h"], diff --git a/kernels/optimized/optimized-oss.yaml b/kernels/optimized/optimized-oss.yaml index 52262e2dd5..75e8a8ea86 100644 --- a/kernels/optimized/optimized-oss.yaml +++ b/kernels/optimized/optimized-oss.yaml @@ -40,6 +40,11 @@ - arg_meta: null kernel_name: torch::executor::opt_sigmoid_out +- op: gelu.out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_gelu_out + - op: le.Scalar_out kernels: - arg_meta: null