diff --git a/src/enzyme_ad/jax/Dialect/Ops.cpp b/src/enzyme_ad/jax/Dialect/Ops.cpp index f1111a45b..583698aee 100644 --- a/src/enzyme_ad/jax/Dialect/Ops.cpp +++ b/src/enzyme_ad/jax/Dialect/Ops.cpp @@ -83,6 +83,7 @@ class ReadOnlyKernelArg final auto operand = fn.front().getArgument(operandIndex); bool readonly = + operand.use_empty() || fn.getArgAttr(operandIndex, LLVMDialect::getReadonlyAttrName()) || fn.getArgAttr(operandIndex, LLVMDialect::getReadnoneAttrName()); @@ -108,6 +109,7 @@ class ReadOnlyKernelArg final assert(launchOp.getInputs()[operandIndex].getType() == launchOp.getResultTypes()[idx]); bool readonly = + operand.use_empty() || fn.getArgAttr(operandIndex, LLVMDialect::getReadonlyAttrName()) || fn.getArgAttr(operandIndex, LLVMDialect::getReadnoneAttrName()); @@ -144,6 +146,7 @@ class ReadOnlyKernelArg final auto operand = fn.front().getArgument(operandIndex); bool readonly = + operand.use_empty() || fn.getArgAttr(operandIndex, LLVMDialect::getReadonlyAttrName()) || fn.getArgAttr(operandIndex, LLVMDialect::getReadnoneAttrName()); @@ -160,7 +163,106 @@ class ReadOnlyKernelArg final } }; +class ReadNoneKernelArg final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzymexla::KernelCallOp launchOp, + PatternRewriter &rewriter) const override { + SymbolTableCollection symbolTable; + auto mod = launchOp->getParentOfType(); + symbolTable.getSymbolTable(mod); + auto fn = cast( + symbolTable.lookupNearestSymbolFrom(launchOp, launchOp.getFnAttr())); + + bool changed = false; + + SmallVector calls; + auto use_opt = symbolTable.getSymbolTable(mod).getSymbolUses(fn, mod); + if (!use_opt) + return failure(); + for (auto u : *use_opt) { + auto launch2 = dyn_cast(u.getUser()); + if (!launch2) + return failure(); + calls.push_back(launch2); + auto operand_aliases2 = launchOp.getOutputOperandAliases(); + assert(operand_aliases2.size() == launchOp.getNumResults()); + } + + BitVector deadArgs(fn.front().getNumArguments(), false); + for (auto arg : fn.front().getArguments()) { + auto operandIndex = arg.getArgNumber(); + bool readnone = arg.use_empty(); + // fn.getArgAttr(operandIndex, LLVMDialect::getReadnoneAttrName()); + if (!readnone) + continue; + + for (auto call : calls) { + auto operand_aliases = call.getOutputOperandAliases(); + for (auto alias_attr : operand_aliases) { + auto alias = cast(alias_attr); + auto aliasOperandIndex = alias.getOperandIndex(); + if (aliasOperandIndex == operandIndex) { + return failure(); + } + } + } + changed = true; + deadArgs[operandIndex] = true; + } + + if (!changed) + return failure(); + + rewriter.modifyOpInPlace(fn, [&]() { + // fn.eraseArguments(deadArgs); + if (auto T = dyn_cast(fn.getFunctionType())) { + SmallVector argStorage; + mlir::filterTypesOut(fn.getArgumentTypes(), deadArgs, argStorage); + auto fty2 = + LLVMFunctionType::get(T.getReturnType(), argStorage, T.getVarArg()); + mlir::function_interface_impl::eraseFunctionArguments(fn, deadArgs, + fty2); + } else { + fn.eraseArguments(deadArgs); + } + }); + + for (auto call : calls) { + BitVector nonLiveCallOperands(call.getNumOperands(), false); + for (int index : deadArgs.set_bits()) + nonLiveCallOperands.set(call.getInputs().getBeginOperandIndex() + + index); + + SmallVector outputAliases; + auto operand_aliases = call.getOutputOperandAliases(); + + for (auto alias_attr : operand_aliases) { + auto alias = cast(alias_attr); + auto operandIndex = alias.getOperandIndex(); + size_t nextIndex = operandIndex; + for (int index : deadArgs.set_bits()) { + if (index <= operandIndex) + nextIndex--; + } + outputAliases.push_back(OutputOperandAliasAttr::get( + call->getContext(), alias.getOutputTupleIndices(), nextIndex, + alias.getOperandTupleIndices())); + } + + rewriter.modifyOpInPlace(call, [&]() { + call->eraseOperands(nonLiveCallOperands); + call.setOutputOperandAliasesAttr( + ArrayAttr::get(call->getContext(), outputAliases)); + }); + } + return success(); + } +}; + void KernelCallOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } diff --git a/src/enzyme_ad/jax/Passes/LowerKernel.cpp b/src/enzyme_ad/jax/Passes/LowerKernel.cpp index daa6c6165..9137fdb9c 100644 --- a/src/enzyme_ad/jax/Passes/LowerKernel.cpp +++ b/src/enzyme_ad/jax/Passes/LowerKernel.cpp @@ -549,6 +549,16 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, LLVM::LLVMFuncOp funcload = builder.create( loc, "cuModuleGetFunction", funcload_ty); + LLVM::GlobalOp kernStr; + { + auto type = LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), + legalName.size() + 1); + kernStr = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "str", + builder.getStringAttr(legalName + '\0')); + } + builder.setInsertionPointToStart(&submod.getBodyRegion().front()); LLVM::LLVMFuncOp initfn = builder.create( diff --git a/test/lit_tests/lowering/deadarg.mlir b/test/lit_tests/lowering/deadarg.mlir new file mode 100644 index 000000000..3364b89dc --- /dev/null +++ b/test/lit_tests/lowering/deadarg.mlir @@ -0,0 +1,63 @@ +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(canonicalize)" | FileCheck %s + +module { + func.func @main(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor, tensor) { + %c = stablehlo.constant dense<1> : tensor + %c_0 = stablehlo.constant dense<0> : tensor + %0:3 = enzymexla.kernel_call @k1 blocks in(%c, %c, %c) threads in(%c, %c, %c) shmem = %c_0 (%arg0, %arg1, %arg2) {output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias] } : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + return %0#0, %0#1, %0#2 : tensor, tensor, tensor + } + llvm.func ptx_kernelcc @k1(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>) { + %a0 = llvm.load %arg0 {alignment = 1 : i64} : !llvm.ptr<1> -> i64 + llvm.store %a0, %arg2 {alignment = 1 : i64} : i64, !llvm.ptr<1> + llvm.return + } + func.func @main2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor, tensor) { + %c = stablehlo.constant dense<1> : tensor + %c_0 = stablehlo.constant dense<0> : tensor + %0:3 = enzymexla.kernel_call @k2 blocks in(%c, %c, %c) threads in(%c, %c, %c) shmem = %c_0 (%arg0, %arg1, %arg2) {output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias] } : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + return %0#0, %0#1, %0#2 : tensor, tensor, tensor + } + func.func @main3(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor, tensor) { + %c = stablehlo.constant dense<1> : tensor + %c_0 = stablehlo.constant dense<0> : tensor + %0:3 = enzymexla.kernel_call @k2 blocks in(%c, %c, %c) threads in(%c, %c, %c) shmem = %c_0 (%arg0, %arg1, %arg2) {output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias] } : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + return %0#0, %0#1, %0#2 : tensor, tensor, tensor + } + llvm.func ptx_kernelcc @k2(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>) { + %a0 = llvm.load %arg2 {alignment = 1 : i64} : !llvm.ptr<1> -> i64 + %t = llvm.mul %a0, %a0 : i64 + llvm.store %t, %arg2 {alignment = 1 : i64} : i64, !llvm.ptr<1> + llvm.return + } +} + +// CHECK: func.func @main(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor, tensor) { +// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor +// CHECK-NEXT: %c_0 = stablehlo.constant dense<0> : tensor +// CHECK-NEXT: %0:2 = enzymexla.kernel_call @k1 blocks in(%c, %c, %c) threads in(%c, %c, %c) shmem = %c_0 (%arg0, %arg2) {output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias]} : (tensor, tensor) -> (tensor, tensor) +// CHECK-NEXT: return %0#0, %arg1, %0#1 : tensor, tensor, tensor +// CHECK-NEXT: } +// CHECK: llvm.func ptx_kernelcc @k1(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>) { +// CHECK-NEXT: %0 = llvm.load %arg0 {alignment = 1 : i64} : !llvm.ptr<1> -> i64 +// CHECK-NEXT: llvm.store %0, %arg1 {alignment = 1 : i64} : i64, !llvm.ptr<1> +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } +// CHECK: func.func @main2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor, tensor) { +// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor +// CHECK-NEXT: %c_0 = stablehlo.constant dense<0> : tensor +// CHECK-NEXT: %0 = enzymexla.kernel_call @k2 blocks in(%c, %c, %c) threads in(%c, %c, %c) shmem = %c_0 (%arg2) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor) -> tensor +// CHECK-NEXT: return %arg0, %arg1, %0 : tensor, tensor, tensor +// CHECK-NEXT: } +// CHECK: func.func @main3(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor, tensor) { +// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor +// CHECK-NEXT: %c_0 = stablehlo.constant dense<0> : tensor +// CHECK-NEXT: %0 = enzymexla.kernel_call @k2 blocks in(%c, %c, %c) threads in(%c, %c, %c) shmem = %c_0 (%arg2) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor) -> tensor +// CHECK-NEXT: return %arg0, %arg1, %0 : tensor, tensor, tensor +// CHECK-NEXT: } +// CHECK: llvm.func ptx_kernelcc @k2(%arg0: !llvm.ptr<1>) { +// CHECK-NEXT: %0 = llvm.load %arg0 {alignment = 1 : i64} : !llvm.ptr<1> -> i64 +// CHECK-NEXT: %1 = llvm.mul %0, %0 : i64 +// CHECK-NEXT: llvm.store %1, %arg0 {alignment = 1 : i64} : i64, !llvm.ptr<1> +// CHECK-NEXT: llvm.return +// CHECK-NEXT: }