Skip to content

Commit

Permalink
Reuse GELU implementation from PyTorch core
Browse files Browse the repository at this point in the history
Pull Request resolved: #7041

kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch.

Note that, because we will pick up Sleef internally and ignore it
externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS.

Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break.
ghstack-source-id: 261541003
@exported-using-ghexport

Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/)
  • Loading branch information
swolchok committed Jan 15, 2025
1 parent 179a346 commit 8674880
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 47 deletions.
8 changes: 8 additions & 0 deletions .ci/scripts/build_llama_android.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ set -exu
# shellcheck source=/dev/null
source "$(dirname "${BASH_SOURCE[0]}")/utils.sh"

if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
PYTHON_EXECUTABLE=python3
fi
which "${PYTHON_EXECUTABLE}"
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"

install_executorch_and_backend_lib() {
echo "Installing executorch and xnnpack backend"
clean_executorch_install_folders
Expand All @@ -22,6 +28,7 @@ install_executorch_and_backend_lib() {
-DANDROID_ABI="${ANDROID_ABI}" \
-DCMAKE_INSTALL_PREFIX=cmake-android-out \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
Expand All @@ -47,6 +54,7 @@ build_llama_runner() {
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
-Bcmake-android-out/examples/models/llama examples/models/llama

cmake --build cmake-android-out/examples/models/llama -j4 --config Release
Expand Down
1 change: 1 addition & 0 deletions .ci/scripts/test_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ cmake_install_executorch_libraries() {
rm -rf cmake-out
retry cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')" \
-DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
Expand Down
5 changes: 3 additions & 2 deletions .ci/scripts/test_model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ prepare_artifacts_upload() {

build_cmake_executor_runner() {
echo "Building executor_runner"
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
rm -rf ${CMAKE_OUTPUT_DIR}
cmake -DCMAKE_BUILD_TYPE=Debug \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
-DCMAKE_PREFIX_PATH="$CMAKE_PREFIX_PATH" \
-B${CMAKE_OUTPUT_DIR} .

cmake --build ${CMAKE_OUTPUT_DIR} -j4 --config Debug
Expand Down Expand Up @@ -98,8 +100,7 @@ test_model() {

build_cmake_xnn_executor_runner() {
echo "Building xnn_executor_runner"
SITE_PACKAGES="$(${PYTHON_EXECUTABLE} -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
CMAKE_PREFIX_PATH="${SITE_PACKAGES}/torch"
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"

(rm -rf ${CMAKE_OUTPUT_DIR} \
&& mkdir ${CMAKE_OUTPUT_DIR} \
Expand Down
4 changes: 4 additions & 0 deletions .ci/scripts/test_phi_3_mini.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ NPROC=8
if hash nproc &> /dev/null; then NPROC=$(nproc); fi

cmake_install_executorch_libraries() {
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
cmake -DPYTHON_EXECUTABLE=python \
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
-DEXECUTORCH_ENABLE_LOGGING=1 \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
Expand All @@ -39,8 +41,10 @@ cmake_install_executorch_libraries() {
}

cmake_build_phi_3_mini() {
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
cmake -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
Expand Down
1 change: 1 addition & 0 deletions .ci/scripts/utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ cmake_install_executorch_lib() {
clean_executorch_install_folders
retry cmake -DBUCK2="$BUCK" \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_PREFIX_PATH="$($PYTHON_EXECUTABLE -c 'import torch as _; print(_.__path__[0])')" \
-DCMAKE_BUILD_TYPE=Release \
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
-Bcmake-out .
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ jobs:
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
conda activate "${CONDA_ENV}"
source .ci/scripts/utils.sh
install_executorch "use-pt-pinned-commit"
BUILD_TOOL="cmake"
PYTHON_EXECUTABLE=python \
bash .ci/scripts/build_llama_android.sh "${BUILD_TOOL}"
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ jobs:
rm -rf cmake-out
cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_PREFIX_PATH="$(python -c 'import torch as _; print(_.__path__[0])')" \
-DCMAKE_BUILD_TYPE=Release \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
Expand All @@ -396,6 +397,7 @@ jobs:
cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_PREFIX_PATH="$(python -c 'import torch as _; print(_.__path__[0])')" \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
Expand Down
8 changes: 8 additions & 0 deletions build/build_android_llm_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

set -ex

if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
PYTHON_EXECUTABLE=python3
fi
which "${PYTHON_EXECUTABLE}"
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"

build_jar() {
pushd extension/android
./gradlew build
Expand Down Expand Up @@ -36,6 +42,7 @@ build_android_native_library() {
fi

cmake . -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \
-DANDROID_ABI="${ANDROID_ABI}" \
-DANDROID_PLATFORM=android-26 \
Expand Down Expand Up @@ -69,6 +76,7 @@ build_android_native_library() {
-DANDROID_ABI="${ANDROID_ABI}" \
-DANDROID_PLATFORM=android-26 \
-DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
-DEXECUTORCH_ENABLE_LOGGING=ON \
-DEXECUTORCH_LOG_LEVEL=Info \
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \
Expand Down
9 changes: 9 additions & 0 deletions kernels/optimized/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ message("Generated files ${gen_command_sources}")

list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/")
add_library(optimized_kernels ${_optimized_kernels__srcs})
# We require Torch headers, which setup.py puts in CMAKE_PREFIX_PATH
# for us. Toolchains that we might be using for cross-compiling could
# set CMAKE_FIND_ROOT_PATH, which prevents find_package from finding
# headers not rooted under CMAKE_FIND_ROOT_PATH. This is reasonable
# for binary dependencies because they probably aren't built for the
# target platform, but for our header-only use case, we should just
# ignore CMAKE_FIND_ROOT_PATH.
find_package(Torch CONFIG REQUIRED NO_CMAKE_FIND_ROOT_PATH)
target_include_directories(optimized_kernels PRIVATE ${TORCH_INCLUDE_DIRS})
target_link_libraries(
optimized_kernels PRIVATE executorch_core cpublas extension_threadpool
)
Expand Down
51 changes: 15 additions & 36 deletions kernels/optimized/cpu/op_gelu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <cmath>

#include <ATen/native/cpu/Gelu.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>

Expand Down Expand Up @@ -46,48 +47,26 @@ void gelu(
CTYPE* out_data = output.mutable_data_ptr<CTYPE>();
size_t lim = input.numel();

// TODO: Add fast path for tanh using sleef's tanh
if (approximate == "tanh") {
// 0.5 * x * (1 + Tanh(sqrt(2 / pi) * (x + 0.044715 * x^3))
for (size_t i = 0; i < lim; ++i) {
const CTYPE x = in_data[i];
const CTYPE kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
const CTYPE kKappa = 0.044715;
auto x_cube = x * x * x;
auto inner = kBeta * (x + kKappa * x_cube);
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::tanh(inner));
using Vec = at::vec::Vectorized<CTYPE>;
int i = 0;
for (; i < lim - (lim % Vec::size()); i += Vec::size()) {
Vec x = Vec::loadu(in_data + i);
at::native::vectorized_gelu_approximated_with_tanh(x).store(out_data + i);
}
} else if (approximate == "none") { // dont appx
// GELU(x) = x * Φ(x) where Φ(x) is the is the Cumulative Distribution
// Function for Gaussian Distribution.

#ifndef __aarch64__
for (size_t i = 0; i < lim; ++i) {
const CTYPE x = in_data[i];
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
for (; i < lim; ++i) {
out_data[i] = at::native::scalar_gelu_approximated_with_tanh(in_data[i]);
}
#else
size_t i = 0;
if (std::is_same<CTYPE, float>::value) {
for (; i + 4 < lim; i += 4) {
const float32x4_t in =
vld1q_f32(static_cast<const float*>(&in_data[i]));
const float32x4_t m_sqrt1_2x4 = {
M_SQRT1_2, M_SQRT1_2, M_SQRT1_2, M_SQRT1_2};
const float32x4_t ones = vmovq_n_f32(1.0);
const float32x4_t halves = vmovq_n_f32(0.5);
float32x4_t out = Sleef_erff4_u10(vmulq_f32(in, m_sqrt1_2x4));
vst1q_f32(
static_cast<float*>(&out_data[i]),
vmulq_f32(vmulq_f32(vaddq_f32(out, ones), in), halves));
}
} else if (approximate == "none") {
using Vec = at::vec::Vectorized<CTYPE>;
int i = 0;
for (; i < lim - (lim % Vec::size()); i += Vec::size()) {
Vec x = Vec::loadu(in_data + i);
at::native::vectorized_gelu(x).store(out_data + i);
}
for (; i < lim; ++i) {
const CTYPE x = in_data[i];
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
out_data[i] = at::native::scalar_gelu(in_data[i]);
}
#endif // __aarch64__

} else {
ET_KERNEL_CHECK_MSG(
context,
Expand Down
16 changes: 10 additions & 6 deletions kernels/optimized/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,9 @@ _OPTIMIZED_ATEN_OPS = (
op_target(name = "op_sigmoid"),
op_target(
name = "op_gelu",
deps = select({
"DEFAULT": [],
"ovr_config//cpu:arm64": [
"fbsource//third-party/sleef:sleef_arm",
],
}),
deps = [
"//executorch/runtime/core/portable_type/c10:aten_headers_for_executorch",
],
),
op_target(
name = "op_le",
Expand Down Expand Up @@ -94,6 +91,13 @@ _OPTIMIZED_ATEN_OPS = (
),
)


def get_sleef_preprocessor_flags():
if runtime.is_oss:
return []
return ["-DAT_BUILD_ARM_VEC256_WITH_SLEEF"]


def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.
Expand Down
9 changes: 7 additions & 2 deletions kernels/optimized/optimized-oss.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This yaml file contains operators that have optimized kernels available.
# Note that this is a copy of optimized.yaml that does not include gelu and
# log_softmax, due to the OSS build not currently including sleef.
# Note that this is a copy of optimized.yaml that does not include log_softmax,
# due to the OSS build not currently including sleef.
# TODO (T183193812)

- op: add.out
Expand Down Expand Up @@ -40,6 +40,11 @@
- arg_meta: null
kernel_name: torch::executor::opt_sigmoid_out

- op: gelu.out
kernels:
- arg_meta: null
kernel_name: torch::executor::opt_gelu_out

- op: le.Scalar_out
kernels:
- arg_meta: null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,5 @@ def define_op_target(name, deps):

def is_op_disabled(name):
# TODO (gjcomer) Enable ops with sleef dependency in OSS
disabled_ops = ["op_gelu", "op_log_softmax"]
disabled_ops = ["op_log_softmax"]
return name in disabled_ops
7 changes: 7 additions & 0 deletions test/run_oss_cpp_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,20 @@ elif [[ $(uname) == "Linux" ]]; then
export LLVM_COV="${LLVM_COV:-llvm-cov}"
fi

if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
PYTHON_EXECUTABLE=python3
fi
which "${PYTHON_EXECUTABLE}"

build_executorch() {
BUILD_VULKAN="OFF"
if [ -x "$(command -v glslc)" ]; then
BUILD_VULKAN="ON"
fi
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
cmake . \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
-DEXECUTORCH_USE_CPP_CODE_COVERAGE=ON \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
Expand Down

0 comments on commit 8674880

Please sign in to comment.