Skip to content

Commit

Permalink
[GPU] Fix aot issue caused by fused einsum build (#2450)
Browse files Browse the repository at this point in the history
  • Loading branch information
lingzhi98 authored Oct 26, 2023
1 parent 949b526 commit ab902a7
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 20 deletions.
6 changes: 6 additions & 0 deletions docs/install/install_for_cpp.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@ $ bazel build -c opt --config=gpu //itex:libitex_gpu_cc.so

CC library location: `<Path to intel-extension-for-tensorflow>/bazel-bin/itex/libitex_gpu_cc.so`

NOTE: `libitex_gpu_cc.so` is depended on `libitex_gpu_xetla.so`, so `libitex_gpu_xetla.so` shoule be copied to the same diretcory of `libitex_gpu_cc.so`
```bash
$ cd <Path to intel-extension-for-tensorflow>
$ cp bazel-out/k8-opt-ST-*/bin/itex/core/kernels/gpu/libitex_gpu_xetla.so bazel-bin/itex/
```

For CPU support

```bash
Expand Down
1 change: 1 addition & 0 deletions itex/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ itex_xpu_binary(
"//itex/core/graph:xpu_graph",
"//itex/core/kernels:xpu_kernel_cc",
"//itex/core/profiler:gpu_profiler",
"//itex/core/kernels/gpu:libitex_gpu_xetla",
],
) + [
"//itex/core:protos_all_cc",
Expand Down
1 change: 1 addition & 0 deletions itex/core/kernels/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ itex_xetla_binary(
set_target = "gpu_orig_backend",
visibility = ["//visibility:public"],
deps = [
"//itex/core/kernels/gpu/xetla:fused_einsum_impl",
"//itex/core/kernels/gpu/xetla:mha_op",
"//itex/core/kernels/gpu/xetla:mlp_op",
],
Expand Down
17 changes: 6 additions & 11 deletions itex/core/kernels/gpu/linalg/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("//itex:itex.bzl", "itex_xetla_library", "itex_xpu_library", "tf_copts")
load("//itex:itex.bzl", "itex_xpu_library", "tf_copts")

itex_xpu_library(
name = "linalg",
Expand Down Expand Up @@ -37,19 +37,14 @@ itex_xpu_library(
alwayslink = True,
)

itex_xetla_library(
name = "fused_einsum_impl",
srcs = ["fused_einsum_impl.cc"],
hdrs = [
"einsum_helper.h",
"fused_einsum_impl.h",
],
itex_xpu_library(
name = "fused_einsum_hdrs",
hdrs = ["fused_einsum_helper.h"],
copts = tf_copts(),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
"//itex:core",
"@xetla//:xetla_header",
],
alwayslink = True,
)
Expand All @@ -58,17 +53,17 @@ itex_xpu_library(
name = "einsum_op_impl",
srcs = ["einsum_op_impl.cc"],
hdrs = [
"einsum_helper.h",
"//itex/core/kernels/common:einsum_hdrs",
],
copts = tf_copts(),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
"//itex/core/kernels/common:fill_functor",
"//itex/core/kernels/gpu:libitex_gpu_xetla",
"//itex/core/kernels/gpu:matmul_op",
"//itex/core/kernels/gpu:reduction_ops",
"//itex/core/kernels/gpu/linalg:fused_einsum_impl",
"//itex/core/kernels/gpu/linalg:fused_einsum_hdrs",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
Expand Down
2 changes: 1 addition & 1 deletion itex/core/kernels/gpu/linalg/einsum_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ limitations under the License.

#include "itex/core/kernels/common/einsum_op_impl.h"

#include "itex/core/kernels/gpu/linalg/einsum_helper.h"
#include "itex/core/kernels/gpu/linalg/fused_einsum_helper.h"

namespace itex {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
* limitations under the License.
******************************************************************************/

#ifndef ITEX_CORE_KERNELS_GPU_LINALG_EINSUM_HELPER_H_
#define ITEX_CORE_KERNELS_GPU_LINALG_EINSUM_HELPER_H_
#ifndef ITEX_CORE_KERNELS_GPU_LINALG_FUSED_EINSUM_HELPER_H_
#define ITEX_CORE_KERNELS_GPU_LINALG_FUSED_EINSUM_HELPER_H_

#include "itex/core/utils/op_kernel.h"
#include "itex/core/utils/plugin_tensor.h"
Expand Down Expand Up @@ -122,4 +122,4 @@ void Dispatch(Args<T>& args) { // NOLINT
} // namespace functor
} // namespace itex

#endif // ITEX_CORE_KERNELS_GPU_LINALG_EINSUM_HELPER_H_
#endif // ITEX_CORE_KERNELS_GPU_LINALG_FUSED_EINSUM_HELPER_H_
17 changes: 17 additions & 0 deletions itex/core/kernels/gpu/xetla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,20 @@ itex_xetla_library(
],
alwayslink = True,
)

itex_xetla_library(
name = "fused_einsum_impl",
srcs = ["fused_einsum_impl.cc"],
hdrs = [
"fused_einsum_impl.h",
],
copts = tf_copts(),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
"//itex:core",
"//itex/core/kernels/gpu/linalg:fused_einsum_hdrs",
"@xetla//:xetla_header",
],
alwayslink = True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
* limitations under the License.
******************************************************************************/

#include "itex/core/kernels/gpu/linalg/fused_einsum_impl.h"
#include "itex/core/kernels/gpu/xetla/fused_einsum_impl.h"

#include "itex/core/kernels/gpu/linalg/einsum_helper.h"
#include "itex/core/kernels/gpu/linalg/fused_einsum_helper.h"
#include "itex/core/utils/op_requires.h"

namespace itex {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
* limitations under the License.
******************************************************************************/

#ifndef ITEX_CORE_KERNELS_GPU_LINALG_FUSED_EINSUM_IMPL_H_
#define ITEX_CORE_KERNELS_GPU_LINALG_FUSED_EINSUM_IMPL_H_
#ifndef ITEX_CORE_KERNELS_GPU_XETLA_FUSED_EINSUM_IMPL_H_
#define ITEX_CORE_KERNELS_GPU_XETLA_FUSED_EINSUM_IMPL_H_

#include <xetla.hpp>

Expand Down Expand Up @@ -236,4 +236,4 @@ class FusedEinsumKernel {

} // namespace gpu::xetla

#endif // ITEX_CORE_KERNELS_GPU_LINALG_FUSED_EINSUM_IMPL_H_
#endif // ITEX_CORE_KERNELS_GPU_XETLA_FUSED_EINSUM_IMPL_H_

0 comments on commit ab902a7

Please sign in to comment.