From 864757779c32113311392d3ff4976ef205fb2bc5 Mon Sep 17 00:00:00 2001 From: xla authors Date: Mon, 6 Jan 2025 03:56:15 -0800 Subject: [PATCH] [XLA:CPU] Decouple compiled function library from JIT compiler. PiperOrigin-RevId: 712473805 --- xla/backends/cpu/codegen/BUILD | 17 +++++ .../cpu/codegen/compiled_function_library.cc | 68 +++++++++++++++++++ .../cpu/codegen/compiled_function_library.h | 68 +++++++++++++++++++ 3 files changed, 153 insertions(+) create mode 100644 xla/backends/cpu/codegen/compiled_function_library.cc create mode 100644 xla/backends/cpu/codegen/compiled_function_library.h diff --git a/xla/backends/cpu/codegen/BUILD b/xla/backends/cpu/codegen/BUILD index 7014ef7c7d2f2..27f286addb514 100644 --- a/xla/backends/cpu/codegen/BUILD +++ b/xla/backends/cpu/codegen/BUILD @@ -281,3 +281,20 @@ cc_library( "@tsl//tsl/platform:statusor", ], ) + +cc_library( + name = "compiled_function_library", + srcs = ["compiled_function_library.cc"], + hdrs = ["compiled_function_library.h"], + deps = [ + "//xla/backends/cpu/runtime:function_library", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:ExecutionEngine", + "@llvm-project//llvm:OrcJIT", + ], +) diff --git a/xla/backends/cpu/codegen/compiled_function_library.cc b/xla/backends/cpu/codegen/compiled_function_library.cc new file mode 100644 index 0000000000000..7f111e5a3566b --- /dev/null +++ b/xla/backends/cpu/codegen/compiled_function_library.cc @@ -0,0 +1,68 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/compiled_function_library.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "xla/backends/cpu/runtime/function_library.h" + +namespace xla::cpu { + +CompiledFunctionLibrary::CompiledFunctionLibrary( + std::unique_ptr execution_session, + std::unique_ptr object_layer, + absl::flat_hash_map symbols_map) + : execution_session_(std::move(execution_session)), + object_layer_(std::move(object_layer)), + symbols_map_(std::move(symbols_map)) { + DCHECK(execution_session_) << "Execution session must not be null"; +} + +CompiledFunctionLibrary::~CompiledFunctionLibrary() { + if (execution_session_) { + if (auto err = execution_session_->endSession()) { + execution_session_->reportError(std::move(err)); + } + } +} + +absl::StatusOr CompiledFunctionLibrary::ResolveFunction( + TypeId type_id, absl::string_view name) { + if (auto it = symbols_map_.find(name); it != symbols_map_.end()) { + if (it->second.type_id != type_id) { + return absl::Status( + absl::StatusCode::kInternal, + absl::StrFormat("Symbol %s has type id %d, expected %d", name, + it->second.type_id.value(), type_id.value())); + } + return it->second.ptr; + } + return absl::Status(absl::StatusCode::kNotFound, + absl::StrFormat("Function %s not found (type id: %d)", + name, type_id.value())); +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/codegen/compiled_function_library.h b/xla/backends/cpu/codegen/compiled_function_library.h new file mode 100644 index 0000000000000..b91100a66dd10 --- /dev/null +++ b/xla/backends/cpu/codegen/compiled_function_library.h @@ -0,0 +1,68 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_CODEGEN_COMPILED_FUNCTION_LIBRARY_H_ +#define XLA_BACKENDS_CPU_CODEGEN_COMPILED_FUNCTION_LIBRARY_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "xla/backends/cpu/runtime/function_library.h" + +namespace xla::cpu { + +// A CompiledFunctionLibrary is a FunctionLibrary that resolves function names +// to compiled functions using LLVM's ORC JIT. +class CompiledFunctionLibrary : public FunctionLibrary { + public: + struct ResolvedSymbol { + TypeId type_id; + void* ptr; + }; + + // Constructs a new CompiledFunctionLibrary. + // + // `execution_session` is the LLVM ORC execution session to use. + // `object_layer` is the LLVM ORC object linking layer with preloaded object + // files. + // `symbols_map` is a map from symbol names to resolved symbols. + CompiledFunctionLibrary( + std::unique_ptr execution_session, + std::unique_ptr object_layer, + absl::flat_hash_map symbols_map); + + ~CompiledFunctionLibrary() final; + + // Resolves the function with the given name and type ID. + absl::StatusOr ResolveFunction(TypeId type_id, + absl::string_view name) final; + + private: + std::unique_ptr execution_session_; + // Owns resources required for the execution session. + std::unique_ptr object_layer_; + // Caches the resolved symbols so we don't have to look them up every time a + // function is resolved. + absl::flat_hash_map symbols_map_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_CODEGEN_COMPILED_FUNCTION_LIBRARY_H_