Skip to content

Commit

Permalink
chlo.lgamma const prop (#182)
Browse files Browse the repository at this point in the history
* skeleton transform

* lgamma expansion

* Legalize all CHLO ops

* Fix call to materialize

* Add lit test template

* lgamma expansion

* lgamma expansion

* fix the diff

* disabled div-sqrt for now

* Added LIT test lowering check

* Modified lgamma const prop test
  • Loading branch information
vimarsh6739 authored Jan 14, 2025
1 parent 14672d8 commit ff78f0d
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 15 deletions.
50 changes: 35 additions & 15 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,26 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h"
#include "src/enzyme_ad/jax/Passes/PassDetails.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/dialect/TypeInference.h"
#include "stablehlo/reference/Ops.h"
#include "stablehlo/transforms/ChloDecompositionUtils.h"
#include "stablehlo/transforms/PassUtils.h"
#include "stablehlo/transforms/Passes.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"

#include "stablehlo/dialect/TypeInference.h"

#define DEBUG_TYPE "enzyme"

using namespace mlir;
Expand Down Expand Up @@ -2154,6 +2157,23 @@ struct ConcatToBroadcast final
}
};

struct GammaConstProp final : OpRewritePattern<mlir::chlo::LgammaOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::chlo::LgammaOp op,
PatternRewriter &rewriter) const override {
// return if not constant
DenseElementsAttr inputAttr;
if (!matchPattern(op.getOperand(), m_Constant(&inputAttr)))
return failure();
Value result = mlir::stablehlo::materializeLgamma(rewriter, op.getLoc(),
op->getOperands());
rewriter.replaceOp(op, result);

return success();
}
};

struct DynamicUpdateSliceConstProp final
: OpRewritePattern<mlir::stablehlo::DynamicUpdateSliceOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -7103,6 +7123,7 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {

void runOnOperation() override {
auto context = getOperation()->getContext();

RewritePatternSet patterns(context);
patterns.add<AddSimplify, SubSimplify, AndSimplify, MaxSimplify,
MinSimplify, OrSimplify, NegateSimplify, MulSimplify,
Expand All @@ -7115,22 +7136,21 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
GatherSimplify, ReshapeEmptyBroadcast, BroadcastReshape,
ConstPropThroughBarrier, ReplaceNegAddWithSubtract>(
context, PatternBenefit(65000));

patterns.add<IotaSimplify, BroadcastInDimSimplify>(
max_constant_expansion, context, PatternBenefit(65000));

patterns.add<ConvertConcat, DynamicUpdateToConcat, SliceOfDynamicUpdate,
SliceElementwise, SliceReshapeElementwise, SlicePad,
SliceReshapePad, DotReshapeDot, ConcatConstProp,
DynamicUpdateSliceConstProp, LogConstProp, LogPlusConstProp,
ChloInfConstProp, ConcatFuse, ConcatToBroadcast, PadPad,
PadReshapePad, ConcatPushBinop<stablehlo::AddOp>,
ConcatPushBinop<stablehlo::MulOp>, ScatterToDynamicUpdateSlice,
ReduceConcat, ConcatSlice, SliceConcat, SliceReshapeConcat,
BinBroadcastSplat<stablehlo::AddOp>,
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>>(context);
patterns.add<
ConvertConcat, DynamicUpdateToConcat, SliceOfDynamicUpdate,
SliceElementwise, SliceReshapeElementwise, SlicePad, SliceReshapePad,
DotReshapeDot, ConcatConstProp, DynamicUpdateSliceConstProp,
LogConstProp, LogPlusConstProp, ChloInfConstProp, GammaConstProp,
ConcatFuse, ConcatToBroadcast, PadPad, PadReshapePad,
ConcatPushBinop<stablehlo::AddOp>, ConcatPushBinop<stablehlo::MulOp>,
ScatterToDynamicUpdateSlice, ReduceConcat, ConcatSlice, SliceConcat,
SliceReshapeConcat, BinBroadcastSplat<stablehlo::AddOp>,
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>>(context);

patterns.add<BinaryOpTransposeSimplify<stablehlo::AddOp>,
BinaryOpTransposeSimplify<stablehlo::SubtractOp>,
Expand Down
4 changes: 4 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ def ApplyChloInfConstProp : EnzymeHLOPatternOp<
"chlo_inf_const_prop">{
let patterns = ["ChloInfConstProp"];
}
def ApplyGammaConstProp : EnzymeHLOPatternOp<
"gamma_const_prop">{
let patterns = ["GammaConstProp"];
}

// regular benefit
def ApplyConvertConcatPatterns : EnzymeHLOPatternOp<
Expand Down
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def hlo_opts():
log_const_prop<1>;
log_plus_one_const_prop<1>;
chlo_inf_const_prop<1>;
gamma_const_prop<1>;
concat_fuse<1>;
pad_reshape_pad<1>;
pad_pad<1>;
Expand Down
15 changes: 15 additions & 0 deletions test/lit_tests/chlo_lgamma_prop.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(enzyme-hlo-opt)" | FileCheck %s

module {
func.func @lgamma_f32() -> tensor<f32> {
%arg = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%1 = chlo.lgamma %arg : tensor<f32> -> tensor<f32>
func.return %1 : tensor<f32>
}
}


// CHECK: func.func @lgamma_f32() -> tensor<f32> {
// CHECK-NEXT: %cst = stablehlo.constant dense<4.76837158E-7> : tensor<f32>
// CHECK-NEXT: return %cst : tensor<f32>
// CHECK-NEXT: }

0 comments on commit ff78f0d

Please sign in to comment.