From ff78f0df0e41f7a8a8eea95701c05018874ad1c7 Mon Sep 17 00:00:00 2001 From: Vimarsh Sathia <39610523+vimarsh6739@users.noreply.github.com> Date: Tue, 14 Jan 2025 12:23:08 -0600 Subject: [PATCH] chlo.lgamma const prop (#182) * skeleton transform * lgamma expansion * Legalize all CHLO ops * Fix call to materialize * Add lit test template * lgamma expansion * lgamma expansion * fix the diff * disabled div-sqrt for now * Added LIT test lowering check * Modified lgamma const prop test --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 50 +++++++++++++------ .../jax/TransformOps/TransformOps.td | 4 ++ src/enzyme_ad/jax/primitives.py | 1 + test/lit_tests/chlo_lgamma_prop.mlir | 15 ++++++ 4 files changed, 55 insertions(+), 15 deletions(-) create mode 100644 test/lit_tests/chlo_lgamma_prop.mlir diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 23fbfbe83..9ccfc6f64 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -11,23 +11,26 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/CommonFolders.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h" #include "src/enzyme_ad/jax/Passes/PassDetails.h" #include "src/enzyme_ad/jax/Passes/Passes.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/dialect/TypeInference.h" #include "stablehlo/reference/Ops.h" +#include "stablehlo/transforms/ChloDecompositionUtils.h" +#include "stablehlo/transforms/PassUtils.h" #include "stablehlo/transforms/Passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "stablehlo/dialect/TypeInference.h" - #define DEBUG_TYPE "enzyme" using namespace mlir; @@ -2154,6 +2157,23 @@ struct ConcatToBroadcast final } }; +struct GammaConstProp final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::chlo::LgammaOp op, + PatternRewriter &rewriter) const override { + // return if not constant + DenseElementsAttr inputAttr; + if (!matchPattern(op.getOperand(), m_Constant(&inputAttr))) + return failure(); + Value result = mlir::stablehlo::materializeLgamma(rewriter, op.getLoc(), + op->getOperands()); + rewriter.replaceOp(op, result); + + return success(); + } +}; + struct DynamicUpdateSliceConstProp final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -7103,6 +7123,7 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { void runOnOperation() override { auto context = getOperation()->getContext(); + RewritePatternSet patterns(context); patterns.add { GatherSimplify, ReshapeEmptyBroadcast, BroadcastReshape, ConstPropThroughBarrier, ReplaceNegAddWithSubtract>( context, PatternBenefit(65000)); - patterns.add( max_constant_expansion, context, PatternBenefit(65000)); - patterns.add, - ConcatPushBinop, ScatterToDynamicUpdateSlice, - ReduceConcat, ConcatSlice, SliceConcat, SliceReshapeConcat, - BinBroadcastSplat, - BinBroadcastSplat, - BinBroadcastSplat, - BinBroadcastSplat>(context); + patterns.add< + ConvertConcat, DynamicUpdateToConcat, SliceOfDynamicUpdate, + SliceElementwise, SliceReshapeElementwise, SlicePad, SliceReshapePad, + DotReshapeDot, ConcatConstProp, DynamicUpdateSliceConstProp, + LogConstProp, LogPlusConstProp, ChloInfConstProp, GammaConstProp, + ConcatFuse, ConcatToBroadcast, PadPad, PadReshapePad, + ConcatPushBinop, ConcatPushBinop, + ScatterToDynamicUpdateSlice, ReduceConcat, ConcatSlice, SliceConcat, + SliceReshapeConcat, BinBroadcastSplat, + BinBroadcastSplat, + BinBroadcastSplat, + BinBroadcastSplat>(context); patterns.add, BinaryOpTransposeSimplify, diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index 04791b9d9..7a5cfbc90 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -165,6 +165,10 @@ def ApplyChloInfConstProp : EnzymeHLOPatternOp< "chlo_inf_const_prop">{ let patterns = ["ChloInfConstProp"]; } +def ApplyGammaConstProp : EnzymeHLOPatternOp< + "gamma_const_prop">{ + let patterns = ["GammaConstProp"]; +} // regular benefit def ApplyConvertConcatPatterns : EnzymeHLOPatternOp< diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 69cd279a9..65425ec8c 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -175,6 +175,7 @@ def hlo_opts(): log_const_prop<1>; log_plus_one_const_prop<1>; chlo_inf_const_prop<1>; +gamma_const_prop<1>; concat_fuse<1>; pad_reshape_pad<1>; pad_pad<1>; diff --git a/test/lit_tests/chlo_lgamma_prop.mlir b/test/lit_tests/chlo_lgamma_prop.mlir new file mode 100644 index 000000000..7359a028c --- /dev/null +++ b/test/lit_tests/chlo_lgamma_prop.mlir @@ -0,0 +1,15 @@ +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(enzyme-hlo-opt)" | FileCheck %s + +module { + func.func @lgamma_f32() -> tensor { + %arg = stablehlo.constant dense<1.000000e+00> : tensor + %1 = chlo.lgamma %arg : tensor -> tensor + func.return %1 : tensor + } +} + + +// CHECK: func.func @lgamma_f32() -> tensor { +// CHECK-NEXT: %cst = stablehlo.constant dense<4.76837158E-7> : tensor +// CHECK-NEXT: return %cst : tensor +// CHECK-NEXT: }