Skip to content

Commit

Permalink
SROA by roundtripping to LLVM (#229)
Browse files Browse the repository at this point in the history
* WIP

* Round trip done

* comment

* comment

* indent

* Useless clone

* fixup

* func attrs

* force llvm

* fix

* fmt

* fix name

* fix

* assert on creation

* more printing

* fix

* Fix

* fix llvmfunc

* fix

* more prints

* fix output alias rewrite

* Math raising

* only ranked tensor

* improve debugging

* fix

* fix

* cusync

* bump xla

* fmt

---------

Co-authored-by: William S. Moses <[email protected]>
  • Loading branch information
ivanradanov and wsmoses authored Jan 15, 2025
1 parent ff78f0d commit 362f33f
Show file tree
Hide file tree
Showing 15 changed files with 814 additions and 29 deletions.
35 changes: 35 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,41 @@ hedron_compile_commands_setup()
hedron_compile_commands_setup_transitive()
hedron_compile_commands_setup_transitive_transitive()
hedron_compile_commands_setup_transitive_transitive_transitive()
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

# LLVM_COMMIT = "b39c5cb6977f35ad727d86b2dd6232099734ffd3"
# LLVM_SHA256 = ""
# http_archive(
# name = "llvm-raw",
# build_file_content = "# empty",
# sha256 = LLVM_SHA256,
# strip_prefix = "llvm-project-" + LLVM_COMMIT,
# urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)],
# )
#
#
# load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
# maybe(
# http_archive,
# name = "llvm_zlib",
# build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD",
# sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731",
# strip_prefix = "zlib-ng-2.0.7",
# urls = [
# "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip",
# ],
# )
#
# maybe(
# http_archive,
# name = "llvm_zstd",
# build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD",
# sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0",
# strip_prefix = "zstd-1.5.2",
# urls = [
# "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz"
# ],
# )

load("//third_party/jax:workspace.bzl", jax_workspace = "repo")
jax_workspace()
Expand Down
10 changes: 9 additions & 1 deletion src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,15 @@ cc_library(
":EnzymeXLAOpsIncGen",
":EnzymeXLAPassesIncGen",
":EnzyeHLOPatternsIncGen",
"@llvm-project//mlir:DLTIDialect",
"@llvm-project//mlir:GPUPipelines",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:ExecutionEngine",
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:OrcTargetProcess",
"@llvm-project//llvm:Passes",
"@llvm-project//llvm:Scalar",
"@llvm-project//llvm:InstCombine",
":mhlo-derivatives",
":stablehlo-derivatives",
":chlo-derivatives",
Expand All @@ -326,7 +330,7 @@ cc_library(
"@llvm-project//mlir:AffineToStandard",
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:TransformDialectInterfaces",
"@llvm-project//mlir:TransformDialectTransforms",
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
Expand Down Expand Up @@ -363,6 +367,10 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:TranslateLib",
"@llvm-project//mlir:FromLLVMIRTranslation",
"@llvm-project//mlir:ToLLVMIRTranslationRegistration",
"@llvm-project//mlir:FromLLVMIRTranslationRegistration",
"@stablehlo//:reference_ops",
"@stablehlo//:stablehlo_ops",
"@stablehlo//:chlo_ops",
Expand Down
11 changes: 8 additions & 3 deletions src/enzyme_ad/jax/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,24 @@ class ReadOnlyKernelArg final
SmallVector<Attribute> outputAliases;
SmallVector<Type> resTys;
size_t out_idx = 0;
for (auto alias_attr : operand_aliases) {
auto alias = cast<OutputOperandAliasAttr>(alias_attr);
for (auto en : llvm::enumerate(operand_aliases)) {
auto idx = en.index();
auto alias = cast<OutputOperandAliasAttr>(en.value());
auto outputTupleIndices = alias.getOutputTupleIndices();
auto operandIndex = alias.getOperandIndex();
auto operandTupleIndices = alias.getOperandTupleIndices();

auto operand = fn.front().getArgument(operandIndex);
assert(launchOp.getInputs()[operandIndex].getType() ==
launchOp.getResultTypes()[idx]);
bool readonly =
fn.getArgAttr(operandIndex, LLVMDialect::getReadonlyAttrName()) ||
fn.getArgAttr(operandIndex, LLVMDialect::getReadnoneAttrName());

if (readonly) {
continue;
}
resTys.push_back(launchOp.getResultTypes()[out_idx]);
resTys.push_back(launchOp.getResultTypes()[idx]);
if (outputs == 1) {
outputAliases.push_back(OutputOperandAliasAttr::get(
launchOp->getContext(), {}, operandIndex, {}));
Expand All @@ -129,6 +132,8 @@ class ReadOnlyKernelArg final
launchOp.getInputs(), launchOp.getBackendConfigAttr(),
launchOp.getOperandLayoutsAttr(), /*resultLayouts*/ nullptr,
ArrayAttr::get(launchOp->getContext(), outputAliases));

assert(outputAliases.size() == newOp.getNumResults());
SmallVector<Value> replacements;
out_idx = 0;
for (auto alias_attr : operand_aliases) {
Expand Down
10 changes: 10 additions & 0 deletions src/enzyme_ad/jax/Passes/ArithRaising.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ struct ArithRaisingPass : public ArithRaisingPassBase<ArithRaisingPass> {
auto op = getOperation();

op->walk([=](arith::AddFOp addOp) {
if (!addOp.getType().isa<RankedTensorType>())
return;
OpBuilder builder(addOp);
Value newAddOp;
if (use_stablehlo)
Expand All @@ -52,6 +54,8 @@ struct ArithRaisingPass : public ArithRaisingPassBase<ArithRaisingPass> {
addOp.erase();
});
op->walk([=](complex::AddOp addOp) {
if (!addOp.getType().isa<RankedTensorType>())
return;
OpBuilder builder(addOp);
Value newAddOp;
if (use_stablehlo)
Expand All @@ -64,6 +68,8 @@ struct ArithRaisingPass : public ArithRaisingPassBase<ArithRaisingPass> {
addOp.erase();
});
op->walk([=](complex::ConjOp addOp) {
if (!addOp.getType().isa<RankedTensorType>())
return;
OpBuilder builder(addOp);
Value newAddOp;
newAddOp =
Expand All @@ -72,6 +78,8 @@ struct ArithRaisingPass : public ArithRaisingPassBase<ArithRaisingPass> {
addOp.erase();
});
op->walk([=](arith::AddIOp addOp) {
if (!addOp.getType().isa<RankedTensorType>())
return;
OpBuilder builder(addOp);
Value newAddOp;
if (use_stablehlo)
Expand All @@ -84,6 +92,8 @@ struct ArithRaisingPass : public ArithRaisingPassBase<ArithRaisingPass> {
addOp.erase();
});
op->walk([=](arith::ConstantOp constOp) {
if (!constOp.getType().isa<RankedTensorType>())
return;
auto CT = constOp.getType();
if (isa<TensorType>(CT)) {
OpBuilder builder(constOp);
Expand Down
Loading

0 comments on commit 362f33f

Please sign in to comment.