diff --git a/xla/mlir_hlo/BUILD b/xla/mlir_hlo/BUILD index 9f77bab2b71d2..6cf3950527150 100644 --- a/xla/mlir_hlo/BUILD +++ b/xla/mlir_hlo/BUILD @@ -1083,6 +1083,7 @@ cc_library( "stablehlo_ext/transforms/chlo_recompose_ops.cpp", "stablehlo_ext/transforms/sdy_refine_shapes.cpp", "stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp", + "stablehlo_ext/transforms/stablehlo_flatten_tuple.cpp", "stablehlo_ext/transforms/stablehlo_prepare_for_hlo_export.cpp", "stablehlo_ext/transforms/stablehlo_refine_shapes.cpp", ], diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/passes.h b/xla/mlir_hlo/stablehlo_ext/transforms/passes.h index e2b77594141f0..243267e044cb3 100644 --- a/xla/mlir_hlo/stablehlo_ext/transforms/passes.h +++ b/xla/mlir_hlo/stablehlo_ext/transforms/passes.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassOptions.h" #include "mlir/Transforms/DialectConversion.h" @@ -27,10 +28,13 @@ namespace mlir { namespace stablehlo_ext { #define GEN_PASS_DECL -#define GEN_PASS_REGISTRATION #include "stablehlo_ext/transforms/passes.h.inc" void createChloLegalizeToStablehloPipeline(OpPassManager &pm); +std::unique_ptr> createStablehloFlattenTuplePass(); + +#define GEN_PASS_REGISTRATION +#include "stablehlo_ext/transforms/passes.h.inc" } // namespace stablehlo_ext } // namespace mlir diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/passes.td b/xla/mlir_hlo/stablehlo_ext/transforms/passes.td index 5e329ff7f06b1..eae60e5d3e62d 100644 --- a/xla/mlir_hlo/stablehlo_ext/transforms/passes.td +++ b/xla/mlir_hlo/stablehlo_ext/transforms/passes.td @@ -50,4 +50,10 @@ def StablehloPrepareForHloExportPass : Pass<"stablehlo-ext-prepare-for-hlo-expor Note: The result of this pass need not be a module in canonical form and canonicalization may undo transformations. }]; -} \ No newline at end of file +} + +def StablehloFlattenTuplePass : Pass<"stablehlo-ext-flatten-tuple", "func::FuncOp"> { + let summary = "Flatten tuples in operands and results of operators that " + "support both tuple and variadic type."; + let constructor = "createStablehloFlattenTuplePass()"; +} diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_flatten_tuple.cpp b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_flatten_tuple.cpp new file mode 100644 index 0000000000000..7e67d796a6784 --- /dev/null +++ b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_flatten_tuple.cpp @@ -0,0 +1,157 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for flattening tuples in HLO ops. + +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo_ext/transforms/passes.h" // NOLINT: Used in passes.h.inc + +namespace mlir { +namespace stablehlo_ext { + +#define GEN_PASS_DEF_STABLEHLOFLATTENTUPLEPASS +#include "stablehlo_ext/transforms/passes.h.inc" + +namespace { + +// Calculates the flatten types of a value. +void flattenTupleType(Value value, llvm::SmallVectorImpl &types) { + if (!mlir::isa(value.getType())) { + types.push_back(value.getType()); + return; + } + + // This function doesn't handle nested tuple. + auto tupleType = mlir::cast(value.getType()); + types.append(tupleType.begin(), tupleType.end()); +} + +// FlattenTupleValue and CreateTupleValue is a pair of functions to create and +// flatten tuples in the exact same order. CreateTupleValue returns the result +// of the root TupleOp or given value if the type is not TupleType. +Value createTupleValue(OpBuilder &builder, Location loc, + ValueRange flattenValues, Type tupleType) { + if (!mlir::isa(tupleType)) { + assert(flattenValues.size() == 1); + return flattenValues[0]; + } + + assert(mlir::cast(tupleType).getTypes().size() == + flattenValues.size()); + return builder.create(loc, flattenValues); +} + +void flattenTupleValue(OpBuilder &builder, Location loc, Value value, + llvm::SmallVectorImpl &flattenedValues) { + auto tupleType = mlir::dyn_cast(value.getType()); + if (!tupleType) { + flattenedValues.push_back(value); + return; + } + int flattenIdx = 0; + for (auto innerType : tupleType.getTypes()) { + auto innerValue = builder.create( + loc, innerType, value, builder.getI32IntegerAttr(flattenIdx++)); + flattenTupleValue(builder, loc, innerValue, flattenedValues); + } +} + +struct FlattenCustomCallOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::CustomCallOp op, + PatternRewriter &rewriter) const override { + bool flattenResult = op->getNumResults() == 1 && + mlir::isa(op->getResult(0).getType()); + bool flattenOperands = llvm::any_of(op.getInputs(), [](Value operand) { + return mlir::isa(operand.getType()); + }); + + if (!flattenResult && !flattenOperands) return failure(); + + llvm::SmallVector flattenedOperands; + for (auto operand : op.getInputs()) + flattenTupleValue(rewriter, op->getLoc(), operand, flattenedOperands); + + llvm::SmallVector flattenedResultTypes; + if (!flattenResult) { + flattenedResultTypes.push_back(op->getResult(0).getType()); + } else { + // Check for nested tuples. + for (Type innerType : + mlir::cast(op->getResult(0).getType()).getTypes()) + if (mlir::isa(innerType)) return failure(); + + for (auto result : op->getResults()) + flattenTupleType(result, flattenedResultTypes); + } + + auto flattenedCall = rewriter.create( + op->getLoc(), flattenedResultTypes, flattenedOperands, op->getAttrs()); + + rewriter.replaceOp(op, flattenResult + ? createTupleValue(rewriter, op->getLoc(), + flattenedCall.getResults(), + op->getResult(0).getType()) + : flattenedCall.getResult(0)); + return success(); + } +}; + +class StablehloFlattenTuplePass + : public impl::StablehloFlattenTuplePassBase { + public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +static PassRegistration pass; + +} // namespace + +std::unique_ptr> createStablehloFlattenTuplePass() { + return std::make_unique(); +} + +} // namespace stablehlo_ext +} // namespace mlir diff --git a/xla/mlir_hlo/tests/stablehlo_ext/stablehlo_flatten_tuple.mlir b/xla/mlir_hlo/tests/stablehlo_ext/stablehlo_flatten_tuple.mlir new file mode 100644 index 0000000000000..4bcf9caab8b1e --- /dev/null +++ b/xla/mlir_hlo/tests/stablehlo_ext/stablehlo_flatten_tuple.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-hlo-opt -split-input-file -stablehlo-ext-flatten-tuple %s | FileCheck %s + +// CHECK-LABEL: @custom_call +// CHECK-SAME: %[[X:.*]]: tensor<6x3xf32> +func.func @custom_call(%x: tensor<6x3xf32>) -> (tensor<6xf32>, tensor<3xf32>) { + // CHECK: %[[CALL:.+]]:2 = stablehlo.custom_call @f(%[[X]]) {api_version = 2 : i32} : (tensor<6x3xf32>) -> (tensor<6xf32>, tensor<3xf32>) + %0 = "stablehlo.custom_call"(%x) {api_version = 2 : i32, call_target_name = "f"} + : (tensor<6x3xf32>) -> tuple, tensor<3xf32>> + %1 = "stablehlo.get_tuple_element"(%0) {index = 0 : i32} : (tuple, tensor<3xf32>>) -> tensor<6xf32> + %2 = "stablehlo.get_tuple_element"(%0) {index = 1 : i32} : (tuple, tensor<3xf32>>) -> tensor<3xf32> + return %1, %2 : tensor<6xf32>, tensor<3xf32> +} + +// ----- + +// CHECK-LABEL: @custom_call_tupled_operand +func.func @custom_call_tupled_operand(%arg0: tuple, tensor>) + -> (tensor, tensor) { + // CHECK-NEXT: %[[C0:.*]] = stablehlo.constant dense<1> : tensor + %0 = stablehlo.constant dense<1> : tensor + // CHECK-NEXT: %[[C1:.*]] = stablehlo.constant dense<10> : tensor + %1 = stablehlo.constant dense<10> : tensor + // CHECK-NEXT: %[[TUPLE:.*]] = stablehlo.tuple %[[C0]], %[[C1]], %arg + %2 = stablehlo.tuple %0, %1, %arg0 : tuple, tensor, + tuple, tensor>> + // CHECK-NEXT: %[[VAR1:.*]] = stablehlo.get_tuple_element %[[TUPLE]][0] + // CHECK-NEXT: %[[VAR2:.*]] = stablehlo.get_tuple_element %[[TUPLE]][1] + // CHECK-NEXT: %[[VAR3:.*]] = stablehlo.get_tuple_element %[[TUPLE]][2] + // CHECK-NEXT: %[[VAR4:.*]] = stablehlo.get_tuple_element %[[VAR3]][0] + // CHECK-NEXT: %[[VAR5:.*]] = stablehlo.get_tuple_element %[[VAR3]][1] + // CHECK-NEXT: stablehlo.custom_call @ScalarProgramDummyConstant(%[[VAR1]], %[[VAR2]], %[[VAR4]], %[[VAR5]]) + %3 = stablehlo.custom_call @ScalarProgramDummyConstant(%2) + : (tuple, tensor, tuple, tensor>>) + -> tensor + return %1, %3 : tensor, tensor +}