diff --git a/xla/service/BUILD b/xla/service/BUILD index 2adf484d53efa..d2cec90701211 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -268,6 +268,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:backend_configs_cc", "//xla/service/graphcycles", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -283,14 +284,14 @@ xla_cc_test( ":collective_ops_utils", ":collective_permute_decomposer", "//xla/hlo/ir:hlo", - "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_matchers", "//xla/hlo/utils:hlo_query", "//xla/service/gpu:backend_configs_cc", - "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], ) diff --git a/xla/service/collective_permute_decomposer.h b/xla/service/collective_permute_decomposer.h index 11e96e5005e11..daffaecf58c2d 100644 --- a/xla/service/collective_permute_decomposer.h +++ b/xla/service/collective_permute_decomposer.h @@ -16,6 +16,11 @@ limitations under the License. #ifndef XLA_SERVICE_COLLECTIVE_PERMUTE_DECOMPOSER_H_ #define XLA_SERVICE_COLLECTIVE_PERMUTE_DECOMPOSER_H_ +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/xla/service/collective_permute_decomposer_test.cc b/xla/service/collective_permute_decomposer_test.cc index cc0634472ecf1..85e13e8085411 100644 --- a/xla/service/collective_permute_decomposer_test.cc +++ b/xla/service/collective_permute_decomposer_test.cc @@ -23,72 +23,54 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace { using ::testing::ElementsAre; using ::testing::HasSubstr; -namespace op = xla::testing::opcode_matchers; -using CollectivePermuteDecomposerTest = HloTestBase; - -TEST_F(CollectivePermuteDecomposerTest, WithCycleNotTransformed) { - const absl::string_view kModuleStr = R"( - HloModule test - ENTRY test_computation { - p = u32[] replica-id() - ROOT cp = u32[] collective-permute(p), channel_id=1, - source_target_pairs={{0,1}, {1,0}} - } - )"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_FALSE(changed); -} -TEST_F(CollectivePermuteDecomposerTest, WithContextDataNotTransformed) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - p = u32[] replica-id() - ROOT cp = (u32[], u32[], u32[], u32[]) collective-permute(p), channel_id=1, - source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} - } - )"; +namespace op = xla::testing::opcode_matchers; +using Pass = CollectivePermuteDecomposer; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_FALSE(changed); +class DecomposerTest : public HloHardwareIndependentTestBase { + protected: + void AssertNoTranform(absl::string_view hlo) { + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Pass(0), false)); + }; + auto Transform(absl::string_view hlo) { + return RunAndCheckHloRewrite(hlo, Pass(0), true); + }; +}; + +TEST_F(DecomposerTest, WithCycleNotTransformed) { + AssertNoTranform(R"(HloModule test + ENTRY test_computation { + p = u32[] replica-id() + ROOT cp = u32[] collective-permute(p), channel_id=1, + source_target_pairs={{0,1}, {1,0}} + } + )"); } -TEST_F(CollectivePermuteDecomposerTest, TransformedExplicitChannelId) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - p = u32[] replica-id() - ROOT cp = u32[] collective-permute(p), channel_id=1, - source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, - metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} - } +TEST_F(DecomposerTest, TransformedExplicitChannelId) { + absl::string_view hlo = R"( + HloModule test + ENTRY test_computation { + p = u32[] replica-id() + ROOT cp = u32[] collective-permute(p), channel_id=1, + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, + metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} + } )"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Transform(hlo)); auto check_metadata = [](const HloInstruction* inst) { EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add"); @@ -131,8 +113,8 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedExplicitChannelId) { EXPECT_THAT(root, op::GetTupleElement(recv_done, 0)); } -TEST_F(CollectivePermuteDecomposerTest, NotTransformedDefaultChannelId) { - const char* const kModuleStr = R"( +TEST_F(DecomposerTest, TransformedDefaultNoChannelId) { + absl::string_view hlo = R"( HloModule test ENTRY test_computation { p = u32[] replica-id() @@ -141,11 +123,7 @@ TEST_F(CollectivePermuteDecomposerTest, NotTransformedDefaultChannelId) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Transform(hlo)); HloInstruction* after_all = FindInstruction(module.get(), "after-all"); HloInstruction* recv = FindInstruction(module.get(), "recv"); @@ -172,26 +150,20 @@ TEST_F(CollectivePermuteDecomposerTest, NotTransformedDefaultChannelId) { EXPECT_THAT(root, op::GetTupleElement(recv_done, 0)); } -TEST_F(CollectivePermuteDecomposerTest, ThresholdNotTransformed) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - p = u32[] replica-id() - ROOT cp = u32[] collective-permute(p), channel_id=1, - source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, - metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} - } - )"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/8); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_FALSE(changed); +TEST_F(DecomposerTest, ThresholdNotTransformed) { + absl::string_view hlo = R"(HloModule test + ENTRY test_computation { + p = u32[] replica-id() + ROOT cp = u32[] collective-permute(p), channel_id=1, + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, + metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} + })"; + TF_ASSERT_OK( + RunAndCheckHloRewrite(hlo, Pass(/*threshold_in_bytes=*/8), false)); } -TEST_F(CollectivePermuteDecomposerTest, Pipeline1) { - const char* const kModuleStr = R"( +TEST_F(DecomposerTest, Pipeline1) { + absl::string_view hlo = R"( HloModule module cond { param = (u32[], u32[2]) parameter(0) @@ -229,11 +201,7 @@ TEST_F(CollectivePermuteDecomposerTest, Pipeline1) { ROOT result = u32[2] get-tuple-element(while_result), index=1 })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Transform(hlo)); HloInstruction* recv = FindInstruction(module.get(), "recv"); EXPECT_EQ(recv->channel_id().value(), 1); EXPECT_THAT( @@ -262,7 +230,7 @@ TEST_F(CollectivePermuteDecomposerTest, Pipeline1) { EXPECT_THAT(recv_done->control_predecessors(), ElementsAre(send)); } -TEST_F(CollectivePermuteDecomposerTest, ForwardPipeline2) { +TEST_F(DecomposerTest, ForwardPipeline2) { const char* const kModuleStr = R"( HloModule module cond { @@ -310,10 +278,8 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipeline2) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); + Transform(kModuleStr)); + HloInstruction* recv = FindInstruction(module.get(), "recv"); EXPECT_EQ(recv->channel_id().value(), 1); EXPECT_THAT(recv->ToString(), @@ -347,7 +313,7 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipeline2) { EXPECT_THAT(send1->control_predecessors(), ElementsAre(recv1)); } -TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { +TEST_F(DecomposerTest, ForwardPipelineWithMatmul) { // The HLO module below is generated by passing the HLO in // CollectiveOpsTest.CollectivePermute_CircularPipelinePreOptimization through // the collective_permute_cycle_decomposer.transformation. @@ -401,10 +367,7 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); + Transform(kModuleStr)); HloModule* transformed_module = module.get(); // Check the annotations and ordering of the decomposed send-recv pairs. // We expect the recv to come before the send in the while body, both for the @@ -458,8 +421,8 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { EXPECT_THAT(recv_done_bwd->control_predecessors(), ElementsAre(send_fwd)); } -TEST_F(CollectivePermuteDecomposerTest, BackwardPipeline2) { - const char* const kModuleStr = R"( +TEST_F(DecomposerTest, BackwardPipeline2) { + absl::string_view hlo = R"( HloModule module cond { param = (u32[], u32[2]) parameter(0) @@ -505,11 +468,7 @@ TEST_F(CollectivePermuteDecomposerTest, BackwardPipeline2) { ROOT result = u32[2] get-tuple-element(while_result), index=1 })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Transform(hlo)); HloInstruction* recv = FindInstruction(module.get(), "recv"); EXPECT_EQ(recv->channel_id().value(), 1); EXPECT_THAT( @@ -537,9 +496,8 @@ TEST_F(CollectivePermuteDecomposerTest, BackwardPipeline2) { EXPECT_THAT(send->control_predecessors(), ElementsAre(recv)); } -TEST_F(CollectivePermuteDecomposerTest, - DecomposeCrossReplicaCollectivePermute) { - const char* const kModuleStr = R"( +TEST_F(DecomposerTest, DecomposeCrossReplicaCollectivePermute) { + absl::string_view hlo = R"( HloModule module ENTRY body { data = f32[16] parameter(0) @@ -547,12 +505,7 @@ TEST_F(CollectivePermuteDecomposerTest, source_target_pairs={{0,1}, {1,2}, {2,3}} } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Transform(hlo)); HloComputation* comp = module->entry_computation(); HloInstruction* root = comp->root_instruction();