Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

While simplify opt #248

Merged
merged 6 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1825,7 +1825,7 @@ class AutoDiffSort

static void removalBlockExplore(Block *block, IRMapping &mapping,
OpBuilder &builder,
llvm::SmallDenseSet<Value> &gradients,
llvm::SetVector<Value> &gradients,
llvm::MapVector<Value, CacheInfo> &caches) {
for (auto it = block->begin(), e = block->end(); it != e;) {
Operation *op = &*it;
Expand Down Expand Up @@ -1928,7 +1928,7 @@ struct IfOpEnzymeOpsRemover
}

// Gradients whose value is set in either branches.
llvm::SmallDenseSet<Value> gradients;
llvm::SetVector<Value> gradients;

// We assume pushes are exclusive.
llvm::MapVector<Value, CacheInfo> pushedCaches;
Expand Down
70 changes: 69 additions & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6961,6 +6961,74 @@ struct IfToSelect final : public OpRewritePattern<mlir::stablehlo::IfOp> {
}
};

// Replace while op iteration variables which are not updated with their
// upcoming value
struct WhileSimplify : public OpRewritePattern<stablehlo::WhileOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(stablehlo::WhileOp op,
PatternRewriter &rewriter) const override {
SmallVector<unsigned> operands;

Block *cond = &op.getCond().front(), *body = &op.getBody().front();
Operation *bodyTerm = body->getTerminator();

int deleted = 0;
for (auto &opOperand : op->getOpOperands()) {
Value inputValue = opOperand.get();

auto i = opOperand.getOperandNumber() - deleted;
Value bodyArg = body->getArgument(i);
Value condArg = cond->getArgument(i);

if (bodyArg == bodyTerm->getOperand(i)) {
// This variable is not updated during iterations
rewriter.replaceAllUsesWith(bodyArg, inputValue);
rewriter.replaceAllUsesWith(condArg, inputValue);
rewriter.modifyOpInPlace(bodyTerm,
[&] { bodyTerm->setOperands(i, 1, {}); });
rewriter.replaceAllUsesWith(op.getResult(opOperand.getOperandNumber()),
inputValue);

body->eraseArgument(i);
cond->eraseArgument(i);

deleted++;
} else {
operands.push_back(opOperand.getOperandNumber());
}
}

if (operands.size() == op->getNumOperands())
return failure();

SmallVector<Value> newOperands;
newOperands.reserve(operands.size());

for (auto opOperand : operands) {
newOperands.push_back(op->getOperand(opOperand));
}

auto newWhile =
rewriter.create<stablehlo::WhileOp>(op.getLoc(), newOperands);
newWhile.getCond().takeBody(op.getCond());
newWhile.getBody().takeBody(op.getBody());

// Replace uses for remaining results.
for (const auto &it : llvm::enumerate(operands))
 {
Value oldRes = op->getResult(it.value());
Value newRes = newWhile->getResult(it.index());

rewriter.replaceAllUsesWith(oldRes, newRes);
}

rewriter.eraseOp(op);

return success();
}
};

struct DynamicGatherOpIsNotDynamic
: public OpRewritePattern<stablehlo::DynamicGatherOp> {
using OpRewritePattern<stablehlo::DynamicGatherOp>::OpRewritePattern;
Expand Down Expand Up @@ -7339,7 +7407,7 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
GetTupleElementOpCanon, RealOpCanon, ImagOpCanon,
ConjComplexNegate, GetDimensionSizeOpCanon, GatherOpCanon,
ReshapeOpCanon, MergeConsecutiveReshapes, TransposeIsReshape,
SelectOpUsedWithinIf, IfInline, IfToSelect,
SelectOpUsedWithinIf, IfInline, IfToSelect, WhileSimplify,
ZeroExtentTensorCanon, ReorderElementwiseAndShapeOp,
DynamicGatherOpIsNotDynamic, DivideSqrtToMultiplyRsqrt>(
context);
Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,11 @@ def IfToSelect : EnzymeHLOPatternOp<
let patterns = ["IfToSelect"];
}

def WhileSimplify : EnzymeHLOPatternOp<
"while_simplify"> {
let patterns = ["WhileSimplify"];
}

def SelectOpUsedWithinIf : EnzymeHLOPatternOp<
"select_op_used_within_if"> {
let patterns = ["SelectOpUsedWithinIf"];
Expand Down
4 changes: 4 additions & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ def hlo_opts():
dot_reshape_pad<1>;
pad_dot_general<1>(0);

if_inline<1>;
if_to_select<1>;
while_simplify<1>;

dot_reshape_pad<1>;
pad_dot_general<1>(1);
},
Expand Down
44 changes: 44 additions & 0 deletions test/lit_tests/while_simplify.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s

module {
func.func @main(%x: tensor<f64>) -> tensor<f64> {
%init_i = stablehlo.constant dense<0> : tensor<i64>
%init_sum = stablehlo.constant dense<0.0> : tensor<f64>
%one = stablehlo.constant dense<1> : tensor<i64>
%one_f = stablehlo.constant dense<2.0> : tensor<f64>
%ten = stablehlo.constant dense<3> : tensor<i64>
%constant = stablehlo.constant dense<42.0> : tensor<f64>
%results0, %results1, %results2 = "stablehlo.while"(%init_i, %x, %constant) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<f64>, %arg2: tensor<f64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<f64>, %arg2: tensor<f64>):
%new_sum = stablehlo.multiply %arg1, %arg2 : tensor<f64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum, %arg2 : tensor<i64>, tensor<f64>, tensor<f64>
}) : (tensor<i64>, tensor<f64>, tensor<f64>) -> (tensor<i64>, tensor<f64>, tensor<f64>)
%new_result = stablehlo.add %results1, %results2 : tensor<f64>
return %new_result : tensor<f64>
}
}

// CHECK: func.func @main(%arg0: tensor<f64>) -> tensor<f64> {
// CHECK-NEXT: %c = stablehlo.constant dense<0> : tensor<i64>
// CHECK-NEXT: %c_0 = stablehlo.constant dense<1> : tensor<i64>
// CHECK-NEXT: %c_1 = stablehlo.constant dense<3> : tensor<i64>
// CHECK-NEXT: %cst = stablehlo.constant dense<4.200000e+01> : tensor<f64>
// CHECK-NEXT: %0:2 = stablehlo.while(%iterArg = %c, %iterArg_2 = %arg0) : tensor<i64>, tensor<f64>
// CHECK-NEXT: cond {
// CHECK-NEXT: %2 = stablehlo.compare LT, %iterArg, %c_1 : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK-NEXT: stablehlo.return %2 : tensor<i1>
// CHECK-NEXT: } do {
// CHECK-NEXT: %2 = stablehlo.multiply %iterArg_2, %cst : tensor<f64>
// CHECK-NEXT: %3 = stablehlo.add %iterArg, %c_0 : tensor<i64>
// CHECK-NEXT: stablehlo.return %3, %2 : tensor<i64>, tensor<f64>
// CHECK-NEXT: }
// CHECK-NEXT: %1 = stablehlo.add %0#1, %cst : tensor<f64>
// CHECK-NEXT: return %1 : tensor<f64>
// CHECK-NEXT: }
Loading