Skip to content

Commit

Permalink
A mechanism to extract LLVM IR from JAX via XLA (#7)
Browse files Browse the repository at this point in the history
This relies on a debug feature of XLA, but that feature has been quite
stable so far. Note that the LLVM IR is not yet connected to JIT.

Co-authored-by: Ulysse Beaugnon <[email protected]>
Co-authored-by: William Moses <[email protected]>
  • Loading branch information
3 people authored Jul 12, 2023
1 parent ab31488 commit d9e2ae0
Show file tree
Hide file tree
Showing 13 changed files with 729 additions and 113 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tag.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ jobs:
submodules: recursive
- run: |
export TAG=`echo $GITHUB_REF | cut -c2- `
echo $TAG
sed -i.bak "s~version = \"[0-9.]*\"~version = \"$TAG\"~g" BUILD
cat BUILD
- uses: bazelbuild/setup-bazelisk@v2
- name: Mount bazel cache # Optional
uses: actions/cache@v3
Expand Down
6 changes: 5 additions & 1 deletion BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:OrcJIT",
"@enzyme//:EnzymeStatic"
],
)
Expand All @@ -36,7 +37,10 @@ py_package(
name = "enzyme_jax_data",
# Only include these Python packages.
packages = ["@//enzyme_jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen"],
deps = ["//enzyme_jax:enzyme_call", "@llvm-project//clang:builtin_headers_gen"],
deps = [
"//enzyme_jax:enzyme_call",
"@llvm-project//clang:builtin_headers_gen",
],
prefix = "enzyme_jax/",
)

Expand Down
4 changes: 3 additions & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ http_archive(
name = "xla",
sha256 = XLA_SHA256,
strip_prefix = "xla-" + XLA_COMMIT,
urls = ["https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)]
urls = ["https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)],
patch_args = ["-p1"],
patches = ["//:patches/xla.patch"],
)


Expand Down
57 changes: 44 additions & 13 deletions clang_compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,12 @@
#include "clang/Frontend/TextDiagnosticBuffer.h"
#include "llvm/Support/Host.h"
#include "clang/FrontendTool/Utils.h"

#include "llvm/MC/TargetRegistry.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Linker/Linker.h"

#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"

#include <Python.h>
#include <pybind11/pybind11.h>
Expand Down Expand Up @@ -178,9 +182,25 @@ static TargetLibraryInfoImpl *createTLII(llvm::Triple &&TargetTriple,
return TLII;
}

static LLVMContext GlobalContext;
// Returns the TargetMachine instance or zero if no triple is provided.
static TargetMachine* GetTargetMachine(llvm::Triple TheTriple, StringRef CPUStr,
StringRef FeaturesStr,
const llvm::TargetOptions &Options, CodeGenOpt::Level level) {
std::string Error;
const Target *TheTarget =
TargetRegistry::lookupTarget(codegen::getMArch(), TheTriple, Error);
// Some modules don't specify a triple, and this is okay.
if (!TheTarget) {
return nullptr;
}

return TheTarget->createTargetMachine(
TheTriple.getTriple(), codegen::getCPUStr(), codegen::getFeaturesStr(),
Options, codegen::getExplicitRelocModel(),
codegen::getExplicitCodeModel(), level);
}

std::unique_ptr<llvm::Module> GetLLVMFromJob(std::string filename, std::string filecontents, bool cpp, ArrayRef<std::string> pyargv, LLVMContext* Context) {
std::unique_ptr<llvm::Module> GetLLVMFromJob(std::string filename, std::string filecontents, bool cpp, ArrayRef<std::string> pyargv, LLVMContext* Context, std::unique_ptr<llvm::Module> linkMod) {
const llvm::opt::InputArgList Args;
const char *binary = cpp ? "clang++" : "clang";
// Buffer diagnostics from argument parsing so that we can output them using a
Expand Down Expand Up @@ -500,7 +520,7 @@ struct tensor<T, n0, N...>
return {};
}

if (!Context) Context=&GlobalContext;
assert(Context);
auto Act = std::make_unique<EmitLLVMOnlyAction>(Context);
Success = Clang->ExecuteAction(*Act);
if (!Success) {
Expand All @@ -509,6 +529,11 @@ struct tensor<T, n0, N...>
}

auto mod = Act->takeModule();

if (linkMod) {
Linker::linkModules(*mod, std::move(linkMod));
}

for (auto &f : *mod) {
if (f.empty()) continue;
if (f.getName() == "entry") continue;
Expand All @@ -527,9 +552,21 @@ struct tensor<T, n0, N...>
createTLII(llvm::Triple(mod->getTargetTriple()), Clang->getCodeGenOpts()));
FAM.registerPass([&] { return TargetLibraryAnalysis(*TLII); });


auto level = CodeGenOpt::Level::Aggressive; //OptimizationLevel::O3;

Triple ModuleTriple(mod->getTargetTriple());
std::string CPUStr, FeaturesStr;

auto ETM = llvm::orc::JITTargetMachineBuilder(llvm::Triple(mod->getTargetTriple())).createTargetMachine ();
if (!ETM) {
throw pybind11::value_error("failed to create targetmachine");
}
auto TM = std::move(ETM.get());

std::optional<PGOOptions> PGOOpt;
PassInstrumentationCallbacks PIC;
PassBuilder PB(nullptr, PTO, PGOOpt, &PIC);
PassBuilder PB(TM.get(), PTO, PGOOpt, &PIC);

augmentPassBuilder(PB);

Expand All @@ -541,14 +578,8 @@ struct tensor<T, n0, N...>
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);

ModulePassManager MPM;

// Map our optimization levels into one of the distinct levels used to
// configure the pipeline.
OptimizationLevel Level = OptimizationLevel::O3;

MPM = PB.buildPerModuleDefaultPipeline(Level);

MPM.run(*mod, MAM);
PB.parsePassPipeline(MPM, "default<O3>");
MPM.run(*mod, MAM);
return mod;
}

2 changes: 1 addition & 1 deletion clang_compile.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
#include <string>
#include "llvm/IR/Module.h"

std::unique_ptr<llvm::Module> GetLLVMFromJob(std::string filename, std::string filecontents, bool cpp, llvm::ArrayRef<std::string> pyargv, llvm::LLVMContext*ctx=nullptr);
std::unique_ptr<llvm::Module> GetLLVMFromJob(std::string filename, std::string filecontents, bool cpp, llvm::ArrayRef<std::string> pyargv, llvm::LLVMContext*ctx=nullptr, std::unique_ptr<llvm::Module> linkMod=nullptr);

#endif // ENZYME_JAX_CLANG_COMPILE_H
61 changes: 59 additions & 2 deletions enzyme_jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,61 @@ py_library(
visibility = ["//visibility:public"]
)

cc_library(
name = "compile_with_xla",
srcs = ["compile_with_xla.cc"],
deps = [
# This is similar to xla_binary rule and is needed to make XLA client compile.
"@tsl//tsl/framework:allocator",
"@tsl//tsl/framework:allocator_registry_impl",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:env_impl",
"@tsl//tsl/platform:tensor_float_32_utils",
"@tsl//tsl/profiler/backends/cpu:annotation_stack_impl",
"@tsl//tsl/profiler/backends/cpu:traceme_recorder_impl",
"@tsl//tsl/profiler/utils:time_utils_impl",
"@tsl//tsl/protobuf:autotuning_proto_cc_impl",
"@tsl//tsl/protobuf:dnn_proto_cc",
"@tsl//tsl/protobuf:dnn_proto_cc_impl",
"@tsl//tsl/protobuf:histogram_proto_cc",
"@tsl//tsl/protobuf:histogram_proto_cc_impl",
"@tsl//tsl/protobuf:protos_all_cc_impl",
"@tsl//tsl/util:determinism",

# This is similar to xla_binary rule and is needed to make XLA client compile.
"@xla//xla:autotune_results_proto_cc",
"@xla//xla:autotune_results_proto_cc_impl",
"@xla//xla/client",
"@xla//xla/client:client_library",
"@xla//xla/service/cpu:cpu_executable",
"@xla//xla/service/gpu:backend_configs_cc",
"@xla//xla/service/gpu:backend_configs_cc_impl",
"@xla//xla/service:hlo_proto_cc",
"@xla//xla/service:hlo_proto_cc_impl",
"@xla//xla/service:memory_space_assignment_proto_cc_impl",
"@xla//xla/stream_executor:dnn_proto_cc_impl",
"@xla//xla:xla_data_proto_cc",
"@xla//xla:xla_data_proto_cc_impl",
"@xla//xla:xla_proto_cc",
"@xla//xla:xla_proto_cc_impl",

# Make CPU target available to XLA.
"@xla//xla/service:cpu_plugin",

# MHLO stuff.
"@xla//xla/mlir_hlo",
"@xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo",

# This is necessary for XLA protobufs to link
"@com_google_protobuf//:protobuf",

# MLIR dialects and parser.
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:Parser",
],
)

pybind_extension(
name = "enzyme_call",
srcs = ["enzyme_call.cc"],
Expand All @@ -15,6 +70,8 @@ pybind_extension(
"@llvm-project//llvm:Support",
"@llvm-project//llvm:OrcJIT",
"//:clang_compile",
":compile_with_xla",
"@com_google_absl//absl/status"
],
visibility = ["//visibility:public"]
)
visibility = ["//visibility:public"],
)
2 changes: 1 addition & 1 deletion enzyme_jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from enzyme_jax.primitives import cpp_call
from enzyme_jax.primitives import cpp_call, enzyme_jax_ir
78 changes: 78 additions & 0 deletions enzyme_jax/compile_with_xla.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <string>
#include <vector>

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Parser/Parser.h"
#include "xla/client/client_library.h"
#include "xla/client/executable_build_options.h"
#include "xla/client/xla_computation.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/service/cpu/cpu_executable.h"
#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"
#include "xla/translate/mhlo_to_hlo/type_to_shape.h"

// Compile an MHLO module given as a string to LLVM IR using XLA.
absl::StatusOr<std::string> compile_mhlo_to_llvm_with_xla(
llvm::StringRef mhlo_text) {
// Parse MLIR.
mlir::MLIRContext context;
context.loadDialect<mlir::arith::ArithDialect>();
context.loadDialect<mlir::func::FuncDialect>();
context.loadDialect<mlir::mhlo::MhloDialect>();
mlir::ParserConfig parser_config(&context);
mlir::OwningOpRef<mlir::ModuleOp> parsed_module =
mlir::parseSourceString<mlir::ModuleOp>(mhlo_text, parser_config);

// Convert to XLA Computation.
xla::HloProto hlo_proto;
mlir::ConvertMlirHloToHlo(*parsed_module, &hlo_proto,
/*use_tuple_args=*/false, /*return_tuple=*/false);
xla::XlaComputation xla_computation(hlo_proto.hlo_module());

// Extract and convert the shapes fro MHLO.
std::vector<xla::Shape> shapes;
mlir::SymbolTable symbol_table(*parsed_module);
auto entry_point = symbol_table.lookup<mlir::FunctionOpInterface>("main");
shapes.reserve(entry_point.getNumArguments());
for (mlir::Type type : entry_point.getArgumentTypes()) {
shapes.push_back(xla::TypeToShape(type));
}
std::vector<const xla::Shape *> shape_pointers;
shape_pointers.reserve(shapes.size());
for (xla::Shape &shape : shapes) {
shape_pointers.push_back(&shape);
}

// Compile with XLA, local client means targeting CPU.
// XXX: this is using a debug feature of XLA to preserve LLVM IR. If the
// feature ever disappears and is not recoverable with a local patch, this
// will have to recreate the XLA pipeline. This may also be wiser in the long
// term so we don't waste compile time running LLVM optimizations and code
// generation only to throw away the binary.
absl::StatusOr<xla::LocalClient *> local_client_or_error =
xla::ClientLibrary::GetOrCreateLocalClient();
if (!local_client_or_error.ok()) return local_client_or_error.status();
xla::LocalClient *local_client = local_client_or_error.value();
xla::ExecutableBuildOptions build_options;
build_options.mutable_debug_options()->set_xla_embed_ir_in_executable(true);
absl::StatusOr<std::vector<std::unique_ptr<xla::LocalExecutable>>>
local_executables =
local_client->Compile(xla_computation, shape_pointers, build_options);
if (!local_executables.ok()) return local_executables.status();

// Extract the LLVM IR stored in the executable.
xla::LocalExecutable &local_executable = *local_executables.value()[0];
auto *cpu_executable =
static_cast<xla::cpu::CpuExecutable *>(local_executable.executable());
const std::string &llvm_ir = cpu_executable->ir_module_string();
return llvm_ir;
}
Loading

0 comments on commit d9e2ae0

Please sign in to comment.