Skip to content

Commit

Permalink
[XLA:GPU] Preserve metadata in SortRewriter pass.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694557433
  • Loading branch information
akuegel authored and Google-ML-Automation committed Nov 8, 2024
1 parent 993068f commit 21da776
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
11 changes: 5 additions & 6 deletions xla/service/gpu/transforms/sort_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,13 @@ absl::StatusOr<std::unique_ptr<CubSortRunnerInterface>> CreateRunner(
// The trailing argument is the scratch buffer which should be discarded.
HloInstruction* UnpackResultPair(HloSortInstruction* sort_op,
HloInstruction* custom_call, bool swap) {
HloComputation* parent = sort_op->parent();
HloInstruction* gte0 =
parent->AddInstruction(HloInstruction::CreateGetTupleElement(
sort_op->AddInstruction(HloInstruction::CreateGetTupleElement(
sort_op->operand(0)->shape(), custom_call, swap ? 1 : 0));
HloInstruction* gte1 =
parent->AddInstruction(HloInstruction::CreateGetTupleElement(
sort_op->AddInstruction(HloInstruction::CreateGetTupleElement(
sort_op->operand(1)->shape(), custom_call, swap ? 0 : 1));
return parent->AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
return sort_op->AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
}

} // namespace
Expand Down Expand Up @@ -264,7 +263,7 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(

// Build the custom call instruction.
HloInstruction* custom_call =
sort_op->parent()->AddInstruction(HloInstruction::CreateCustomCall(
sort_op->AddInstruction(HloInstruction::CreateCustomCall(
call_shape, absl::MakeSpan(operands), kCubDeviceRadixSortTarget));

xla::SortOptions backend_config;
Expand All @@ -274,7 +273,7 @@ absl::StatusOr<bool> SortRewriter::RunOnInstruction(
// Build the replacement instruction.
HloInstruction* replacement;
if (sort_op->operand_count() == 1 || values == nullptr) {
replacement = sort_op->parent()->AddInstruction(
replacement = sort_op->AddInstruction(
HloInstruction::CreateGetTupleElement(keys->shape(), custom_call, 0));
} else {
replacement = UnpackResultPair(sort_op, custom_call,
Expand Down
32 changes: 32 additions & 0 deletions xla/service/gpu/transforms/sort_rewriter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,38 @@ ENTRY %main {
GmockMatch(m::GetTupleElement(m::CustomCall(), 0)));
}

TEST_F(SortRewriterTest, SortPairsIotaComparerSimplePreservesMetadata) {
constexpr char kHlo[] = R"(
HloModule TestModule
%compare {
%lhs = u16[] parameter(0)
%rhs = u16[] parameter(1)
%lhs_index = s32[] parameter(2)
%rhs_index = s32[] parameter(3)
cmp_indices = pred[] compare(%lhs_index, %rhs_index), direction=LT
cmp_lr = pred[] compare(%lhs, %rhs), direction=GT
cmp_eq = pred[] compare(%lhs, %rhs), direction=EQ
ROOT %lt = pred[] select(cmp_eq, cmp_indices, cmp_lr)
}
ENTRY %main {
%inputs = u16[1000] parameter(0)
%iota = s32[1000] iota(), iota_dimension=0
ROOT %sort = (u16[1000], s32[1000]) sort(%inputs, %iota),
dimensions={0}, to_apply=%compare, metadata={op_type="sort" op_name="sort" source_file="path/to/test.cc" source_line=68}
})";
constexpr char kExpectedPattern[] = R"(
// CHECK: %[[CC:.*]] = (u16[1000]{0}, s32[1000]{0}, u8[1]{0}) custom-call({{.*}}), custom_call_target="__cub$DeviceRadixSort", metadata={op_type="sort" op_name="sort" source_file="path/to/test.cc" source_line=68}, backend_config={"descending":true}
// CHECK: %[[GTE0:.*]] = u16[1000]{0} get-tuple-element(%[[CC]]), index=0, metadata={op_type="sort" op_name="sort" source_file="path/to/test.cc" source_line=68}
// CHECK: %[[GTE1:.*]] = s32[1000]{0} get-tuple-element(%[[CC]]), index=1, metadata={op_type="sort" op_name="sort" source_file="path/to/test.cc" source_line=68}
// CHECK: ROOT %{{.*}} = (u16[1000]{0}, s32[1000]{0}) tuple(%[[GTE0]], %[[GTE1]]), metadata={op_type="sort" op_name="sort" source_file="path/to/test.cc" source_line=68}
)";
RunAndFilecheckHloRewrite(kHlo, SortRewriter(), kExpectedPattern);
}

// Sort a pair of tensors (values, indices generated by iota) with a complex
// compare computation that matches the output of the StableSortExpander pass.
TEST_F(SortRewriterTest, SortPairsIotaComparerLikeStableSortExpander) {
Expand Down

0 comments on commit 21da776

Please sign in to comment.