Skip to content

Commit

Permalink
Reuse GELU implementation from PyTorch core
Browse files Browse the repository at this point in the history
kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch.

Note that, because we will pick up Sleef internally and ignore it
externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS.

Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break.

Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/)

ghstack-source-id: 255095943
Pull Request resolved: #7041
  • Loading branch information
swolchok committed Nov 23, 2024
1 parent 85363ea commit 5036f3d
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 42 deletions.
2 changes: 2 additions & 0 deletions kernels/optimized/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ message("Generated files ${gen_command_sources}")

list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/")
add_library(optimized_kernels ${_optimized_kernels__srcs})
find_package(Torch CONFIG REQUIRED)
target_include_directories(optimized_kernels PRIVATE ${TORCH_INCLUDE_DIRS})
target_link_libraries(
optimized_kernels PRIVATE executorch_core cpublas extension_threadpool
)
Expand Down
51 changes: 15 additions & 36 deletions kernels/optimized/cpu/op_gelu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <cmath>

#include <ATen/native/cpu/Gelu.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>

Expand Down Expand Up @@ -46,48 +47,26 @@ void gelu(
CTYPE* out_data = output.mutable_data_ptr<CTYPE>();
size_t lim = input.numel();

// TODO: Add fast path for tanh using sleef's tanh
if (approximate == "tanh") {
// 0.5 * x * (1 + Tanh(sqrt(2 / pi) * (x + 0.044715 * x^3))
for (size_t i = 0; i < lim; ++i) {
const CTYPE x = in_data[i];
const CTYPE kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
const CTYPE kKappa = 0.044715;
auto x_cube = x * x * x;
auto inner = kBeta * (x + kKappa * x_cube);
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::tanh(inner));
using Vec = at::vec::Vectorized<CTYPE>;
int i = 0;
for (; i < lim - (lim % Vec::size()); i += Vec::size()) {
Vec x = Vec::loadu(in_data + i);
at::native::vectorized_gelu_approximated_with_tanh(x).store(out_data + i);
}
} else if (approximate == "none") { // dont appx
// GELU(x) = x * Φ(x) where Φ(x) is the is the Cumulative Distribution
// Function for Gaussian Distribution.

#ifndef __aarch64__
for (size_t i = 0; i < lim; ++i) {
const CTYPE x = in_data[i];
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
for (; i < lim; ++i) {
out_data[i] = at::native::scalar_gelu_approximated_with_tanh(in_data[i]);
}
#else
size_t i = 0;
if (std::is_same<CTYPE, float>::value) {
for (; i + 4 < lim; i += 4) {
const float32x4_t in =
vld1q_f32(static_cast<const float*>(&in_data[i]));
const float32x4_t m_sqrt1_2x4 = {
M_SQRT1_2, M_SQRT1_2, M_SQRT1_2, M_SQRT1_2};
const float32x4_t ones = vmovq_n_f32(1.0);
const float32x4_t halves = vmovq_n_f32(0.5);
float32x4_t out = Sleef_erff4_u10(vmulq_f32(in, m_sqrt1_2x4));
vst1q_f32(
static_cast<float*>(&out_data[i]),
vmulq_f32(vmulq_f32(vaddq_f32(out, ones), in), halves));
}
} else if (approximate == "none") {
using Vec = at::vec::Vectorized<CTYPE>;
int i = 0;
for (; i < lim - (lim % Vec::size()); i += Vec::size()) {
Vec x = Vec::loadu(in_data + i);
at::native::vectorized_gelu(x).store(out_data + i);
}
for (; i < lim; ++i) {
const CTYPE x = in_data[i];
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
out_data[i] = at::native::scalar_gelu(in_data[i]);
}
#endif // __aarch64__

} else {
ET_KERNEL_CHECK_MSG(
context,
Expand Down
53 changes: 47 additions & 6 deletions kernels/optimized/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,9 @@ _OPTIMIZED_ATEN_OPS = (
op_target(name = "op_sigmoid"),
op_target(
name = "op_gelu",
deps = select({
"DEFAULT": [],
"ovr_config//cpu:arm64": [
"fbsource//third-party/sleef:sleef_arm",
],
}),
deps = [
":aten_headers_for_executorch",
],
),
op_target(
name = "op_le",
Expand Down Expand Up @@ -94,6 +91,13 @@ _OPTIMIZED_ATEN_OPS = (
),
)


def get_sleef_preprocessor_flags():
if runtime.is_oss:
return []
return ["-DAT_BUILD_ARM_VEC256_WITH_SLEEF"]


def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.
Expand All @@ -110,6 +114,43 @@ def define_common_targets():
aten_op_targets = [":{}".format(op["name"]) for op in enabled_ops]
all_op_targets = aten_op_targets

runtime.cxx_library(
name = "aten_headers_for_executorch",
srcs = [],
visibility = ["//executorch/kernels/optimized/..."],
exported_deps = select({
"DEFAULT": [],
"ovr_config//cpu:arm64": [
"fbsource//third-party/sleef:sleef_arm",
] if not runtime.is_oss else [],
# fbsource//third-party/sleef:sleef currently fails to
# link with missing symbols, hence the fbcode-specific dep below.
}),
fbcode_exported_deps = [
"//caffe2:aten-headers-cpu",
"//caffe2:generated-config-header",
"//caffe2/c10/core:base",
] + select({
"DEFAULT": [],
"ovr_config//cpu:x86_64": [
"third-party//sleef:sleef",
]
}),
xplat_exported_deps = [
"//xplat/caffe2:aten_header",
"//xplat/caffe2:generated_aten_config_header",
"//xplat/caffe2/c10:c10",
],
exported_preprocessor_flags = select({
"ovr_config//cpu:x86_64": [
"-DCPU_CAPABILITY=AVX2",
"-DCPU_CAPABILITY_AVX2",
"-DHAVE_AVX2_CPU_DEFINITION",
] + get_sleef_preprocessor_flags(),
"ovr_config//cpu:arm64": get_sleef_preprocessor_flags(),
}) + ["-DSTANDALONE_TORCH_HEADER"],
)

runtime.cxx_library(
name = "binary_ops",
exported_headers = ["binary_ops.h"],
Expand Down
5 changes: 5 additions & 0 deletions kernels/optimized/optimized-oss.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
- arg_meta: null
kernel_name: torch::executor::opt_sigmoid_out

- op: gelu.out
kernels:
- arg_meta: null
kernel_name: torch::executor::opt_gelu_out

- op: le.Scalar_out
kernels:
- arg_meta: null
Expand Down

0 comments on commit 5036f3d

Please sign in to comment.