Skip to content

Commit b12e49e

Browse files
committed
Reuse GELU implementation from PyTorch core
Pull Request resolved: #7041 kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch. Note that, because we will pick up Sleef internally and ignore it externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS. Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break. ghstack-source-id: 263918321 @exported-using-ghexport Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/)
1 parent 41acdf4 commit b12e49e

File tree

15 files changed

+87
-47
lines changed

15 files changed

+87
-47
lines changed

.ci/scripts/build_llama_android.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ set -exu
1010
# shellcheck source=/dev/null
1111
source "$(dirname "${BASH_SOURCE[0]}")/utils.sh"
1212

13+
if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
14+
PYTHON_EXECUTABLE=python3
15+
fi
16+
which "${PYTHON_EXECUTABLE}"
17+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
18+
1319
install_executorch_and_backend_lib() {
1420
echo "Installing executorch and xnnpack backend"
1521
clean_executorch_install_folders
@@ -22,6 +28,7 @@ install_executorch_and_backend_lib() {
2228
-DANDROID_ABI="${ANDROID_ABI}" \
2329
-DCMAKE_INSTALL_PREFIX=cmake-android-out \
2430
-DCMAKE_BUILD_TYPE=Release \
31+
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
2532
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
2633
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
2734
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
@@ -47,6 +54,7 @@ build_llama_runner() {
4754
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
4855
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
4956
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
57+
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
5058
-Bcmake-android-out/examples/models/llama examples/models/llama
5159

5260
cmake --build cmake-android-out/examples/models/llama -j4 --config Release

.ci/scripts/test_llama.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ cmake_install_executorch_libraries() {
154154
rm -rf cmake-out
155155
retry cmake \
156156
-DCMAKE_INSTALL_PREFIX=cmake-out \
157+
-DCMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')" \
157158
-DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \
158159
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
159160
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \

.ci/scripts/test_model.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ prepare_artifacts_upload() {
5050

5151
build_cmake_executor_runner() {
5252
echo "Building executor_runner"
53+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
5354
rm -rf ${CMAKE_OUTPUT_DIR}
5455
cmake -DCMAKE_BUILD_TYPE=Debug \
5556
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
5657
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
58+
-DCMAKE_PREFIX_PATH="$CMAKE_PREFIX_PATH" \
5759
-B${CMAKE_OUTPUT_DIR} .
5860

5961
cmake --build ${CMAKE_OUTPUT_DIR} -j4 --config Debug
@@ -98,8 +100,7 @@ test_model() {
98100

99101
build_cmake_xnn_executor_runner() {
100102
echo "Building xnn_executor_runner"
101-
SITE_PACKAGES="$(${PYTHON_EXECUTABLE} -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
102-
CMAKE_PREFIX_PATH="${SITE_PACKAGES}/torch"
103+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
103104

104105
(rm -rf ${CMAKE_OUTPUT_DIR} \
105106
&& mkdir ${CMAKE_OUTPUT_DIR} \

.ci/scripts/test_phi_3_mini.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ NPROC=8
2222
if hash nproc &> /dev/null; then NPROC=$(nproc); fi
2323

2424
cmake_install_executorch_libraries() {
25+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
2526
cmake -DPYTHON_EXECUTABLE=python \
2627
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
28+
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
2729
-DEXECUTORCH_ENABLE_LOGGING=1 \
2830
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
2931
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
@@ -39,8 +41,10 @@ cmake_install_executorch_libraries() {
3941
}
4042

4143
cmake_build_phi_3_mini() {
44+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
4245
cmake -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
4346
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
47+
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
4448
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
4549
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
4650
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \

.ci/scripts/utils.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ cmake_install_executorch_lib() {
136136
clean_executorch_install_folders
137137
retry cmake -DBUCK2="$BUCK" \
138138
-DCMAKE_INSTALL_PREFIX=cmake-out \
139+
-DCMAKE_PREFIX_PATH="$($PYTHON_EXECUTABLE -c 'import torch as _; print(_.__path__[0])')" \
139140
-DCMAKE_BUILD_TYPE=Release \
140141
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
141142
-Bcmake-out .

.github/workflows/pull.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ jobs:
147147
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
148148
conda activate "${CONDA_ENV}"
149149
150+
source .ci/scripts/utils.sh
151+
install_executorch "use-pt-pinned-commit"
150152
BUILD_TOOL="cmake"
151153
PYTHON_EXECUTABLE=python \
152154
bash .ci/scripts/build_llama_android.sh "${BUILD_TOOL}"

.github/workflows/trunk.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ jobs:
394394
rm -rf cmake-out
395395
cmake \
396396
-DCMAKE_INSTALL_PREFIX=cmake-out \
397+
-DCMAKE_PREFIX_PATH="$(python -c 'import torch as _; print(_.__path__[0])')" \
397398
-DCMAKE_BUILD_TYPE=Release \
398399
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
399400
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
@@ -411,6 +412,7 @@ jobs:
411412
cmake \
412413
-DCMAKE_INSTALL_PREFIX=cmake-out \
413414
-DCMAKE_BUILD_TYPE=Release \
415+
-DCMAKE_PREFIX_PATH="$(python -c 'import torch as _; print(_.__path__[0])')" \
414416
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
415417
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
416418
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \

build/Utils.cmake

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,20 @@ function(resolve_python_executable)
321321
)
322322
endif()
323323
endfunction()
324+
325+
# find_package(Torch CONFIG REQUIRED) replacement for targets that
326+
# have a header-only Torch dependency. Because find_package sets
327+
# variables in the parent scope, we use a macro to preserve this
328+
# rather than maintaining our own list of those variables.
329+
macro(find_package_torch_headers)
330+
# We cannot simply use CMAKE_FIND_ROOT_PATH_BOTH, because that does
331+
# not propagate into TorchConfig.cmake.
332+
foreach(mode_kind IN ITEMS PACKAGE LIBRARY INCLUDE)
333+
set(OLD_CMAKE_FIND_ROOT_PATH_MODE_${mode_kind} ${CMAKE_FIND_ROOT_PATH_MODE_${mode_kind}})
334+
set(CMAKE_FIND_ROOT_PATH_MODE_${mode_kind} BOTH)
335+
endforeach()
336+
find_package(Torch CONFIG REQUIRED)
337+
foreach(mode_kind IN ITEMS PACKAGE LIBRARY INCLUDE)
338+
set(CMAKE_FIND_ROOT_PATH_MODE_${mode_kind} ${OLD_CMAKE_FIND_ROOT_PATH_MODE_${mode_kind}})
339+
endforeach()
340+
endmacro()

build/build_android_llm_demo.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77

88
set -ex
99

10+
if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
11+
PYTHON_EXECUTABLE=python3
12+
fi
13+
which "${PYTHON_EXECUTABLE}"
14+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
15+
1016
build_jar() {
1117
pushd extension/android
1218
./gradlew build
@@ -36,6 +42,7 @@ build_android_native_library() {
3642
fi
3743

3844
cmake . -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \
45+
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
3946
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \
4047
-DANDROID_ABI="${ANDROID_ABI}" \
4148
-DANDROID_PLATFORM=android-26 \
@@ -69,6 +76,7 @@ build_android_native_library() {
6976
-DANDROID_ABI="${ANDROID_ABI}" \
7077
-DANDROID_PLATFORM=android-26 \
7178
-DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \
79+
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
7280
-DEXECUTORCH_ENABLE_LOGGING=ON \
7381
-DEXECUTORCH_LOG_LEVEL=Info \
7482
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \

kernels/optimized/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ message("Generated files ${gen_command_sources}")
6161

6262
list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/")
6363
add_library(optimized_kernels ${_optimized_kernels__srcs})
64+
find_package_torch_headers()
65+
target_include_directories(optimized_kernels PRIVATE ${TORCH_INCLUDE_DIRS})
6466
target_link_libraries(
6567
optimized_kernels PRIVATE executorch_core cpublas extension_threadpool
6668
)

kernels/optimized/cpu/op_gelu.cpp

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include <cmath>
1515

16+
#include <ATen/native/cpu/Gelu.h>
1617
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
1718
#include <executorch/runtime/kernel/kernel_includes.h>
1819
#include <executorch/runtime/platform/assert.h>
@@ -47,48 +48,26 @@ void gelu(
4748
CTYPE* out_data = output.mutable_data_ptr<CTYPE>();
4849
size_t lim = input.numel();
4950

50-
// TODO: Add fast path for tanh using sleef's tanh
5151
if (approximate == "tanh") {
52-
// 0.5 * x * (1 + Tanh(sqrt(2 / pi) * (x + 0.044715 * x^3))
53-
for (size_t i = 0; i < lim; ++i) {
54-
const CTYPE x = in_data[i];
55-
const CTYPE kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
56-
const CTYPE kKappa = 0.044715;
57-
auto x_cube = x * x * x;
58-
auto inner = kBeta * (x + kKappa * x_cube);
59-
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::tanh(inner));
52+
using Vec = at::vec::Vectorized<CTYPE>;
53+
int i = 0;
54+
for (; i < lim - (lim % Vec::size()); i += Vec::size()) {
55+
Vec x = Vec::loadu(in_data + i);
56+
at::native::vectorized_gelu_approximated_with_tanh(x).store(out_data + i);
6057
}
61-
} else if (approximate == "none") { // dont appx
62-
// GELU(x) = x * Φ(x) where Φ(x) is the is the Cumulative Distribution
63-
// Function for Gaussian Distribution.
64-
65-
#ifndef __aarch64__
66-
for (size_t i = 0; i < lim; ++i) {
67-
const CTYPE x = in_data[i];
68-
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
58+
for (; i < lim; ++i) {
59+
out_data[i] = at::native::scalar_gelu_approximated_with_tanh(in_data[i]);
6960
}
70-
#else
71-
size_t i = 0;
72-
if constexpr (std::is_same_v<CTYPE, float>) {
73-
for (; i + 4 < lim; i += 4) {
74-
const float32x4_t in =
75-
vld1q_f32(static_cast<const float*>(&in_data[i]));
76-
const float32x4_t m_sqrt1_2x4 = {
77-
M_SQRT1_2, M_SQRT1_2, M_SQRT1_2, M_SQRT1_2};
78-
const float32x4_t ones = vmovq_n_f32(1.0);
79-
const float32x4_t halves = vmovq_n_f32(0.5);
80-
float32x4_t out = Sleef_erff4_u10(vmulq_f32(in, m_sqrt1_2x4));
81-
vst1q_f32(
82-
static_cast<float*>(&out_data[i]),
83-
vmulq_f32(vmulq_f32(vaddq_f32(out, ones), in), halves));
84-
}
61+
} else if (approximate == "none") {
62+
using Vec = at::vec::Vectorized<CTYPE>;
63+
int i = 0;
64+
for (; i < lim - (lim % Vec::size()); i += Vec::size()) {
65+
Vec x = Vec::loadu(in_data + i);
66+
at::native::vectorized_gelu(x).store(out_data + i);
8567
}
8668
for (; i < lim; ++i) {
87-
const CTYPE x = in_data[i];
88-
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
69+
out_data[i] = at::native::scalar_gelu(in_data[i]);
8970
}
90-
#endif // __aarch64__
91-
9271
} else {
9372
ET_KERNEL_CHECK_MSG(
9473
context,

kernels/optimized/cpu/targets.bzl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,9 @@ _OPTIMIZED_ATEN_OPS = (
2828
op_target(name = "op_sigmoid"),
2929
op_target(
3030
name = "op_gelu",
31-
deps = select({
32-
"DEFAULT": [],
33-
"ovr_config//cpu:arm64": [
34-
"fbsource//third-party/sleef:sleef_arm",
35-
],
36-
}) + [
31+
deps = [
3732
"//executorch/kernels/portable/cpu/util:activation_ops_util",
33+
"//executorch/runtime/core/portable_type/c10:aten_headers_for_executorch",
3834
],
3935
),
4036
op_target(
@@ -96,6 +92,13 @@ _OPTIMIZED_ATEN_OPS = (
9692
),
9793
)
9894

95+
96+
def get_sleef_preprocessor_flags():
97+
if runtime.is_oss:
98+
return []
99+
return ["-DAT_BUILD_ARM_VEC256_WITH_SLEEF"]
100+
101+
99102
def define_common_targets():
100103
"""Defines targets that should be shared between fbcode and xplat.
101104

kernels/optimized/optimized-oss.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
#
33
# This yaml file contains operators that have optimized kernels available.
4-
# Note that this is a copy of optimized.yaml that does not include gelu and
5-
# log_softmax, due to the OSS build not currently including sleef.
4+
# Note that this is a copy of optimized.yaml that does not include log_softmax,
5+
# due to the OSS build not currently including sleef.
66
# TODO (T183193812)
77

88
- op: add.out
@@ -40,6 +40,11 @@
4040
- arg_meta: null
4141
kernel_name: torch::executor::opt_sigmoid_out
4242

43+
- op: gelu.out
44+
kernels:
45+
- arg_meta: null
46+
kernel_name: torch::executor::opt_gelu_out
47+
4348
- op: le.Scalar_out
4449
kernels:
4550
- arg_meta: null

shim/xplat/executorch/kernels/optimized/op_registration_util.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,5 +134,5 @@ def define_op_target(name, deps):
134134

135135
def is_op_disabled(name):
136136
# TODO (gjcomer) Enable ops with sleef dependency in OSS
137-
disabled_ops = ["op_gelu", "op_log_softmax"]
137+
disabled_ops = ["op_log_softmax"]
138138
return name in disabled_ops

test/run_oss_cpp_tests.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,20 @@ elif [[ $(uname) == "Linux" ]]; then
2222
export LLVM_COV="${LLVM_COV:-llvm-cov}"
2323
fi
2424

25+
if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
26+
PYTHON_EXECUTABLE=python3
27+
fi
28+
which "${PYTHON_EXECUTABLE}"
29+
2530
build_executorch() {
2631
BUILD_VULKAN="OFF"
2732
if [ -x "$(command -v glslc)" ]; then
2833
BUILD_VULKAN="ON"
2934
fi
35+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
3036
cmake . \
3137
-DCMAKE_INSTALL_PREFIX=cmake-out \
38+
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
3239
-DEXECUTORCH_USE_CPP_CODE_COVERAGE=ON \
3340
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
3441
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \

0 commit comments

Comments
 (0)