Skip to content

Commit

Permalink
Add readnone kernel arg (#251)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jan 18, 2025
1 parent 07a8928 commit b7c192b
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 1 deletion.
104 changes: 103 additions & 1 deletion src/enzyme_ad/jax/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ReadOnlyKernelArg final

auto operand = fn.front().getArgument(operandIndex);
bool readonly =
operand.use_empty() ||
fn.getArgAttr(operandIndex, LLVMDialect::getReadonlyAttrName()) ||
fn.getArgAttr(operandIndex, LLVMDialect::getReadnoneAttrName());

Expand All @@ -108,6 +109,7 @@ class ReadOnlyKernelArg final
assert(launchOp.getInputs()[operandIndex].getType() ==
launchOp.getResultTypes()[idx]);
bool readonly =
operand.use_empty() ||
fn.getArgAttr(operandIndex, LLVMDialect::getReadonlyAttrName()) ||
fn.getArgAttr(operandIndex, LLVMDialect::getReadnoneAttrName());

Expand Down Expand Up @@ -144,6 +146,7 @@ class ReadOnlyKernelArg final

auto operand = fn.front().getArgument(operandIndex);
bool readonly =
operand.use_empty() ||
fn.getArgAttr(operandIndex, LLVMDialect::getReadonlyAttrName()) ||
fn.getArgAttr(operandIndex, LLVMDialect::getReadnoneAttrName());

Expand All @@ -160,7 +163,106 @@ class ReadOnlyKernelArg final
}
};

class ReadNoneKernelArg final
: public OpRewritePattern<enzymexla::KernelCallOp> {
public:
using OpRewritePattern<enzymexla::KernelCallOp>::OpRewritePattern;

LogicalResult matchAndRewrite(enzymexla::KernelCallOp launchOp,
PatternRewriter &rewriter) const override {
SymbolTableCollection symbolTable;
auto mod = launchOp->getParentOfType<ModuleOp>();
symbolTable.getSymbolTable(mod);
auto fn = cast<FunctionOpInterface>(
symbolTable.lookupNearestSymbolFrom(launchOp, launchOp.getFnAttr()));

bool changed = false;

SmallVector<enzymexla::KernelCallOp> calls;
auto use_opt = symbolTable.getSymbolTable(mod).getSymbolUses(fn, mod);
if (!use_opt)
return failure();
for (auto u : *use_opt) {
auto launch2 = dyn_cast<enzymexla::KernelCallOp>(u.getUser());
if (!launch2)
return failure();
calls.push_back(launch2);
auto operand_aliases2 = launchOp.getOutputOperandAliases();
assert(operand_aliases2.size() == launchOp.getNumResults());
}

BitVector deadArgs(fn.front().getNumArguments(), false);
for (auto arg : fn.front().getArguments()) {
auto operandIndex = arg.getArgNumber();
bool readnone = arg.use_empty();
// fn.getArgAttr(operandIndex, LLVMDialect::getReadnoneAttrName());
if (!readnone)
continue;

for (auto call : calls) {
auto operand_aliases = call.getOutputOperandAliases();
for (auto alias_attr : operand_aliases) {
auto alias = cast<OutputOperandAliasAttr>(alias_attr);
auto aliasOperandIndex = alias.getOperandIndex();
if (aliasOperandIndex == operandIndex) {
return failure();
}
}
}
changed = true;
deadArgs[operandIndex] = true;
}

if (!changed)
return failure();

rewriter.modifyOpInPlace(fn, [&]() {
// fn.eraseArguments(deadArgs);
if (auto T = dyn_cast<LLVMFunctionType>(fn.getFunctionType())) {
SmallVector<Type> argStorage;
mlir::filterTypesOut(fn.getArgumentTypes(), deadArgs, argStorage);
auto fty2 =
LLVMFunctionType::get(T.getReturnType(), argStorage, T.getVarArg());
mlir::function_interface_impl::eraseFunctionArguments(fn, deadArgs,
fty2);
} else {
fn.eraseArguments(deadArgs);
}
});

for (auto call : calls) {
BitVector nonLiveCallOperands(call.getNumOperands(), false);
for (int index : deadArgs.set_bits())
nonLiveCallOperands.set(call.getInputs().getBeginOperandIndex() +
index);

SmallVector<Attribute> outputAliases;
auto operand_aliases = call.getOutputOperandAliases();

for (auto alias_attr : operand_aliases) {
auto alias = cast<OutputOperandAliasAttr>(alias_attr);
auto operandIndex = alias.getOperandIndex();
size_t nextIndex = operandIndex;
for (int index : deadArgs.set_bits()) {
if (index <= operandIndex)
nextIndex--;
}
outputAliases.push_back(OutputOperandAliasAttr::get(
call->getContext(), alias.getOutputTupleIndices(), nextIndex,
alias.getOperandTupleIndices()));
}

rewriter.modifyOpInPlace(call, [&]() {
call->eraseOperands(nonLiveCallOperands);
call.setOutputOperandAliasesAttr(
ArrayAttr::get(call->getContext(), outputAliases));
});
}
return success();
}
};

void KernelCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<ReadOnlyKernelArg>(context);
results.insert<ReadOnlyKernelArg, ReadNoneKernelArg>(context);
}
10 changes: 10 additions & 0 deletions src/enzyme_ad/jax/Passes/LowerKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,16 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
LLVM::LLVMFuncOp funcload = builder.create<LLVM::LLVMFuncOp>(
loc, "cuModuleGetFunction", funcload_ty);

LLVM::GlobalOp kernStr;
{
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8),
legalName.size() + 1);
kernStr = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "str",
builder.getStringAttr(legalName + '\0'));
}

builder.setInsertionPointToStart(&submod.getBodyRegion().front());

LLVM::LLVMFuncOp initfn = builder.create<LLVM::LLVMFuncOp>(
Expand Down
63 changes: 63 additions & 0 deletions test/lit_tests/lowering/deadarg.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(canonicalize)" | FileCheck %s

module {
func.func @main(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> (tensor<i64>, tensor<i64>, tensor<i64>) {
%c = stablehlo.constant dense<1> : tensor<i64>
%c_0 = stablehlo.constant dense<0> : tensor<i64>
%0:3 = enzymexla.kernel_call @k1 blocks in(%c, %c, %c) threads in(%c, %c, %c) shmem = %c_0 (%arg0, %arg1, %arg2) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 1, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [2], operand_index = 2, operand_tuple_indices = []>] } : (tensor<i64>, tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>, tensor<i64>)
return %0#0, %0#1, %0#2 : tensor<i64>, tensor<i64>, tensor<i64>
}
llvm.func ptx_kernelcc @k1(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>) {
%a0 = llvm.load %arg0 {alignment = 1 : i64} : !llvm.ptr<1> -> i64
llvm.store %a0, %arg2 {alignment = 1 : i64} : i64, !llvm.ptr<1>
llvm.return
}
func.func @main2(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> (tensor<i64>, tensor<i64>, tensor<i64>) {
%c = stablehlo.constant dense<1> : tensor<i64>
%c_0 = stablehlo.constant dense<0> : tensor<i64>
%0:3 = enzymexla.kernel_call @k2 blocks in(%c, %c, %c) threads in(%c, %c, %c) shmem = %c_0 (%arg0, %arg1, %arg2) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 1, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [2], operand_index = 2, operand_tuple_indices = []>] } : (tensor<i64>, tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>, tensor<i64>)
return %0#0, %0#1, %0#2 : tensor<i64>, tensor<i64>, tensor<i64>
}
func.func @main3(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> (tensor<i64>, tensor<i64>, tensor<i64>) {
%c = stablehlo.constant dense<1> : tensor<i64>
%c_0 = stablehlo.constant dense<0> : tensor<i64>
%0:3 = enzymexla.kernel_call @k2 blocks in(%c, %c, %c) threads in(%c, %c, %c) shmem = %c_0 (%arg0, %arg1, %arg2) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 1, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [2], operand_index = 2, operand_tuple_indices = []>] } : (tensor<i64>, tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>, tensor<i64>)
return %0#0, %0#1, %0#2 : tensor<i64>, tensor<i64>, tensor<i64>
}
llvm.func ptx_kernelcc @k2(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>) {
%a0 = llvm.load %arg2 {alignment = 1 : i64} : !llvm.ptr<1> -> i64
%t = llvm.mul %a0, %a0 : i64
llvm.store %t, %arg2 {alignment = 1 : i64} : i64, !llvm.ptr<1>
llvm.return
}
}

// CHECK: func.func @main(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> (tensor<i64>, tensor<i64>, tensor<i64>) {
// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor<i64>
// CHECK-NEXT: %c_0 = stablehlo.constant dense<0> : tensor<i64>
// CHECK-NEXT: %0:2 = enzymexla.kernel_call @k1 blocks in(%c, %c, %c) threads in(%c, %c, %c) shmem = %c_0 (%arg0, %arg2) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 1, operand_tuple_indices = []>]} : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// CHECK-NEXT: return %0#0, %arg1, %0#1 : tensor<i64>, tensor<i64>, tensor<i64>
// CHECK-NEXT: }
// CHECK: llvm.func ptx_kernelcc @k1(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>) {
// CHECK-NEXT: %0 = llvm.load %arg0 {alignment = 1 : i64} : !llvm.ptr<1> -> i64
// CHECK-NEXT: llvm.store %0, %arg1 {alignment = 1 : i64} : i64, !llvm.ptr<1>
// CHECK-NEXT: llvm.return
// CHECK-NEXT: }
// CHECK: func.func @main2(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> (tensor<i64>, tensor<i64>, tensor<i64>) {
// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor<i64>
// CHECK-NEXT: %c_0 = stablehlo.constant dense<0> : tensor<i64>
// CHECK-NEXT: %0 = enzymexla.kernel_call @k2 blocks in(%c, %c, %c) threads in(%c, %c, %c) shmem = %c_0 (%arg2) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<i64>) -> tensor<i64>
// CHECK-NEXT: return %arg0, %arg1, %0 : tensor<i64>, tensor<i64>, tensor<i64>
// CHECK-NEXT: }
// CHECK: func.func @main3(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> (tensor<i64>, tensor<i64>, tensor<i64>) {
// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor<i64>
// CHECK-NEXT: %c_0 = stablehlo.constant dense<0> : tensor<i64>
// CHECK-NEXT: %0 = enzymexla.kernel_call @k2 blocks in(%c, %c, %c) threads in(%c, %c, %c) shmem = %c_0 (%arg2) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<i64>) -> tensor<i64>
// CHECK-NEXT: return %arg0, %arg1, %0 : tensor<i64>, tensor<i64>, tensor<i64>
// CHECK-NEXT: }
// CHECK: llvm.func ptx_kernelcc @k2(%arg0: !llvm.ptr<1>) {
// CHECK-NEXT: %0 = llvm.load %arg0 {alignment = 1 : i64} : !llvm.ptr<1> -> i64
// CHECK-NEXT: %1 = llvm.mul %0, %0 : i64
// CHECK-NEXT: llvm.store %1, %arg0 {alignment = 1 : i64} : i64, !llvm.ptr<1>
// CHECK-NEXT: llvm.return
// CHECK-NEXT: }

0 comments on commit b7c192b

Please sign in to comment.