From 77b062a325458964e42d5d7129d7a5b19d84bda8 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Sat, 18 Jan 2025 12:50:04 +0100 Subject: [PATCH 1/6] while simplify --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 67 ++++++++++++++++++- .../jax/TransformOps/TransformOps.td | 5 ++ src/enzyme_ad/jax/primitives.py | 4 ++ test/lit_tests/while_simplify.mlir | 44 ++++++++++++ 4 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 test/lit_tests/while_simplify.mlir diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 3ee123eb7..274a6a1d8 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -6961,6 +6961,71 @@ struct IfToSelect final : public OpRewritePattern { } }; +// Replace while op iteration variables which are not updated with their upcoming value +struct WhileSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::WhileOp op, + PatternRewriter &rewriter) const override { + SmallVector operands; + + Block *cond = &op.getCond().front(), *body = &op.getBody().front(); + Operation *bodyTerm = body->getTerminator(); + + int deleted = 0; + for (auto &opOperand : op->getOpOperands()) { + Value inputValue = opOperand.get(); + + auto i = opOperand.getOperandNumber() - deleted; + Value bodyArg = body->getArgument(i); + Value condArg = cond->getArgument(i); + + if (bodyArg == bodyTerm->getOperand(i)) { + // This variable is not updated during iterations + rewriter.replaceAllUsesWith(bodyArg, inputValue); + rewriter.replaceAllUsesWith(condArg, inputValue); + rewriter.modifyOpInPlace(bodyTerm, + [&] { bodyTerm->setOperands(i, 1, {}); }); + rewriter.replaceAllUsesWith(op.getResult(opOperand.getOperandNumber()), + inputValue); + + body->eraseArgument(i); + cond->eraseArgument(i); + } else { + operands.push_back(opOperand.getOperandNumber()); + } + } + + if (operands.size() == op->getNumOperands()) + return failure(); + + SmallVector newOperands; + newOperands.reserve(operands.size()); + + for (auto opOperand : operands) { + newOperands.push_back(op->getOperand(opOperand)); + } + + auto newWhile = + rewriter.create(op.getLoc(), newOperands); + newWhile.getCond().takeBody(op.getCond()); + newWhile.getBody().takeBody(op.getBody()); + + // Replace uses for remaining results. + for (const auto &it : llvm::enumerate(operands)) +  { + Value oldRes = op->getResult(it.value()); + Value newRes = newWhile->getResult(it.index()); + + rewriter.replaceAllUsesWith(oldRes, newRes); + } + + rewriter.eraseOp(op); + + return success(); + } +}; + struct DynamicGatherOpIsNotDynamic : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -7339,7 +7404,7 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { GetTupleElementOpCanon, RealOpCanon, ImagOpCanon, ConjComplexNegate, GetDimensionSizeOpCanon, GatherOpCanon, ReshapeOpCanon, MergeConsecutiveReshapes, TransposeIsReshape, - SelectOpUsedWithinIf, IfInline, IfToSelect, + SelectOpUsedWithinIf, IfInline, IfToSelect, WhileSimplify, ZeroExtentTensorCanon, ReorderElementwiseAndShapeOp, DynamicGatherOpIsNotDynamic, DivideSqrtToMultiplyRsqrt>( context); diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index 9ba82526f..f664ef7ca 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -709,6 +709,11 @@ def IfToSelect : EnzymeHLOPatternOp< let patterns = ["IfToSelect"]; } +def WhileSimplify : EnzymeHLOPatternOp< + "while_simplify"> { + let patterns = ["WhileSimplify"]; +} + def SelectOpUsedWithinIf : EnzymeHLOPatternOp< "select_op_used_within_if"> { let patterns = ["SelectOpUsedWithinIf"]; diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 48de7d1c0..ce32c583e 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -276,6 +276,10 @@ def hlo_opts(): dot_reshape_pad<1>; pad_dot_general<1>(0); +if_inline<1>; +if_to_select<1>; +while_simplify<1>; + dot_reshape_pad<1>; pad_dot_general<1>(1); }, diff --git a/test/lit_tests/while_simplify.mlir b/test/lit_tests/while_simplify.mlir new file mode 100644 index 000000000..932d8f2b2 --- /dev/null +++ b/test/lit_tests/while_simplify.mlir @@ -0,0 +1,44 @@ +// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s + +module { + func.func @main(%x: tensor) -> tensor { + %init_i = stablehlo.constant dense<0> : tensor + %init_sum = stablehlo.constant dense<0.0> : tensor + %one = stablehlo.constant dense<1> : tensor + %one_f = stablehlo.constant dense<2.0> : tensor + %ten = stablehlo.constant dense<3> : tensor + %constant = stablehlo.constant dense<42.0> : tensor + %results0, %results1, %results2 = "stablehlo.while"(%init_i, %x, %constant) ({ + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor): + %cond = "stablehlo.compare"(%arg0, %ten) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + stablehlo.return %cond : tensor + }, { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor): + %new_sum = stablehlo.multiply %arg1, %arg2 : tensor + %new_i = stablehlo.add %arg0, %one : tensor + stablehlo.return %new_i, %new_sum, %arg2 : tensor, tensor, tensor + }) : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + %new_result = stablehlo.add %results1, %results2 : tensor + return %new_result : tensor + } +} + +// CHECK: func.func @main(%arg0: tensor) -> tensor { +// CHECK-NEXT: %c = stablehlo.constant dense<0> : tensor +// CHECK-NEXT: %c_0 = stablehlo.constant dense<1> : tensor +// CHECK-NEXT: %c_1 = stablehlo.constant dense<3> : tensor +// CHECK-NEXT: %cst = stablehlo.constant dense<4.200000e+01> : tensor +// CHECK-NEXT: %0:2 = stablehlo.while(%iterArg = %c, %iterArg_2 = %arg0) : tensor, tensor +// CHECK-NEXT: cond { +// CHECK-NEXT: %2 = stablehlo.compare LT, %iterArg, %c_1 : (tensor, tensor) -> tensor +// CHECK-NEXT: stablehlo.return %2 : tensor +// CHECK-NEXT: } do { +// CHECK-NEXT: %2 = stablehlo.multiply %iterArg_2, %cst : tensor +// CHECK-NEXT: %3 = stablehlo.add %iterArg, %c_0 : tensor +// CHECK-NEXT: stablehlo.return %3, %2 : tensor, tensor +// CHECK-NEXT: } +// CHECK-NEXT: %1 = stablehlo.add %0#1, %cst : tensor +// CHECK-NEXT: return %1 : tensor +// CHECK-NEXT: } From 3f0397e61d7eefc287e4391396ecb8ee48e88fc2 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Sat, 18 Jan 2025 13:03:29 +0100 Subject: [PATCH 2/6] SmallDenseSet -> SetVector for deterministic values --- .../jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 58f1cac8a..74becdbf5 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -1825,7 +1825,7 @@ class AutoDiffSort static void removalBlockExplore(Block *block, IRMapping &mapping, OpBuilder &builder, - llvm::SmallDenseSet &gradients, + llvm::SetVector &gradients, llvm::MapVector &caches) { for (auto it = block->begin(), e = block->end(); it != e;) { Operation *op = &*it; @@ -1928,7 +1928,7 @@ struct IfOpEnzymeOpsRemover } // Gradients whose value is set in either branches. - llvm::SmallDenseSet gradients; + llvm::SetVector gradients; // We assume pushes are exclusive. llvm::MapVector pushedCaches; From 378fa1df47273994959e09b2002b19f2949cca2d Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Sat, 18 Jan 2025 13:10:59 +0100 Subject: [PATCH 3/6] fixup --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 274a6a1d8..7339dcfcc 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -6991,6 +6991,8 @@ struct WhileSimplify : public OpRewritePattern { body->eraseArgument(i); cond->eraseArgument(i); + + deleted++; } else { operands.push_back(opOperand.getOperandNumber()); } From 010db77894b68ee8cfa5acbfe96188237d0968cf Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Sat, 18 Jan 2025 13:12:12 +0100 Subject: [PATCH 4/6] fmt --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 7339dcfcc..acb236e36 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -6961,7 +6961,8 @@ struct IfToSelect final : public OpRewritePattern { } }; -// Replace while op iteration variables which are not updated with their upcoming value +// Replace while op iteration variables which are not updated with their +// upcoming value struct WhileSimplify : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; From 3af1e78609b9b7ded9fd05b0d9d5f742f92d7223 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Sat, 18 Jan 2025 21:11:27 +0100 Subject: [PATCH 5/6] apply to constants only for now --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index acb236e36..ced50f08f 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -6981,7 +6981,8 @@ struct WhileSimplify : public OpRewritePattern { Value bodyArg = body->getArgument(i); Value condArg = cond->getArgument(i); - if (bodyArg == bodyTerm->getOperand(i)) { + if (isa(inputValue.getDefiningOp()) && + bodyArg == bodyTerm->getOperand(i)) { // This variable is not updated during iterations rewriter.replaceAllUsesWith(bodyArg, inputValue); rewriter.replaceAllUsesWith(condArg, inputValue); From 73dd8283808257c805c1d684b52f30ad2d435ee9 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Sat, 18 Jan 2025 21:40:18 +0100 Subject: [PATCH 6/6] fixup --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index ced50f08f..5ff8edf44 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -6981,7 +6981,7 @@ struct WhileSimplify : public OpRewritePattern { Value bodyArg = body->getArgument(i); Value condArg = cond->getArgument(i); - if (isa(inputValue.getDefiningOp()) && + if (inputValue.getDefiningOp() && bodyArg == bodyTerm->getOperand(i)) { // This variable is not updated during iterations rewriter.replaceAllUsesWith(bodyArg, inputValue);