diff --git a/WORKSPACE b/WORKSPACE index 72f6283d6..bba22de5d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -11,6 +11,41 @@ hedron_compile_commands_setup() hedron_compile_commands_setup_transitive() hedron_compile_commands_setup_transitive_transitive() hedron_compile_commands_setup_transitive_transitive_transitive() +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +# LLVM_COMMIT = "b39c5cb6977f35ad727d86b2dd6232099734ffd3" +# LLVM_SHA256 = "" +# http_archive( +# name = "llvm-raw", +# build_file_content = "# empty", +# sha256 = LLVM_SHA256, +# strip_prefix = "llvm-project-" + LLVM_COMMIT, +# urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)], +# ) +# +# +# load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") +# maybe( +# http_archive, +# name = "llvm_zlib", +# build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD", +# sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731", +# strip_prefix = "zlib-ng-2.0.7", +# urls = [ +# "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip", +# ], +# ) +# +# maybe( +# http_archive, +# name = "llvm_zstd", +# build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD", +# sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0", +# strip_prefix = "zstd-1.5.2", +# urls = [ +# "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz" +# ], +# ) load("//third_party/jax:workspace.bzl", jax_workspace = "repo") jax_workspace() diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index ca05de105..d69dca5fb 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -304,11 +304,15 @@ cc_library( ":EnzymeXLAOpsIncGen", ":EnzymeXLAPassesIncGen", ":EnzyeHLOPatternsIncGen", + "@llvm-project//mlir:DLTIDialect", "@llvm-project//mlir:GPUPipelines", "@llvm-project//llvm:Core", "@llvm-project//llvm:ExecutionEngine", "@llvm-project//llvm:OrcJIT", "@llvm-project//llvm:OrcTargetProcess", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Scalar", + "@llvm-project//llvm:InstCombine", ":mhlo-derivatives", ":stablehlo-derivatives", ":chlo-derivatives", @@ -326,7 +330,7 @@ cc_library( "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:MemRefTransforms", - "@llvm-project//mlir:TransformDialect", + "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:TransformDialectInterfaces", "@llvm-project//mlir:TransformDialectTransforms", "@llvm-project//mlir:GPUToGPURuntimeTransforms", @@ -363,6 +367,10 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:TranslateLib", + "@llvm-project//mlir:FromLLVMIRTranslation", + "@llvm-project//mlir:ToLLVMIRTranslationRegistration", + "@llvm-project//mlir:FromLLVMIRTranslationRegistration", "@stablehlo//:reference_ops", "@stablehlo//:stablehlo_ops", "@stablehlo//:chlo_ops", diff --git a/src/enzyme_ad/jax/Dialect/Ops.cpp b/src/enzyme_ad/jax/Dialect/Ops.cpp index 1ee69e97f..f1111a45b 100644 --- a/src/enzyme_ad/jax/Dialect/Ops.cpp +++ b/src/enzyme_ad/jax/Dialect/Ops.cpp @@ -97,13 +97,16 @@ class ReadOnlyKernelArg final SmallVector outputAliases; SmallVector resTys; size_t out_idx = 0; - for (auto alias_attr : operand_aliases) { - auto alias = cast(alias_attr); + for (auto en : llvm::enumerate(operand_aliases)) { + auto idx = en.index(); + auto alias = cast(en.value()); auto outputTupleIndices = alias.getOutputTupleIndices(); auto operandIndex = alias.getOperandIndex(); auto operandTupleIndices = alias.getOperandTupleIndices(); auto operand = fn.front().getArgument(operandIndex); + assert(launchOp.getInputs()[operandIndex].getType() == + launchOp.getResultTypes()[idx]); bool readonly = fn.getArgAttr(operandIndex, LLVMDialect::getReadonlyAttrName()) || fn.getArgAttr(operandIndex, LLVMDialect::getReadnoneAttrName()); @@ -111,7 +114,7 @@ class ReadOnlyKernelArg final if (readonly) { continue; } - resTys.push_back(launchOp.getResultTypes()[out_idx]); + resTys.push_back(launchOp.getResultTypes()[idx]); if (outputs == 1) { outputAliases.push_back(OutputOperandAliasAttr::get( launchOp->getContext(), {}, operandIndex, {})); @@ -129,6 +132,8 @@ class ReadOnlyKernelArg final launchOp.getInputs(), launchOp.getBackendConfigAttr(), launchOp.getOperandLayoutsAttr(), /*resultLayouts*/ nullptr, ArrayAttr::get(launchOp->getContext(), outputAliases)); + + assert(outputAliases.size() == newOp.getNumResults()); SmallVector replacements; out_idx = 0; for (auto alias_attr : operand_aliases) { diff --git a/src/enzyme_ad/jax/Passes/ArithRaising.cpp b/src/enzyme_ad/jax/Passes/ArithRaising.cpp index fcab14d0c..eedc3b54b 100644 --- a/src/enzyme_ad/jax/Passes/ArithRaising.cpp +++ b/src/enzyme_ad/jax/Passes/ArithRaising.cpp @@ -40,6 +40,8 @@ struct ArithRaisingPass : public ArithRaisingPassBase { auto op = getOperation(); op->walk([=](arith::AddFOp addOp) { + if (!addOp.getType().isa()) + return; OpBuilder builder(addOp); Value newAddOp; if (use_stablehlo) @@ -52,6 +54,8 @@ struct ArithRaisingPass : public ArithRaisingPassBase { addOp.erase(); }); op->walk([=](complex::AddOp addOp) { + if (!addOp.getType().isa()) + return; OpBuilder builder(addOp); Value newAddOp; if (use_stablehlo) @@ -64,6 +68,8 @@ struct ArithRaisingPass : public ArithRaisingPassBase { addOp.erase(); }); op->walk([=](complex::ConjOp addOp) { + if (!addOp.getType().isa()) + return; OpBuilder builder(addOp); Value newAddOp; newAddOp = @@ -72,6 +78,8 @@ struct ArithRaisingPass : public ArithRaisingPassBase { addOp.erase(); }); op->walk([=](arith::AddIOp addOp) { + if (!addOp.getType().isa()) + return; OpBuilder builder(addOp); Value newAddOp; if (use_stablehlo) @@ -84,6 +92,8 @@ struct ArithRaisingPass : public ArithRaisingPassBase { addOp.erase(); }); op->walk([=](arith::ConstantOp constOp) { + if (!constOp.getType().isa()) + return; auto CT = constOp.getType(); if (isa(CT)) { OpBuilder builder(constOp); diff --git a/src/enzyme_ad/jax/Passes/LibDeviceFuncsRaisingPass.cpp b/src/enzyme_ad/jax/Passes/LibDeviceFuncsRaisingPass.cpp index 93efd1874..0a00f02fc 100644 --- a/src/enzyme_ad/jax/Passes/LibDeviceFuncsRaisingPass.cpp +++ b/src/enzyme_ad/jax/Passes/LibDeviceFuncsRaisingPass.cpp @@ -14,9 +14,135 @@ #include "src/enzyme_ad/jax/Passes/PassDetails.h" #include "src/enzyme_ad/jax/Passes/Passes.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" + using namespace mlir; using namespace mlir::enzyme; +template typename AttrConvert = + AttrConvertPassThrough> +class VectorConvertFromLLVMPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SourceOp op, + PatternRewriter &rewriter) const override { + static_assert( + std::is_base_of, SourceOp>::value, + "expected single result op"); + // Determine attributes for the target op + AttrConvert attrConvert(op); + + auto operands = op->getOperands(); + auto llvmNDVectorTy = operands[0].getType(); + if (isa(llvmNDVectorTy)) { + return failure(); + } + Operation *newOp = rewriter.create( + op->getLoc(), rewriter.getStringAttr(TargetOp::getOperationName()), + operands, op->getResultTypes(), attrConvert.getAttrs()); + + rewriter.replaceOp(op, newOp->getResult(0)); + return success(); + } +}; + +arith::IntegerOverflowFlags +convertArithOverflowFlagsFromLLVM(LLVM::IntegerOverflowFlags llvmFlags) { + arith::IntegerOverflowFlags arithFlags{}; + const std::pair + flags[] = { + {arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw}, + {arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}}; + for (auto [arithFlag, llvmFlag] : flags) { + if (bitEnumContainsAny(llvmFlags, llvmFlag)) + arithFlags = arithFlags | arithFlag; + } + return arithFlags; +} + +template +class AttrConvertOverflowFromLLVM { +public: + AttrConvertOverflowFromLLVM(SourceOp srcOp) { + // Copy the source attributes. + convertedAttr = NamedAttrList{srcOp->getAttrs()}; + // Get the name of the arith overflow attribute. + StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName(); + // Remove the source overflow attribute. + if (auto arithAttr = dyn_cast_if_present( + convertedAttr.erase(arithAttrName))) { + if (arithAttr.getValue() != LLVM::IntegerOverflowFlags::none) { + StringRef targetAttrName = TargetOp::getOverflowFlagsAttrName(); + convertedAttr.set(targetAttrName, arith::IntegerOverflowFlagsAttr::get( + srcOp->getContext(), + convertArithOverflowFlagsFromLLVM( + arithAttr.getValue()))); + } + } + } + + ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } + +private: + NamedAttrList convertedAttr; +}; + +arith::FastMathFlags +convertArithFastMathFlagsFromLLVM(LLVM::FastmathFlags llvmFMF) { + arith::FastMathFlags arithFMF{}; + const std::pair flags[] = { + {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan}, + {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf}, + {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz}, + {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp}, + {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract}, + {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn}, + {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}}; + for (auto [arithFlag, llvmFlag] : flags) { + if (bitEnumContainsAny(llvmFMF, llvmFlag)) + arithFMF = arithFMF | arithFlag; + } + return arithFMF; +} + +arith::FastMathFlagsAttr +convertArithFastMathAttrFromLLVM(LLVM::FastmathFlagsAttr fmfAttr) { + auto arithFMF = fmfAttr.getValue(); + return arith::FastMathFlagsAttr::get( + fmfAttr.getContext(), convertArithFastMathFlagsFromLLVM(arithFMF)); +} + +// Attribute converter that populates a NamedAttrList by removing the fastmath +// attribute from the source operation attributes, and replacing it with an +// equivalent LLVM fastmath attribute. +template +class AttrConvertFastMathFromLLVM { +public: + AttrConvertFastMathFromLLVM(SourceOp srcOp) { + // Copy the source attributes. + convertedAttr = NamedAttrList{srcOp->getAttrs()}; + // Get the name of the arith fastmath attribute. + StringRef arithFMFAttrName = SourceOp::getFastmathAttrName(); + // Remove the source fastmath attribute. + auto arithFMFAttr = dyn_cast_if_present( + convertedAttr.erase(arithFMFAttrName)); + if (arithFMFAttr && + arithFMFAttr.getValue() != mlir::LLVM::FastmathFlags::none) { + StringRef targetAttrName = TargetOp::getFastMathAttrName(); + convertedAttr.set(targetAttrName, + convertArithFastMathAttrFromLLVM(arithFMFAttr)); + } + } + + ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } + +private: + NamedAttrList convertedAttr; +}; + namespace { template class CallToOpRaising : public OpRewritePattern { @@ -58,6 +184,151 @@ static void populateOpPatterns(MLIRContext *context, patterns.add>(context, f16Func); } +namespace { + +// From +// https://github.com/llvm/llvm-project/blob/7d8b4eb0ead277f41ff69525ed807f9f6e227f37/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp#L31 +// except we invert source and target +template +using ConvertFastMath = AttrConvertFastMathFromLLVM; + +template typename AttrConvert = + AttrConvertPassThrough> +using InvVectorConvertFromLLVMPattern = + VectorConvertFromLLVMPattern; + +template +using ConvertFMFMathFromLLVMPattern = + VectorConvertFromLLVMPattern; + +using AbsFOpLowering = + ConvertFMFMathFromLLVMPattern; +using CeilOpLowering = + ConvertFMFMathFromLLVMPattern; +using CopySignOpLowering = + ConvertFMFMathFromLLVMPattern; +using CosOpLowering = ConvertFMFMathFromLLVMPattern; +using CtPopFOpLowering = + VectorConvertFromLLVMPattern; +using Exp2OpLowering = + ConvertFMFMathFromLLVMPattern; +using ExpOpLowering = ConvertFMFMathFromLLVMPattern; +using FloorOpLowering = + ConvertFMFMathFromLLVMPattern; +using FmaOpLowering = ConvertFMFMathFromLLVMPattern; +using Log10OpLowering = + ConvertFMFMathFromLLVMPattern; +using Log2OpLowering = + ConvertFMFMathFromLLVMPattern; +using LogOpLowering = ConvertFMFMathFromLLVMPattern; +using PowFOpLowering = ConvertFMFMathFromLLVMPattern; +using FPowIOpLowering = + ConvertFMFMathFromLLVMPattern; +using RoundEvenOpLowering = + ConvertFMFMathFromLLVMPattern; +using RoundOpLowering = + ConvertFMFMathFromLLVMPattern; +using SinOpLowering = ConvertFMFMathFromLLVMPattern; +using SqrtOpLowering = + ConvertFMFMathFromLLVMPattern; +using FTruncOpLowering = + ConvertFMFMathFromLLVMPattern; + +using AddFOpLowering = + InvVectorConvertFromLLVMPattern; +using AddIOpLowering = + InvVectorConvertFromLLVMPattern; +using AndIOpLowering = + InvVectorConvertFromLLVMPattern; +using BitcastOpLowering = + InvVectorConvertFromLLVMPattern; +using DivFOpLowering = + InvVectorConvertFromLLVMPattern; +using DivSIOpLowering = + InvVectorConvertFromLLVMPattern; +using DivUIOpLowering = + InvVectorConvertFromLLVMPattern; +using ExtFOpLowering = + InvVectorConvertFromLLVMPattern; +using ExtSIOpLowering = + InvVectorConvertFromLLVMPattern; +using ExtUIOpLowering = + InvVectorConvertFromLLVMPattern; +using FPToSIOpLowering = + InvVectorConvertFromLLVMPattern; +using FPToUIOpLowering = + InvVectorConvertFromLLVMPattern; +using MaximumFOpLowering = + InvVectorConvertFromLLVMPattern; +using MaxNumFOpLowering = + InvVectorConvertFromLLVMPattern; +using MaxSIOpLowering = + InvVectorConvertFromLLVMPattern; +using MaxUIOpLowering = + InvVectorConvertFromLLVMPattern; +using MinimumFOpLowering = + InvVectorConvertFromLLVMPattern; +using MinNumFOpLowering = + InvVectorConvertFromLLVMPattern; +using MinSIOpLowering = + InvVectorConvertFromLLVMPattern; +using MinUIOpLowering = + InvVectorConvertFromLLVMPattern; +using MulFOpLowering = + InvVectorConvertFromLLVMPattern; +using MulIOpLowering = + InvVectorConvertFromLLVMPattern; +using NegFOpLowering = + InvVectorConvertFromLLVMPattern; +using OrIOpLowering = InvVectorConvertFromLLVMPattern; +using RemFOpLowering = + InvVectorConvertFromLLVMPattern; +using RemSIOpLowering = + InvVectorConvertFromLLVMPattern; +using RemUIOpLowering = + InvVectorConvertFromLLVMPattern; +using SelectOpLowering = + InvVectorConvertFromLLVMPattern; +using ShLIOpLowering = + InvVectorConvertFromLLVMPattern; +using ShRSIOpLowering = + InvVectorConvertFromLLVMPattern; +using ShRUIOpLowering = + InvVectorConvertFromLLVMPattern; +using SIToFPOpLowering = + InvVectorConvertFromLLVMPattern; +using SubFOpLowering = + InvVectorConvertFromLLVMPattern; +using SubIOpLowering = + InvVectorConvertFromLLVMPattern; +// using TruncFOpLowering = +// ConstrainedVectorConvertFromLLVMPattern; +// using ConstrainedTruncFOpLowering = ConstrainedVectorConvertFromLLVMPattern< +// arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true, +// arith::AttrConverterConstrainedFPFromLLVM>; +using TruncIOpLowering = + VectorConvertFromLLVMPattern; +using UIToFPOpLowering = + VectorConvertFromLLVMPattern; +using XOrIOpLowering = VectorConvertFromLLVMPattern; +} // namespace + void mlir::enzyme::populateLibDeviceFuncsToOpsPatterns( MLIRContext *context, RewritePatternSet &patterns) { // XXX: Keep in sync with @@ -130,6 +401,50 @@ void mlir::enzyme::populateLibDeviceFuncsToOpsPatterns( "__nv_tanh"); } +void populateLLVMToMathPatterns(MLIRContext *context, + RewritePatternSet &patterns) { + auto *converter = context; + // From + // https://github.com/llvm/llvm-project/blob/7d8b4eb0ead277f41ff69525ed807f9f6e227f37/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp#L306 + // patterns.add(converter); + patterns.add(converter); + + patterns + .add(converter); +} + namespace { class LibDeviceFuncsRaisingPass : public LibDeviceFuncsRaisingPassBase { @@ -138,6 +453,7 @@ class LibDeviceFuncsRaisingPass void runOnOperation() override { RewritePatternSet patterns(getOperation()->getContext()); + populateLLVMToMathPatterns(getOperation()->getContext(), patterns); populateLibDeviceFuncsToOpsPatterns(getOperation()->getContext(), patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { emitError(getOperation()->getLoc()) << "failed to raise __nv functions"; diff --git a/src/enzyme_ad/jax/Passes/LowerKernel.cpp b/src/enzyme_ad/jax/Passes/LowerKernel.cpp index 7613b6f96..2dc17767a 100644 --- a/src/enzyme_ad/jax/Passes/LowerKernel.cpp +++ b/src/enzyme_ad/jax/Passes/LowerKernel.cpp @@ -225,6 +225,7 @@ CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp, std::unique_ptr ctx(new llvm::LLVMContext); auto llvmModule = translateModuleToLLVMIR(modOp, *ctx); if (!llvmModule) { + llvm::errs() << "modOp: " << *modOp << "\n"; llvm::errs() << "could not convert to LLVM IR\n"; return {}; } @@ -312,7 +313,9 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, std::string cubinFeatures, size_t cuLaunchKernelPtr, size_t cuModuleLoadDataPtr, size_t cuModuleGetFunctionPtr, bool compileLaunch, - bool run_init) { + bool run_init, enzymexla::KernelCallOp kernelCallOp, + bool debug, size_t cuResultHandlerPtr, + size_t cuStreamSynchronizePtr) { OpBuilder builder(op); @@ -532,8 +535,12 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, idx, i32, ptrty, ptrty, ptrty}; auto launch_ty = LLVM::LLVMFunctionType::get(i32, cutys); + mlir::Type curesulttys[] = {i32}; + auto curesult_handler_ty = + LLVM::LLVMFunctionType::get(voidty, curesulttys); LLVM::LLVMFuncOp launch = builder.create(loc, "cuLaunchKernel", launch_ty); + auto cusync_ty = LLVM::LLVMFunctionType::get(i32, {ptrty}); mlir::Type cufunctys[] = {ptrty, ptrty, ptrty}; auto funcload_ty = LLVM::LLVMFunctionType::get(i32, cufunctys); @@ -556,15 +563,61 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, loc, "nv_func_init", LLVM::LLVMFunctionType::get(ptrty, {}, false), LLVM::Linkage::External); - LLVM::GlobalOp printStrFunc; - { - std::string value = "found pointer func = %p\n"; + LLVM::LLVMFuncOp printfunc = nullptr; + LLVM::LLVMFuncOp putfunc = nullptr; + + if (debug) { + printfunc = builder.create( + loc, "printf", + LLVM::LLVMFunctionType::get(ptrty, {ptrty, ptrty}, false), + LLVM::Linkage::External); + printfunc.setVisibility(SymbolTable::Visibility::Private); + putfunc = builder.create( + loc, "puts", LLVM::LLVMFunctionType::get(voidty, {ptrty}, false), + LLVM::Linkage::External); + putfunc.setVisibility(SymbolTable::Visibility::Private); + } + + LLVM::GlobalOp loadModuleStr = nullptr; + if (debug) { + std::string value = "load Module result = %d\n"; + auto type = LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1); + loadModuleStr = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strmod", + builder.getStringAttr(value + '\0')); + } + LLVM::GlobalOp loadFuncStr = nullptr; + if (debug) { + std::string value = "load Func result = %d\n"; auto type = LLVM::LLVMArrayType::get( mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1); - printStrFunc = builder.create( + loadFuncStr = builder.create( loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strfunc", builder.getStringAttr(value + '\0')); } + LLVM::GlobalOp launchKernelStr = nullptr; + if (debug) { + std::string value = "launch Kernel result = %d\n"; + auto type = LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1); + launchKernelStr = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, + "strlaunch", builder.getStringAttr(value + '\0')); + } + LLVM::GlobalOp modOpStr = nullptr; + if (debug) { + std::string opstr; + llvm::raw_string_ostream ss(opstr); + + ss << kernelCallOp; + std::string value = "modstr=" + modstr + "\n" + opstr + "\n\n"; + auto type = LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1); + modOpStr = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, + "strmlirmod", builder.getStringAttr(value + '\0')); + } LLVM::GlobalOp binary = nullptr; submod.walk([&](gpu::BinaryOp op) { @@ -608,15 +661,36 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, SmallVector modargs = {modptr->getResult(0), addr_modbin->getResult(0)}; + mlir::Value loadModRes = nullptr; if (cuModuleLoadDataPtr) { auto addr_glob_int = builder.create( loc, i64, builder.getI64IntegerAttr(cuModuleLoadDataPtr)); auto addr_glob = builder.create(loc, ptrty, addr_glob_int); modargs.insert(modargs.begin(), addr_glob); - builder.create(loc, modload_ty, modargs); + loadModRes = builder.create(loc, modload_ty, modargs) + ->getResult(0); } else { - builder.create(loc, modload, modargs); + loadModRes = + builder.create(loc, modload, modargs)->getResult(0); + } + + if (debug) { + Value printargs1[] = { + builder.create(loc, loadModuleStr) + ->getResult(0), + builder.create(loc, ptrty, loadModRes) + ->getResult(0)}; + builder.create(loc, printfunc, printargs1); + } + if (cuResultHandlerPtr) { + auto addr_glob_int = builder.create( + loc, i64, builder.getI64IntegerAttr(cuResultHandlerPtr)); + auto addr_glob = + builder.create(loc, ptrty, addr_glob_int) + ->getResult(0); + mlir::Value args[2] = {addr_glob, loadModRes}; + builder.create(loc, curesult_handler_ty, args); } auto mod = builder.create(loc, ptrty, modptr); @@ -627,15 +701,36 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, SmallVector funcargs = {funcptr->getResult(0), mod->getResult(0), addr_kernstr->getResult(0)}; + mlir::Value loadFuncRes = nullptr; if (cuModuleGetFunctionPtr) { auto addr_glob_int = builder.create( loc, i64, builder.getI64IntegerAttr(cuModuleGetFunctionPtr)); auto addr_glob = builder.create(loc, ptrty, addr_glob_int); funcargs.insert(funcargs.begin(), addr_glob); - builder.create(loc, funcload_ty, funcargs); + loadFuncRes = + builder.create(loc, funcload_ty, funcargs) + ->getResult(0); } else { - builder.create(loc, funcload, funcargs); + loadFuncRes = builder.create(loc, funcload, funcargs) + ->getResult(0); + } + + if (debug) { + Value printargs1[] = { + builder.create(loc, loadFuncStr)->getResult(0), + builder.create(loc, ptrty, loadFuncRes) + ->getResult(0)}; + builder.create(loc, printfunc, printargs1); + } + if (cuResultHandlerPtr) { + auto addr_glob_int = builder.create( + loc, i64, builder.getI64IntegerAttr(cuResultHandlerPtr)); + auto addr_glob = + builder.create(loc, ptrty, addr_glob_int) + ->getResult(0); + mlir::Value args[2] = {addr_glob, loadFuncRes}; + builder.create(loc, curesult_handler_ty, args); } auto func = builder.create(loc, ptrty, funcptr); @@ -666,15 +761,65 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, params, builder.create(loc, ptrty)}; + Value kernRes; if (cuLaunchKernelPtr) { auto addr_glob_int = builder.create( loc, i64, builder.getI64IntegerAttr(cuLaunchKernelPtr)); auto addr_glob = builder.create(loc, ptrty, addr_glob_int); args.insert(args.begin(), addr_glob); - builder.create(loc, launch_ty, args)->getResult(0); + kernRes = + builder.create(loc, launch_ty, args)->getResult(0); } else { - builder.create(loc, launch, args)->getResult(0); + kernRes = + builder.create(loc, launch, args)->getResult(0); + } + if (debug) { + Value printargs1[] = { + builder.create(loc, launchKernelStr) + ->getResult(0), + builder.create(loc, ptrty, kernRes) + ->getResult(0)}; + builder.create(loc, printfunc, printargs1); + } + if (debug) { + Value printargs1[] = { + builder.create(loc, modOpStr)->getResult(0)}; + builder.create(loc, putfunc, printargs1); + } + if (cuResultHandlerPtr) { + auto addr_glob_int = builder.create( + loc, i64, builder.getI64IntegerAttr(cuResultHandlerPtr)); + auto addr_glob = + builder.create(loc, ptrty, addr_glob_int) + ->getResult(0); + mlir::Value args[2] = {addr_glob, kernRes}; + builder.create(loc, curesult_handler_ty, args); + } + + if (cuStreamSynchronizePtr) { + auto addr_glob_int = builder.create( + loc, i64, builder.getI64IntegerAttr(cuStreamSynchronizePtr)); + auto addr_glob = + builder.create(loc, ptrty, addr_glob_int); + mlir::Value args[2] = {addr_glob, op.getAsyncObject()}; + auto syncRes = + builder.create(loc, cusync_ty, args)->getResult(0); + + if (debug) { + Value printargs1[] = { + builder.create(loc, modOpStr)->getResult(0)}; + builder.create(loc, putfunc, printargs1); + } + if (cuResultHandlerPtr) { + auto addr_glob_int = builder.create( + loc, i64, builder.getI64IntegerAttr(cuResultHandlerPtr)); + auto addr_glob = + builder.create(loc, ptrty, addr_glob_int) + ->getResult(0); + mlir::Value args[2] = {addr_glob, syncRes}; + builder.create(loc, curesult_handler_ty, args); + } } op.erase(); @@ -755,6 +900,12 @@ struct LowerKernelPass : public LowerKernelPassBase { auto *symbolOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()); auto fn = cast(symbolOp); + if (fn.getArguments().size() != op.getInputs().size()) { + op->emitError() << "Kernel_call had " << op.getInputs().size() + << " whereas called kernel requires " + << fn.getArguments().size() << "\n"; + return; + } Value vals[] = {op.getGridx(), op.getGridy(), op.getGridz(), op.getBlockx(), op.getBlocky(), op.getBlockz(), @@ -781,7 +932,8 @@ struct LowerKernelPass : public LowerKernelPassBase { data[5], data[6], data[7], toolkitPath.getValue(), linkFilesArray, indexBitWidth.getValue(), cubinChip.getValue(), cubinFeatures.getValue(), cuLaunchKernelPtr, cuModuleLoadDataPtr, - cuModuleGetFunctionPtr, compileLaunch, run_init); + cuModuleGetFunctionPtr, compileLaunch, run_init, op, debug, + cuResultHandlerPtr, cuStreamSynchronizePtr); std::string backendinfo((char *)&cdata, sizeof(CallInfo)); diff --git a/src/enzyme_ad/jax/Passes/PassDetails.h b/src/enzyme_ad/jax/Passes/PassDetails.h index 5c0a7385f..f3c00740c 100644 --- a/src/enzyme_ad/jax/Passes/PassDetails.h +++ b/src/enzyme_ad/jax/Passes/PassDetails.h @@ -17,6 +17,7 @@ #ifndef DIALECT_ENZYMEXLA_TRANSFORMS_PASSDETAILS_H #define DIALECT_ENZYMEXLA_TRANSFORMS_PASSDETAILS_H +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "src/enzyme_ad/jax/Passes/Passes.h" diff --git a/src/enzyme_ad/jax/Passes/Passes.h b/src/enzyme_ad/jax/Passes/Passes.h index 6e4b3af11..ffd3adff3 100644 --- a/src/enzyme_ad/jax/Passes/Passes.h +++ b/src/enzyme_ad/jax/Passes/Passes.h @@ -24,6 +24,7 @@ std::unique_ptr createEnzymeHLOUnrollPass(); std::unique_ptr createPrintPass(); std::unique_ptr createLowerKernelPass(); std::unique_ptr createLibDeviceFuncsRaisingPass(); +std::unique_ptr createSROAWrappersPass(); void populateLibDeviceFuncsToOpsPatterns(MLIRContext *context, RewritePatternSet &patterns); @@ -39,6 +40,8 @@ namespace mlir { template void registerDialect(DialectRegistry ®istry); +class DLTIDialect; + namespace mhlo { class MhloDialect; } // end namespace mhlo @@ -109,14 +112,4 @@ class LLVMDialect; } // end namespace mlir -static void regsiterenzymeXLAPasses() { - using namespace mlir; - registerArithRaisingPass(); - registerPrintPass(); - registerEnzymeHLOOptPass(); - registerEnzymeHLOUnrollPass(); - registerLowerKernelPass(); - registerConsumingInterpreterPass(); - registerLibDeviceFuncsRaisingPass(); -} #endif // ENZYMEXLA_PASSES_H diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 5529eba05..8e8645891 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -126,6 +126,19 @@ def PrintPass : Pass<"print"> { ]; } +def SROAWrappersPass : Pass<"sroa-wrappers", "mlir::ModuleOp"> { + let summary = ""; + let dependentDialects = []; + let constructor = "mlir::enzyme::createSROAWrappersPass()"; + let dependentDialects = [ + "mlir::LLVM::LLVMDialect", + "mlir::DLTIDialect", + "mlir::NVVM::NVVMDialect", + "mlir::arith::ArithDialect", + "mlir::math::MathDialect", + ]; +} + def LibDeviceFuncsRaisingPass : Pass<"libdevice-funcs-raise"> { let summary = "Raise libdevice function calls to arith/math operations"; let dependentDialects = [ @@ -233,6 +246,27 @@ def LowerKernelPass : Pass<"lower-kernel"> { /*default=*/"false", /*description=*/"Run initialization of cuda module" >, + Option< + /*C++ variable name=*/"debug", + /*CLI argument=*/"debug", + /*type=*/"bool", + /*default=*/"false", + /*description=*/"Compile in debug prints" + >, + Option< + /*C++ variable name=*/"cuResultHandlerPtr", + /*CLI argument=*/"cuResultHandlerPtr", + /*type=*/"size_t", + /*default=*/"0", + /*description=*/"Function handler to call with result of curesult" + >, + Option< + /*C++ variable name=*/"cuStreamSynchronizePtr", + /*CLI argument=*/"cuStreamSynchronizePtr", + /*type=*/"size_t", + /*default=*/"0", + /*description=*/"Function handler to sync results" + >, ]; } diff --git a/src/enzyme_ad/jax/Passes/SROAWrappers.cpp b/src/enzyme_ad/jax/Passes/SROAWrappers.cpp new file mode 100644 index 000000000..bf6b48a1b --- /dev/null +++ b/src/enzyme_ad/jax/Passes/SROAWrappers.cpp @@ -0,0 +1,161 @@ +//===- SROAWrappers.cpp - Run SROA on ABI conversion wrappers --------------- // +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to print the MLIR module +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Target/LLVMIR/ModuleImport.h" + +#include "llvm/IR/PassManager.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Transforms/IPO/Attributor.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Scalar/SROA.h" + +#include "src/enzyme_ad/jax/Passes/PassDetails.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "llvm/Transforms/IPO/FunctionAttrs.h" + +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" + +#include + +#define DEBUG_TYPE "sroa-wrappers" + +using namespace mlir::enzyme; + +namespace { +struct SROAWrappersPass : public SROAWrappersPassBase { + void runOnOperation() override { + mlir::ModuleOp m = getOperation(); + + mlir::OpBuilder b(m); + + auto mToTranslate = b.cloneWithoutRegions(m); + + llvm::SmallVector toOpt; + for (auto [oldRegion, newRegion] : + llvm::zip(m->getRegions(), mToTranslate->getRegions())) { + for (auto &oldBlock : oldRegion.getBlocks()) { + assert(oldBlock.getNumArguments() == 0); + auto newBlock = b.createBlock(&newRegion, newRegion.end()); + for (auto &op : oldBlock) { + assert(op.hasTrait()); + // FIXME in reality, this check should be whether the entirety + // (all nested ops with all (transitively) used symbol as well) of + // the op is translatable to llvm ir. + // FIXME we also need to mark them `used` so the llvm optimizer + // does not get rid of them. + if (llvm::isa(op.getDialect())) { + // There should be no need for mapping because all top level + // operations in the module should be isolated from above + b.clone(op); + toOpt.push_back(&op); + } + } + } + } + + mlir::PassManager pm(mToTranslate.getContext()); + pm.addPass(mlir::createConvertMathToLLVMPass()); + pm.addPass(mlir::createArithToLLVMConversionPass()); + pm.addPass(mlir::createConvertNVVMToLLVMPass()); + auto subres = pm.run(mToTranslate); + if (!subres.succeeded()) { + return; + } + + llvm::LLVMContext llvmCtx; + auto llvmModule = mlir::translateModuleToLLVMIR(mToTranslate, llvmCtx); + + { + using namespace llvm; + PipelineTuningOptions PTO; + PTO.LoopUnrolling = false; + PTO.LoopInterleaving = false; + PTO.LoopVectorization = false; + PTO.SLPVectorization = false; + PTO.MergeFunctions = false; + PTO.CallGraphProfile = false; + PTO.UnifiedLTO = false; + + LoopAnalysisManager LAM; + FunctionAnalysisManager FAM; + CGSCCAnalysisManager CGAM; + ModuleAnalysisManager MAM; + + PassInstrumentationCallbacks PIC; + PassBuilder PB(nullptr, PTO, std::nullopt, nullptr); + + PB.registerModuleAnalyses(MAM); + PB.registerCGSCCAnalyses(CGAM); + PB.registerFunctionAnalyses(FAM); + PB.registerLoopAnalyses(LAM); + PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); + + ModulePassManager MPM; + FunctionPassManager FPM; + MPM.addPass( + createModuleToFunctionPassAdaptor(SROAPass(SROAOptions::ModifyCFG))); + MPM.addPass(createModuleToFunctionPassAdaptor(InstCombinePass())); + MPM.addPass(createModuleToFunctionPassAdaptor(InstCombinePass())); + MPM.addPass(llvm::AttributorPass()); + MPM.run(*llvmModule, MAM); + } + auto translatedFromLLVMIR = mlir::translateLLVMIRToModule( + std::move(llvmModule), m->getContext(), /*emitExpensiveWarnings*/ true, + /*dropDICompositeTypeElements*/ false, /*loadAllDialects*/ false); + + b.setInsertionPoint(m); + mlir::ModuleOp newM = *translatedFromLLVMIR; + + for (auto op : toOpt) { + op->erase(); + } + for (auto [oldRegion, newRegion] : + llvm::zip(m->getRegions(), newM->getRegions())) { + for (auto [oldBlock, newBlock] : + llvm::zip(oldRegion.getBlocks(), newRegion.getBlocks())) { + b.setInsertionPointToEnd(&oldBlock); + for (auto &op : newBlock) { + assert(op.hasTrait()); + assert(llvm::isa(op.getDialect())); + // There should be no need for mapping because all top level + // operations in the module should be isolated from above + b.clone(op); + } + } + } + + mToTranslate->erase(); + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace enzyme { +std::unique_ptr createSROAWrappersPass() { + return std::make_unique(); +} +} // namespace enzyme +} // namespace mlir diff --git a/src/enzyme_ad/jax/RegistryUtils.cpp b/src/enzyme_ad/jax/RegistryUtils.cpp index 673f8a7d2..2fc0b9c49 100644 --- a/src/enzyme_ad/jax/RegistryUtils.cpp +++ b/src/enzyme_ad/jax/RegistryUtils.cpp @@ -5,6 +5,7 @@ #include "Implementations/XLADerivatives.h" #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "Dialect/Dialect.h" @@ -44,6 +45,9 @@ #include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" + namespace mlir { namespace enzyme { void registerEnzymeJaxTransformExtension(mlir::DialectRegistry ®istry); @@ -82,6 +86,8 @@ void prepareRegistry(mlir::DialectRegistry ®istry) { mlir::enzyme::registerXLAAutoDiffInterfaces(registry); mlir::func::registerInlinerExtension(registry); + mlir::LLVM::registerInlinerInterface(registry); + mlir::NVVM::registerInlinerInterface(registry); mlir::registerConvertNVVMToLLVMInterface(registry); @@ -115,4 +121,7 @@ void prepareRegistry(mlir::DialectRegistry ®istry) { mlir::linalg::registerTransformDialectExtension(registry); mlir::enzyme::registerEnzymeJaxTransformExtension(registry); + + mlir::registerLLVMDialectImport(registry); + mlir::registerNVVMDialectImport(registry); } diff --git a/src/enzyme_ad/jax/enzyme_call.cc b/src/enzyme_ad/jax/enzyme_call.cc index 66c43f4bf..758cc589d 100644 --- a/src/enzyme_ad/jax/enzyme_call.cc +++ b/src/enzyme_ad/jax/enzyme_call.cc @@ -1046,7 +1046,7 @@ PYBIND11_MODULE(enzyme_call, m) { mlir::arith::registerArithPasses(); mlir::memref::registerMemRefPasses(); mlir::registerenzymePasses(); - regsiterenzymeXLAPasses(); + mlir::regsiterenzymeXLAPasses(); mlir::enzyme::registerGenerateApplyPatternsPass(); mlir::enzyme::registerRemoveTransformPass(); mlir::stablehlo::registerPasses(); diff --git a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp index 44d2463c6..63e757a1a 100644 --- a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp +++ b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp @@ -88,7 +88,7 @@ int main(int argc, char **argv) { prepareRegistry(registry); mlir::registerenzymePasses(); - regsiterenzymeXLAPasses(); + mlir::registerenzymexlaPasses(); // Register the standard passes we want. mlir::registerCSEPass(); @@ -114,6 +114,12 @@ int main(int argc, char **argv) { PtrElementModel>(*ctx); LLVM::LLVMArrayType::attachInterface>( *ctx); + + // This is very stupid but we need it because the SROAWrappers pass does a + // round trip to LLVM and the translation from LLVMIR to MLIR loads all + // available dialects and doing that in a pass is forbidden. Preload them + // here. + // ctx->loadAllAvailableDialects(); }); // Transform dialect and extensions. diff --git a/test/lit_tests/sroa-wrappers.mlir b/test/lit_tests/sroa-wrappers.mlir new file mode 100644 index 000000000..cf22e16a9 --- /dev/null +++ b/test/lit_tests/sroa-wrappers.mlir @@ -0,0 +1,53 @@ +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(inline{default-pipeline=canonicalize max-iterations=4},sroa-wrappers)" | FileCheck %s +#tbaa_root = #llvm.tbaa_root +#tbaa_type_desc = #llvm.tbaa_type_desc}> +#tbaa_tag = #llvm.tbaa_tag +module { + llvm.func local_unnamed_addr @_Z8tuplef2_5TupleI5Int6413CuTracedArrayIS0_Li0ELi1E2__EE(%arg0: !llvm.struct<(i64, array<1 x ptr<1>>)>) attributes {sym_visibility = "private"} { + %0 = llvm.extractvalue %arg0[0] : !llvm.struct<(i64, array<1 x ptr<1>>)> + %1 = llvm.extractvalue %arg0[1] : !llvm.struct<(i64, array<1 x ptr<1>>)> + %2 = llvm.extractvalue %1[0] : !llvm.array<1 x ptr<1>> + %3 = llvm.bitcast %2 : !llvm.ptr<1> to !llvm.ptr<1> + %4 = llvm.load %3 {alignment = 1 : i64, tbaa = [#tbaa_tag]} : !llvm.ptr<1> -> i64 + %5 = llvm.mul %4, %0 : i64 + llvm.store %5, %3 {alignment = 1 : i64, tbaa = [#tbaa_tag]} : i64, !llvm.ptr<1> + llvm.return + } + llvm.func ptx_kernelcc @"##call__Z8tuplef2_5TupleI5Int6413CuTracedArrayIS0_Li0ELi1E2__EE#258"(%arg0: !llvm.ptr<1>) attributes {sym_visibility = "private"} { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x !llvm.struct<(i64, array<1 x ptr<1>>)> : (i64) -> !llvm.ptr + %2 = llvm.mlir.constant(dense<[5, 0, 0, 0, 0, 0, 0, 0, 112, 231, 165, 87, 9, 117, 0, 0]> : tensor<16xui8>) : !llvm.array<16 x i8> + llvm.store %2, %1 : !llvm.array<16 x i8>, !llvm.ptr + %3 = llvm.getelementptr %1[8] : (!llvm.ptr) -> !llvm.ptr, ui8 + llvm.store %arg0, %3 : !llvm.ptr<1>, !llvm.ptr + %4 = llvm.load %1 : !llvm.ptr -> !llvm.struct<(i64, array<1 x ptr<1>>)> + llvm.call @_Z8tuplef2_5TupleI5Int6413CuTracedArrayIS0_Li0ELi1E2__EE(%4) : (!llvm.struct<(i64, array<1 x ptr<1>>)>) -> () + llvm.return + } + func.func @main(%arg0: tensor) -> tensor { + %c = stablehlo.constant dense<1> : tensor + %c_0 = stablehlo.constant dense<1> : tensor + %c_1 = stablehlo.constant dense<1> : tensor + %c_2 = stablehlo.constant dense<1> : tensor + %c_3 = stablehlo.constant dense<1> : tensor + %c_4 = stablehlo.constant dense<1> : tensor + %c_5 = stablehlo.constant dense<0> : tensor + %0 = enzymexla.kernel_call @"##call__Z8tuplef2_5TupleI5Int6413CuTracedArrayIS0_Li0ELi1E2__EE#258" blocks in(%c, %c_0, %c_1) threads in(%c_2, %c_3, %c_4) shmem = %c_5 (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor) -> tensor + return %0 : tensor + } +} + +// CHECK: func.func @main(%arg0: tensor) -> tensor { +// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor +// CHECK-NEXT: %c_0 = stablehlo.constant dense<0> : tensor +// CHECK-NEXT: %0 = enzymexla.kernel_call @"##call__Z8tuplef2_5TupleI5Int6413CuTracedArrayIS0_Li0ELi1E2__EE#258" blocks in(%c, %c, %c) threads in(%c, %c, %c) shmem = %c_0 (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor) -> tensor +// CHECK-NEXT: return %0 : tensor +// CHECK-NEXT: } +// CHECK: llvm.func ptx_kernelcc @"##call__Z8tuplef2_5TupleI5Int6413CuTracedArrayIS0_Li0ELi1E2__EE#258"(%arg0: !llvm.ptr<1>) { +// CHECK-NEXT: %0 = llvm.mlir.constant(5 : i64) : i64 +// CHECK-NEXT: %1 = llvm.load %arg0 {alignment = 1 : i64} : !llvm.ptr<1> -> i64 +// CHECK-NEXT: %2 = llvm.mul %1, %0 : i64 +// CHECK-NEXT: llvm.store %2, %arg0 {alignment = 1 : i64} : i64, !llvm.ptr<1> +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index d7feaaeaf..7745d40ae 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -1,8 +1,10 @@ """Loads XLA.""" load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") +# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") load("//:workspace.bzl", "XLA_PATCHES") +XLA_COMMIT = "1bb4fc18e73faa1c001d96bfe3a22f733987b018" +XLA_SHA256 = "" def repo(): http_archive(