diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 58f1cac8a..74becdbf5 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -1825,7 +1825,7 @@ class AutoDiffSort static void removalBlockExplore(Block *block, IRMapping &mapping, OpBuilder &builder, - llvm::SmallDenseSet &gradients, + llvm::SetVector &gradients, llvm::MapVector &caches) { for (auto it = block->begin(), e = block->end(); it != e;) { Operation *op = &*it; @@ -1928,7 +1928,7 @@ struct IfOpEnzymeOpsRemover } // Gradients whose value is set in either branches. - llvm::SmallDenseSet gradients; + llvm::SetVector gradients; // We assume pushes are exclusive. llvm::MapVector pushedCaches; diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 3ee123eb7..5ff8edf44 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -6961,6 +6961,75 @@ struct IfToSelect final : public OpRewritePattern { } }; +// Replace while op iteration variables which are not updated with their +// upcoming value +struct WhileSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::WhileOp op, + PatternRewriter &rewriter) const override { + SmallVector operands; + + Block *cond = &op.getCond().front(), *body = &op.getBody().front(); + Operation *bodyTerm = body->getTerminator(); + + int deleted = 0; + for (auto &opOperand : op->getOpOperands()) { + Value inputValue = opOperand.get(); + + auto i = opOperand.getOperandNumber() - deleted; + Value bodyArg = body->getArgument(i); + Value condArg = cond->getArgument(i); + + if (inputValue.getDefiningOp() && + bodyArg == bodyTerm->getOperand(i)) { + // This variable is not updated during iterations + rewriter.replaceAllUsesWith(bodyArg, inputValue); + rewriter.replaceAllUsesWith(condArg, inputValue); + rewriter.modifyOpInPlace(bodyTerm, + [&] { bodyTerm->setOperands(i, 1, {}); }); + rewriter.replaceAllUsesWith(op.getResult(opOperand.getOperandNumber()), + inputValue); + + body->eraseArgument(i); + cond->eraseArgument(i); + + deleted++; + } else { + operands.push_back(opOperand.getOperandNumber()); + } + } + + if (operands.size() == op->getNumOperands()) + return failure(); + + SmallVector newOperands; + newOperands.reserve(operands.size()); + + for (auto opOperand : operands) { + newOperands.push_back(op->getOperand(opOperand)); + } + + auto newWhile = + rewriter.create(op.getLoc(), newOperands); + newWhile.getCond().takeBody(op.getCond()); + newWhile.getBody().takeBody(op.getBody()); + + // Replace uses for remaining results. + for (const auto &it : llvm::enumerate(operands)) +  { + Value oldRes = op->getResult(it.value()); + Value newRes = newWhile->getResult(it.index()); + + rewriter.replaceAllUsesWith(oldRes, newRes); + } + + rewriter.eraseOp(op); + + return success(); + } +}; + struct DynamicGatherOpIsNotDynamic : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -7339,7 +7408,7 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { GetTupleElementOpCanon, RealOpCanon, ImagOpCanon, ConjComplexNegate, GetDimensionSizeOpCanon, GatherOpCanon, ReshapeOpCanon, MergeConsecutiveReshapes, TransposeIsReshape, - SelectOpUsedWithinIf, IfInline, IfToSelect, + SelectOpUsedWithinIf, IfInline, IfToSelect, WhileSimplify, ZeroExtentTensorCanon, ReorderElementwiseAndShapeOp, DynamicGatherOpIsNotDynamic, DivideSqrtToMultiplyRsqrt>( context); diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index 9ba82526f..f664ef7ca 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -709,6 +709,11 @@ def IfToSelect : EnzymeHLOPatternOp< let patterns = ["IfToSelect"]; } +def WhileSimplify : EnzymeHLOPatternOp< + "while_simplify"> { + let patterns = ["WhileSimplify"]; +} + def SelectOpUsedWithinIf : EnzymeHLOPatternOp< "select_op_used_within_if"> { let patterns = ["SelectOpUsedWithinIf"]; diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 48de7d1c0..ce32c583e 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -276,6 +276,10 @@ def hlo_opts(): dot_reshape_pad<1>; pad_dot_general<1>(0); +if_inline<1>; +if_to_select<1>; +while_simplify<1>; + dot_reshape_pad<1>; pad_dot_general<1>(1); }, diff --git a/test/lit_tests/while_simplify.mlir b/test/lit_tests/while_simplify.mlir new file mode 100644 index 000000000..932d8f2b2 --- /dev/null +++ b/test/lit_tests/while_simplify.mlir @@ -0,0 +1,44 @@ +// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s + +module { + func.func @main(%x: tensor) -> tensor { + %init_i = stablehlo.constant dense<0> : tensor + %init_sum = stablehlo.constant dense<0.0> : tensor + %one = stablehlo.constant dense<1> : tensor + %one_f = stablehlo.constant dense<2.0> : tensor + %ten = stablehlo.constant dense<3> : tensor + %constant = stablehlo.constant dense<42.0> : tensor + %results0, %results1, %results2 = "stablehlo.while"(%init_i, %x, %constant) ({ + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor): + %cond = "stablehlo.compare"(%arg0, %ten) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + stablehlo.return %cond : tensor + }, { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor): + %new_sum = stablehlo.multiply %arg1, %arg2 : tensor + %new_i = stablehlo.add %arg0, %one : tensor + stablehlo.return %new_i, %new_sum, %arg2 : tensor, tensor, tensor + }) : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + %new_result = stablehlo.add %results1, %results2 : tensor + return %new_result : tensor + } +} + +// CHECK: func.func @main(%arg0: tensor) -> tensor { +// CHECK-NEXT: %c = stablehlo.constant dense<0> : tensor +// CHECK-NEXT: %c_0 = stablehlo.constant dense<1> : tensor +// CHECK-NEXT: %c_1 = stablehlo.constant dense<3> : tensor +// CHECK-NEXT: %cst = stablehlo.constant dense<4.200000e+01> : tensor +// CHECK-NEXT: %0:2 = stablehlo.while(%iterArg = %c, %iterArg_2 = %arg0) : tensor, tensor +// CHECK-NEXT: cond { +// CHECK-NEXT: %2 = stablehlo.compare LT, %iterArg, %c_1 : (tensor, tensor) -> tensor +// CHECK-NEXT: stablehlo.return %2 : tensor +// CHECK-NEXT: } do { +// CHECK-NEXT: %2 = stablehlo.multiply %iterArg_2, %cst : tensor +// CHECK-NEXT: %3 = stablehlo.add %iterArg, %c_0 : tensor +// CHECK-NEXT: stablehlo.return %3, %2 : tensor, tensor +// CHECK-NEXT: } +// CHECK-NEXT: %1 = stablehlo.add %0#1, %cst : tensor +// CHECK-NEXT: return %1 : tensor +// CHECK-NEXT: }