diff --git a/xla/service/gpu/transforms/sort_rewriter.cc b/xla/service/gpu/transforms/sort_rewriter.cc index 0d7bac8019429..7377f28aad4e7 100644 --- a/xla/service/gpu/transforms/sort_rewriter.cc +++ b/xla/service/gpu/transforms/sort_rewriter.cc @@ -210,14 +210,13 @@ absl::StatusOr> CreateRunner( // The trailing argument is the scratch buffer which should be discarded. HloInstruction* UnpackResultPair(HloSortInstruction* sort_op, HloInstruction* custom_call, bool swap) { - HloComputation* parent = sort_op->parent(); HloInstruction* gte0 = - parent->AddInstruction(HloInstruction::CreateGetTupleElement( + sort_op->AddInstruction(HloInstruction::CreateGetTupleElement( sort_op->operand(0)->shape(), custom_call, swap ? 1 : 0)); HloInstruction* gte1 = - parent->AddInstruction(HloInstruction::CreateGetTupleElement( + sort_op->AddInstruction(HloInstruction::CreateGetTupleElement( sort_op->operand(1)->shape(), custom_call, swap ? 0 : 1)); - return parent->AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + return sort_op->AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); } } // namespace @@ -264,7 +263,7 @@ absl::StatusOr SortRewriter::RunOnInstruction( // Build the custom call instruction. HloInstruction* custom_call = - sort_op->parent()->AddInstruction(HloInstruction::CreateCustomCall( + sort_op->AddInstruction(HloInstruction::CreateCustomCall( call_shape, absl::MakeSpan(operands), kCubDeviceRadixSortTarget)); xla::SortOptions backend_config; @@ -274,7 +273,7 @@ absl::StatusOr SortRewriter::RunOnInstruction( // Build the replacement instruction. HloInstruction* replacement; if (sort_op->operand_count() == 1 || values == nullptr) { - replacement = sort_op->parent()->AddInstruction( + replacement = sort_op->AddInstruction( HloInstruction::CreateGetTupleElement(keys->shape(), custom_call, 0)); } else { replacement = UnpackResultPair(sort_op, custom_call, diff --git a/xla/service/gpu/transforms/sort_rewriter_test.cc b/xla/service/gpu/transforms/sort_rewriter_test.cc index 46e9d92b19d3f..db2d727a40f22 100644 --- a/xla/service/gpu/transforms/sort_rewriter_test.cc +++ b/xla/service/gpu/transforms/sort_rewriter_test.cc @@ -505,6 +505,38 @@ ENTRY %main { GmockMatch(m::GetTupleElement(m::CustomCall(), 0))); } +TEST_F(SortRewriterTest, SortPairsIotaComparerSimplePreservesMetadata) { + constexpr char kHlo[] = R"( +HloModule TestModule + +%compare { + %lhs = u16[] parameter(0) + %rhs = u16[] parameter(1) + %lhs_index = s32[] parameter(2) + %rhs_index = s32[] parameter(3) + + cmp_indices = pred[] compare(%lhs_index, %rhs_index), direction=LT + cmp_lr = pred[] compare(%lhs, %rhs), direction=GT + cmp_eq = pred[] compare(%lhs, %rhs), direction=EQ + + ROOT %lt = pred[] select(cmp_eq, cmp_indices, cmp_lr) +} + +ENTRY %main { + %inputs = u16[1000] parameter(0) + %iota = s32[1000] iota(), iota_dimension=0 + ROOT %sort = (u16[1000], s32[1000]) sort(%inputs, %iota), + dimensions={0}, to_apply=%compare, metadata={op_type="sort" op_name="sort" source_file="path/to/test.cc" source_line=68} +})"; + constexpr char kExpectedPattern[] = R"( + // CHECK: %[[CC:.*]] = (u16[1000]{0}, s32[1000]{0}, u8[1]{0}) custom-call({{.*}}), custom_call_target="__cub$DeviceRadixSort", metadata={op_type="sort" op_name="sort" source_file="path/to/test.cc" source_line=68}, backend_config={"descending":true} + // CHECK: %[[GTE0:.*]] = u16[1000]{0} get-tuple-element(%[[CC]]), index=0, metadata={op_type="sort" op_name="sort" source_file="path/to/test.cc" source_line=68} + // CHECK: %[[GTE1:.*]] = s32[1000]{0} get-tuple-element(%[[CC]]), index=1, metadata={op_type="sort" op_name="sort" source_file="path/to/test.cc" source_line=68} + // CHECK: ROOT %{{.*}} = (u16[1000]{0}, s32[1000]{0}) tuple(%[[GTE0]], %[[GTE1]]), metadata={op_type="sort" op_name="sort" source_file="path/to/test.cc" source_line=68} + )"; + RunAndFilecheckHloRewrite(kHlo, SortRewriter(), kExpectedPattern); +} + // Sort a pair of tensors (values, indices generated by iota) with a complex // compare computation that matches the output of the StableSortExpander pass. TEST_F(SortRewriterTest, SortPairsIotaComparerLikeStableSortExpander) {