Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass flatten-tuple : Migrate from MHLO to StableHLO #21429

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/mlir_hlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,7 @@ cc_library(
"stablehlo_ext/transforms/chlo_recompose_ops.cpp",
"stablehlo_ext/transforms/sdy_refine_shapes.cpp",
"stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp",
"stablehlo_ext/transforms/stablehlo_flatten_tuple.cpp",
"stablehlo_ext/transforms/stablehlo_prepare_for_hlo_export.cpp",
"stablehlo_ext/transforms/stablehlo_refine_shapes.cpp",
],
Expand Down
6 changes: 5 additions & 1 deletion xla/mlir_hlo/stablehlo_ext/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <string>

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassOptions.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -27,10 +28,13 @@ namespace mlir {
namespace stablehlo_ext {

#define GEN_PASS_DECL
#define GEN_PASS_REGISTRATION
#include "stablehlo_ext/transforms/passes.h.inc"

void createChloLegalizeToStablehloPipeline(OpPassManager &pm);
std::unique_ptr<OperationPass<func::FuncOp>> createStablehloFlattenTuplePass();

#define GEN_PASS_REGISTRATION
#include "stablehlo_ext/transforms/passes.h.inc"

} // namespace stablehlo_ext
} // namespace mlir
Expand Down
8 changes: 7 additions & 1 deletion xla/mlir_hlo/stablehlo_ext/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,10 @@ def StablehloPrepareForHloExportPass : Pass<"stablehlo-ext-prepare-for-hlo-expor
Note: The result of this pass need not be a module in canonical form and
canonicalization may undo transformations.
}];
}
}

def StablehloFlattenTuplePass : Pass<"stablehlo-ext-flatten-tuple", "func::FuncOp"> {
let summary = "Flatten tuples in operands and results of operators that "
"support both tuple and variadic type.";
let constructor = "createStablehloFlattenTuplePass()";
}
157 changes: 157 additions & 0 deletions xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_flatten_tuple.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// This file implements logic for flattening tuples in HLO ops.

#include <cassert>
#include <memory>
#include <utility>

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo_ext/transforms/passes.h" // NOLINT: Used in passes.h.inc

namespace mlir {
namespace stablehlo_ext {

#define GEN_PASS_DEF_STABLEHLOFLATTENTUPLEPASS
#include "stablehlo_ext/transforms/passes.h.inc"

namespace {

// Calculates the flatten types of a value.
void flattenTupleType(Value value, llvm::SmallVectorImpl<Type> &types) {
if (!mlir::isa<TupleType>(value.getType())) {
types.push_back(value.getType());
return;
}

// This function doesn't handle nested tuple.
auto tupleType = mlir::cast<TupleType>(value.getType());
types.append(tupleType.begin(), tupleType.end());
}

// FlattenTupleValue and CreateTupleValue is a pair of functions to create and
// flatten tuples in the exact same order. CreateTupleValue returns the result
// of the root TupleOp or given value if the type is not TupleType.
Value createTupleValue(OpBuilder &builder, Location loc,
ValueRange flattenValues, Type tupleType) {
if (!mlir::isa<TupleType>(tupleType)) {
assert(flattenValues.size() == 1);
return flattenValues[0];
}

assert(mlir::cast<TupleType>(tupleType).getTypes().size() ==
flattenValues.size());
return builder.create<stablehlo::TupleOp>(loc, flattenValues);
}

void flattenTupleValue(OpBuilder &builder, Location loc, Value value,
llvm::SmallVectorImpl<Value> &flattenedValues) {
auto tupleType = mlir::dyn_cast<TupleType>(value.getType());
if (!tupleType) {
flattenedValues.push_back(value);
return;
}
int flattenIdx = 0;
for (auto innerType : tupleType.getTypes()) {
auto innerValue = builder.create<stablehlo::GetTupleElementOp>(
loc, innerType, value, builder.getI32IntegerAttr(flattenIdx++));
flattenTupleValue(builder, loc, innerValue, flattenedValues);
}
}

struct FlattenCustomCallOp : public OpRewritePattern<stablehlo::CustomCallOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(stablehlo::CustomCallOp op,
PatternRewriter &rewriter) const override {
bool flattenResult = op->getNumResults() == 1 &&
mlir::isa<TupleType>(op->getResult(0).getType());
bool flattenOperands = llvm::any_of(op.getInputs(), [](Value operand) {
return mlir::isa<TupleType>(operand.getType());
});

if (!flattenResult && !flattenOperands) return failure();

llvm::SmallVector<Value> flattenedOperands;
for (auto operand : op.getInputs())
flattenTupleValue(rewriter, op->getLoc(), operand, flattenedOperands);

llvm::SmallVector<Type, 4> flattenedResultTypes;
if (!flattenResult) {
flattenedResultTypes.push_back(op->getResult(0).getType());
} else {
// Check for nested tuples.
for (Type innerType :
mlir::cast<TupleType>(op->getResult(0).getType()).getTypes())
if (mlir::isa<TupleType>(innerType)) return failure();

for (auto result : op->getResults())
flattenTupleType(result, flattenedResultTypes);
}

auto flattenedCall = rewriter.create<stablehlo::CustomCallOp>(
op->getLoc(), flattenedResultTypes, flattenedOperands, op->getAttrs());

rewriter.replaceOp(op, flattenResult
? createTupleValue(rewriter, op->getLoc(),
flattenedCall.getResults(),
op->getResult(0).getType())
: flattenedCall.getResult(0));
return success();
}
};

class StablehloFlattenTuplePass
: public impl::StablehloFlattenTuplePassBase<StablehloFlattenTuplePass> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<FlattenCustomCallOp>(context);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
}
};

static PassRegistration<StablehloFlattenTuplePass> pass;

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>> createStablehloFlattenTuplePass() {
return std::make_unique<StablehloFlattenTuplePass>();
}

} // namespace stablehlo_ext
} // namespace mlir
36 changes: 36 additions & 0 deletions xla/mlir_hlo/tests/stablehlo_ext/stablehlo_flatten_tuple.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: mlir-hlo-opt -split-input-file -stablehlo-ext-flatten-tuple %s | FileCheck %s

// CHECK-LABEL: @custom_call
// CHECK-SAME: %[[X:.*]]: tensor<6x3xf32>
func.func @custom_call(%x: tensor<6x3xf32>) -> (tensor<6xf32>, tensor<3xf32>) {
// CHECK: %[[CALL:.+]]:2 = stablehlo.custom_call @f(%[[X]]) {api_version = 2 : i32} : (tensor<6x3xf32>) -> (tensor<6xf32>, tensor<3xf32>)
%0 = "stablehlo.custom_call"(%x) {api_version = 2 : i32, call_target_name = "f"}
: (tensor<6x3xf32>) -> tuple<tensor<6xf32>, tensor<3xf32>>
%1 = "stablehlo.get_tuple_element"(%0) {index = 0 : i32} : (tuple<tensor<6xf32>, tensor<3xf32>>) -> tensor<6xf32>
%2 = "stablehlo.get_tuple_element"(%0) {index = 1 : i32} : (tuple<tensor<6xf32>, tensor<3xf32>>) -> tensor<3xf32>
return %1, %2 : tensor<6xf32>, tensor<3xf32>
}

// -----

// CHECK-LABEL: @custom_call_tupled_operand
func.func @custom_call_tupled_operand(%arg0: tuple<tensor<ui32>, tensor<i32>>)
-> (tensor<i32>, tensor<ui32>) {
// CHECK-NEXT: %[[C0:.*]] = stablehlo.constant dense<1> : tensor<ui32>
%0 = stablehlo.constant dense<1> : tensor<ui32>
// CHECK-NEXT: %[[C1:.*]] = stablehlo.constant dense<10> : tensor<i32>
%1 = stablehlo.constant dense<10> : tensor<i32>
// CHECK-NEXT: %[[TUPLE:.*]] = stablehlo.tuple %[[C0]], %[[C1]], %arg
%2 = stablehlo.tuple %0, %1, %arg0 : tuple<tensor<ui32>, tensor<i32>,
tuple<tensor<ui32>, tensor<i32>>>
// CHECK-NEXT: %[[VAR1:.*]] = stablehlo.get_tuple_element %[[TUPLE]][0]
// CHECK-NEXT: %[[VAR2:.*]] = stablehlo.get_tuple_element %[[TUPLE]][1]
// CHECK-NEXT: %[[VAR3:.*]] = stablehlo.get_tuple_element %[[TUPLE]][2]
// CHECK-NEXT: %[[VAR4:.*]] = stablehlo.get_tuple_element %[[VAR3]][0]
// CHECK-NEXT: %[[VAR5:.*]] = stablehlo.get_tuple_element %[[VAR3]][1]
// CHECK-NEXT: stablehlo.custom_call @ScalarProgramDummyConstant(%[[VAR1]], %[[VAR2]], %[[VAR4]], %[[VAR5]])
%3 = stablehlo.custom_call @ScalarProgramDummyConstant(%2)
: (tuple<tensor<ui32>, tensor<i32>, tuple<tensor<ui32>, tensor<i32>>>)
-> tensor<ui32>
return %1, %3 : tensor<i32>, tensor<ui32>
}
Loading