diff --git a/xla/service/BUILD b/xla/service/BUILD index d35ded4822129..ced14c645fcf3 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -256,22 +256,23 @@ cc_library( ) cc_library( - name = "collective_permute_utils", - srcs = ["collective_permute_utils.cc"], - hdrs = ["collective_permute_utils.h"], + name = "source_target_pairs", + srcs = ["source_target_pairs.cc"], + hdrs = ["source_target_pairs.h"], deps = [ - "//xla/hlo/ir:hlo", "//xla/service/graphcycles", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", ], ) xla_cc_test( - name = "collective_permute_utils_test", - srcs = ["collective_permute_utils_test.cc"], + name = "source_target_pairs_test", + srcs = ["source_target_pairs_test.cc"], deps = [ - ":collective_permute_utils", + ":source_target_pairs", "//xla:shape_util", "//xla/hlo/ir:hlo", "@com_google_googletest//:gtest_main", @@ -284,7 +285,7 @@ cc_library( hdrs = ["collective_permute_decomposer.h"], deps = [ ":collective_ops_utils", - ":collective_permute_utils", + ":source_target_pairs", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -296,7 +297,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/collective_permute_decomposer.cc b/xla/service/collective_permute_decomposer.cc index 9f051576e5fc0..49be15c4e130a 100644 --- a/xla/service/collective_permute_decomposer.cc +++ b/xla/service/collective_permute_decomposer.cc @@ -34,20 +34,16 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/collective_permute_utils.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/source_target_pairs.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/platform/errors.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { -using SourceTargetPair = std::pair; -using SourceTargetPairs = std::vector; - // Returns true if the CollectivePermute instruction should be transformed // to Send/Recv. We currently limit the transformation to CollectivePermute // operations without any cycle in their (source, target) relationship, @@ -65,7 +61,8 @@ bool ShouldDecompose(const HloCollectivePermuteInstruction& collective_permute, if (ShapeUtil::ByteSizeOf(result_shape) < threshold_in_bytes) { return false; } - return !cp_utils::HasCycles(collective_permute.source_target_pairs()); + return !SourceTargetPairs(collective_permute.source_target_pairs()) + .HasCycles(); } // Returns true for a pipelineable collective-permute. As a simple heuristic, @@ -82,7 +79,7 @@ bool MayPipeline(const HloCollectivePermuteInstruction& collective_permute) { struct DecomposedCp { HloInstruction* send; HloInstruction* recv; - SourceTargetPairs source_target_pairs; + std::vector> source_target_pairs; }; xla::FrontendAttributes ExtractFrontendAttributes( @@ -92,7 +89,7 @@ xla::FrontendAttributes ExtractFrontendAttributes( attributes.mutable_map()->insert(old_attributes.map().begin(), old_attributes.map().end()); (*attributes.mutable_map())[kSendRecvSourceTargetPairsAttr] = - cp_utils::SourceTargetPairsString(cp); + SourceTargetPairs(cp.source_target_pairs()).ToString(); return attributes; } @@ -170,21 +167,17 @@ std::optional> CheckCyclePatterns(HloCollectivePermuteInstruction* cp0, HloCollectivePermuteInstruction* cp1) { - const SourceTargetPairs& cp0_pairs = cp0->source_target_pairs(); - const SourceTargetPairs& cp1_pairs = cp1->source_target_pairs(); - if (cp0_pairs.size() == 1) { - if (cp_utils::IsForwardCycle(cp0_pairs.front(), cp1_pairs) || - cp_utils::IsBackwardCycle(cp0_pairs.front(), cp1_pairs)) { - // cp0 represents the backedge for the cycle. - return std::make_pair(cp0, cp1); - } + const SourceTargetPairs cp0_pairs(cp0->source_target_pairs()); + const SourceTargetPairs cp1_pairs(cp1->source_target_pairs()); + if (SourceTargetPairs::IsForwardCycle(cp0_pairs, cp1_pairs) || + SourceTargetPairs::IsBackwardCycle(cp0_pairs, cp1_pairs)) { + // cp0 represents the backedge for the cycle. + return std::make_pair(cp0, cp1); } - if (cp1_pairs.size() == 1) { - if (cp_utils::IsForwardCycle(cp1_pairs.front(), cp0_pairs) || - cp_utils::IsBackwardCycle(cp1_pairs.front(), cp0_pairs)) { - // cp1 represents the forward edge for the cycle. - return std::make_pair(cp1, cp0); - } + if (SourceTargetPairs::IsForwardCycle(cp1_pairs, cp0_pairs) || + SourceTargetPairs::IsBackwardCycle(cp1_pairs, cp0_pairs)) { + // cp1 represents the forward edge for the cycle. + return std::make_pair(cp1, cp0); } return std::nullopt; } diff --git a/xla/service/collective_permute_utils.h b/xla/service/collective_permute_utils.h deleted file mode 100644 index 46c62ea25bb38..0000000000000 --- a/xla/service/collective_permute_utils.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2025 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_COLLECTIVE_PERMUTE_UTILS_H_ -#define XLA_SERVICE_COLLECTIVE_PERMUTE_UTILS_H_ - -#include -#include -#include -#include - -#include "xla/hlo/ir/hlo_instructions.h" - -namespace xla { -namespace cp_utils { - -using SourceTargetPair = std::pair; -using SourceTargetPairs = std::vector; - -// Source Targe Pairs to a cannoical string such as {{0,1},{1,2},{2,3},{3,0}}. -std::string SourceTargetPairsString(const HloCollectivePermuteInstruction& cp); - -// Returns true if the (source, target) relationship has a cycle. -bool HasCycles(const SourceTargetPairs& pairs); - -// Returns true if the (source, target) pairs form a forward cycle with all -// participants in the cycle, such as {{0,1},{1,2},{2,3},{3,0}}. We assume that -// the (source, target) pairs are ordered via increasing source IDs, as they are -// currently generated by SPMD partitioning. -bool IsForwardCycle(const SourceTargetPair& backedge, - const SourceTargetPairs& others); - -// Returns true if the (source, target) pairs form a backward cycle with all -// participants in the cycle, such as {{0,3},{1,0},{2,1},{3,2}}. We assume that -// the (source, target) pairs are ordered via increasing source IDs, as they are -// currently generated by SPMD partitioning. -bool IsBackwardCycle(const SourceTargetPair& backedge, - const SourceTargetPairs& others); - -} // namespace cp_utils -} // namespace xla -#endif // XLA_SERVICE_COLLECTIVE_PERMUTE_UTILS_H_ diff --git a/xla/service/collective_permute_utils_test.cc b/xla/service/collective_permute_utils_test.cc deleted file mode 100644 index 0ada6402758b0..0000000000000 --- a/xla/service/collective_permute_utils_test.cc +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2025 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/collective_permute_utils.h" - -#include - -#include -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/shape_util.h" - -namespace xla { -namespace cp_utils { - -struct Cannonical { - SourceTargetPairs cycle; - SourceTargetPairs fwd_edge; - SourceTargetPairs bwd_edge; -}; - -class CollectivePermuteUtilsTest : public ::testing::Test { - protected: - Cannonical fwd2_ = { - .cycle = {{0, 1}, {1, 0}}, .fwd_edge = {{0, 1}}, .bwd_edge = {{1, 0}}}; - Cannonical bwd2_ = { - .cycle = {{1, 0}, {0, 1}}, .fwd_edge = {{1, 0}}, .bwd_edge = {{0, 1}}}; - Cannonical fwd4_ = {.cycle = {{0, 1}, {1, 2}, {2, 3}, {3, 0}}, - .fwd_edge = {{0, 1}, {1, 2}, {2, 3}}, - .bwd_edge = {{3, 0}}}; - Cannonical bwd4_ = {.cycle = {{0, 3}, {1, 0}, {2, 1}, {3, 2}}, - .fwd_edge = {{1, 0}, {2, 1}, {3, 2}}, - .bwd_edge = {{0, 3}}}; - std::unique_ptr simple_input_ = HloInstruction::CreateToken(); - - HloCollectivePermuteInstruction CreateCollectivePermute( - const SourceTargetPairs& pairs) { - return HloCollectivePermuteInstruction(HloOpcode::kCollectivePermute, - ShapeUtil::MakeShape(U32, {8, 8}), - {simple_input_.get()}, pairs, 1); - } -}; - -TEST_F(CollectivePermuteUtilsTest, HasCycles) { - EXPECT_TRUE(HasCycles(fwd2_.cycle)); - EXPECT_TRUE(HasCycles(bwd2_.cycle)); - EXPECT_TRUE(HasCycles(fwd4_.cycle)); - EXPECT_TRUE(HasCycles(bwd4_.cycle)); - - EXPECT_TRUE(HasCycles({{0, 1}, {1, 2}, {2, 3}, {3, 2}})) << "Lasso 3->2"; - EXPECT_TRUE(HasCycles({{0, 1}, {1, 2}, {2, 3}, {3, 1}})) << "Lasso 3->1"; - - EXPECT_FALSE(HasCycles({{1, 2}, {2, 3}, {3, 0}})) << "Forward only"; - EXPECT_FALSE(HasCycles({{1, 2}})) << "Single edge"; -} - -bool IsForwardCycle(Cannonical& canonical) { - return IsForwardCycle(canonical.bwd_edge[0], canonical.fwd_edge); -} -bool IsBackwardCycle(Cannonical& canonical) { - return IsBackwardCycle(canonical.bwd_edge[0], canonical.fwd_edge); -} - -TEST_F(CollectivePermuteUtilsTest, IsForwardCycle) { - EXPECT_TRUE(IsForwardCycle(fwd2_)); - EXPECT_TRUE(IsForwardCycle(fwd4_)); - - EXPECT_FALSE(IsForwardCycle(bwd2_)); - EXPECT_FALSE(IsForwardCycle(bwd4_)); - - EXPECT_FALSE(IsForwardCycle({3, 0}, {{0, 2}, {2, 3}, {3, 0}})) << "Skip 1"; -} - -TEST_F(CollectivePermuteUtilsTest, IsBackwardCycle) { - EXPECT_TRUE(IsBackwardCycle(bwd2_)); - EXPECT_TRUE(IsBackwardCycle(bwd4_)); - - EXPECT_FALSE(IsBackwardCycle(fwd2_)); - EXPECT_FALSE(IsBackwardCycle(fwd4_)); -} - -TEST_F(CollectivePermuteUtilsTest, SourceTargetPairsString) { - EXPECT_EQ(SourceTargetPairsString(CreateCollectivePermute(fwd2_.cycle)), - "{{0,1},{1,0}}"); - EXPECT_EQ(SourceTargetPairsString(CreateCollectivePermute(bwd2_.cycle)), - "{{1,0},{0,1}}"); - EXPECT_EQ(SourceTargetPairsString(CreateCollectivePermute(fwd4_.cycle)), - "{{0,1},{1,2},{2,3},{3,0}}"); - EXPECT_EQ(SourceTargetPairsString(CreateCollectivePermute(bwd4_.cycle)), - "{{0,3},{1,0},{2,1},{3,2}}"); -} - -} // namespace cp_utils -} // namespace xla diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index 7c76a072fc8d8..4095301928aa9 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -538,14 +538,14 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:collective_ops_utils", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", ], ) @@ -556,14 +556,14 @@ xla_cc_test( ":collective_select_folder", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/gpu/transforms/collective_select_folder.cc b/xla/service/gpu/transforms/collective_select_folder.cc index 5b6c4c008ee89..768bcb0576ea7 100644 --- a/xla/service/gpu/transforms/collective_select_folder.cc +++ b/xla/service/gpu/transforms/collective_select_folder.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/comparison_util.h" @@ -34,9 +35,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/collective_ops_utils.h" #include "xla/shape_util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace { diff --git a/xla/service/gpu/transforms/collective_select_folder_test.cc b/xla/service/gpu/transforms/collective_select_folder_test.cc index 42ecc87717cff..5d7453abba64d 100644 --- a/xla/service/gpu/transforms/collective_select_folder_test.cc +++ b/xla/service/gpu/transforms/collective_select_folder_test.cc @@ -28,10 +28,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace { @@ -39,7 +39,7 @@ namespace { namespace op = ::xla::testing::opcode_matchers; using ::testing::HasSubstr; -class CollectiveSelectFolderTest : public HloTestBase { +class CollectiveSelectFolderTest : public HloHardwareIndependentTestBase { public: absl::Status ExpectNoTranform(absl::string_view hlo_template) { return RunAndCheckHloRewrite(hlo_template, CollectiveSelectFolder(), diff --git a/xla/service/collective_permute_utils.cc b/xla/service/source_target_pairs.cc similarity index 59% rename from xla/service/collective_permute_utils.cc rename to xla/service/source_target_pairs.cc index 3ee67e3d86096..e2b35cc4dd5d2 100644 --- a/xla/service/collective_permute_utils.cc +++ b/xla/service/source_target_pairs.cc @@ -13,31 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/collective_permute_utils.h" +#include "xla/service/source_target_pairs.h" #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/graphcycles/graphcycles.h" namespace xla { -namespace cp_utils { -using ::xla::HloCollectivePermuteInstruction; - -std::string SourceTargetPairsString(const HloCollectivePermuteInstruction& cp) { - auto formatter = absl::PairFormatter( - [](std::string* out, int64_t value) { absl::StrAppend(out, "{", value); }, - ",", - [](std::string* out, int64_t value) { - absl::StrAppend(out, value, "}"); - }); - const std::string pairs_str = - absl::StrJoin(cp.source_target_pairs(), ",", formatter); +std::string SourceTargetPairs::ToString() const { + auto formatter = [](std::string* out, const SourceTargetPair& pair) { + absl::StrAppend(out, "{", pair.source, ",", pair.target, "}"); + }; + const std::string pairs_str = absl::StrJoin(pairs_, ",", formatter); return absl::StrCat("{", pairs_str, "}"); } @@ -51,12 +44,12 @@ int32_t GetNodeId(int64_t replica, GraphCycles& graph, } } // namespace -bool HasCycles(const SourceTargetPairs& pairs) { +bool SourceTargetPairs::HasCycles() { GraphCycles graph; absl::flat_hash_map replica_to_node_id; - for (const SourceTargetPair& pair : pairs) { - const int source = GetNodeId(pair.first, graph, replica_to_node_id); - const int target = GetNodeId(pair.second, graph, replica_to_node_id); + for (const SourceTargetPair& pair : pairs_) { + const int source = GetNodeId(pair.source, graph, replica_to_node_id); + const int target = GetNodeId(pair.target, graph, replica_to_node_id); if (!graph.InsertEdge(source, target)) { return true; } @@ -65,35 +58,40 @@ bool HasCycles(const SourceTargetPairs& pairs) { } // TODO: b/388623407 - remove assumptions that pairs are ordered and 0 based. -bool IsForwardCycle(const SourceTargetPair& backedge, - const SourceTargetPairs& others) { +bool SourceTargetPairs::IsForwardCycle(const SourceTargetPairs& backedge, + const SourceTargetPairs& others) { + if (backedge.size() != 1) { + return false; + } const int64_t num_pairs = others.size() + 1; - if (backedge.first != num_pairs - 1 || backedge.second != 0) { + if (backedge[0].source != num_pairs - 1 || backedge[0].target != 0) { return false; } for (int64_t i = 0; i < num_pairs - 1; ++i) { const SourceTargetPair& pair = others[i]; - if (pair.first != i || pair.second != i + 1) { + if (pair.source != i || pair.target != i + 1) { return false; } } return true; } -bool IsBackwardCycle(const SourceTargetPair& backedge, - const SourceTargetPairs& others) { +bool SourceTargetPairs::IsBackwardCycle(const SourceTargetPairs& backedge, + const SourceTargetPairs& others) { + if (backedge.size() != 1) { + return false; + } const int64_t num_pairs = others.size() + 1; - if (backedge.first != 0 || backedge.second != num_pairs - 1) { + if (backedge[0].source != 0 || backedge[0].target != num_pairs - 1) { return false; } for (int64_t i = 0; i < num_pairs - 1; ++i) { const SourceTargetPair& pair = others[i]; - if (pair.first != i + 1 || pair.second != i) { + if (pair.source != i + 1 || pair.target != i) { return false; } } return true; } -} // namespace cp_utils } // namespace xla diff --git a/xla/service/source_target_pairs.h b/xla/service/source_target_pairs.h new file mode 100644 index 0000000000000..5d37b2b2a786f --- /dev/null +++ b/xla/service/source_target_pairs.h @@ -0,0 +1,92 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SOURCE_TARGET_PAIRS_H_ +#define XLA_SERVICE_SOURCE_TARGET_PAIRS_H_ + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" + +namespace xla { + +class SourceTargetPairs { + struct SourceTargetPair { + int64_t source; + int64_t target; + }; + + public: + SourceTargetPairs() = default; + + explicit SourceTargetPairs( + const std::vector>& pairs) { + for (const auto& pair : pairs) { + pairs_.push_back({.source = pair.first, .target = pair.second}); + } + } + + // Returns a cannoical string such as {{0,1},{1,2},{2,3},{3,0}}. + std::string ToString() const; + + SourceTargetPair& operator[](int64_t i) { + CHECK_LT(i, pairs_.size()) + << "Index out of bounds. Size: " << pairs_.size() << " Index: " << i; + return pairs_[i]; + } + const SourceTargetPair& operator[](int64_t i) const { + CHECK_LT(i, pairs_.size()) + << "Index out of bounds. Size: " << pairs_.size() << " Index: " << i; + return pairs_[i]; + } + + int64_t size() const { return pairs_.size(); } + + std::vector> data() const { + std::vector> data; + for (const auto& pair : pairs_) { + data.push_back({pair.source, pair.target}); + } + return data; + } + + // Returns true if the (source, target) relationship has a cycle. + bool HasCycles(); + + // Returns true if the (source, target) pairs form a forward cycle with all + // participants in the cycle, such as {{0,1},{1,2},{2,3},{3,0}}. We assume + // that the (source, target) pairs are ordered via increasing source IDs, as + // they are currently generated by SPMD partitioning. + static bool IsForwardCycle(const SourceTargetPairs& backedge, + const SourceTargetPairs& others); + + // Returns true if the (source, target) pairs form a backward cycle with all + // participants in the cycle, such as {{0,3},{1,0},{2,1},{3,2}}. We assume + // that the (source, target) pairs are ordered via increasing source IDs, as + // they are currently generated by SPMD partitioning. + static bool IsBackwardCycle(const SourceTargetPairs& backedge, + const SourceTargetPairs& others); + + private: + static constexpr int64_t kInlineFactor = 8; + absl::InlinedVector pairs_; +}; + +} // namespace xla +#endif // XLA_SERVICE_SOURCE_TARGET_PAIRS_H_ diff --git a/xla/service/source_target_pairs_test.cc b/xla/service/source_target_pairs_test.cc new file mode 100644 index 0000000000000..0989972f5c64e --- /dev/null +++ b/xla/service/source_target_pairs_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/source_target_pairs.h" + +#include + +#include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape_util.h" + +namespace xla { +namespace { + +struct Cannonical { + SourceTargetPairs cycle; + SourceTargetPairs fwd_edge; + SourceTargetPairs bwd_edge; +}; + +class CollectivePermuteUtilsTest : public ::testing::Test { + protected: + Cannonical fwd2_ = {.cycle = SourceTargetPairs({{0, 1}, {1, 0}}), + .fwd_edge = SourceTargetPairs({{0, 1}}), + .bwd_edge = SourceTargetPairs({{1, 0}})}; + + Cannonical bwd2_ = {.cycle = SourceTargetPairs({{1, 0}, {0, 1}}), + .fwd_edge = SourceTargetPairs({{1, 0}}), + .bwd_edge = SourceTargetPairs({{0, 1}})}; + + Cannonical fwd4_ = { + .cycle = SourceTargetPairs({{0, 1}, {1, 2}, {2, 3}, {3, 0}}), + .fwd_edge = SourceTargetPairs({{0, 1}, {1, 2}, {2, 3}}), + .bwd_edge = SourceTargetPairs({{3, 0}})}; + + Cannonical bwd4_ = { + .cycle = SourceTargetPairs({{0, 3}, {1, 0}, {2, 1}, {3, 2}}), + .fwd_edge = SourceTargetPairs({{1, 0}, {2, 1}, {3, 2}}), + .bwd_edge = SourceTargetPairs({{0, 3}})}; + std::unique_ptr simple_input_ = HloInstruction::CreateToken(); + + HloCollectivePermuteInstruction CreateCollectivePermute( + const SourceTargetPairs& pairs) { + return HloCollectivePermuteInstruction( + HloOpcode::kCollectivePermute, ShapeUtil::MakeShape(U32, {8, 8}), + simple_input_.get(), pairs.data(), 1); + } +}; + +TEST_F(CollectivePermuteUtilsTest, HasCycles) { + EXPECT_TRUE(fwd2_.cycle.HasCycles()); + EXPECT_TRUE(bwd2_.cycle.HasCycles()); + EXPECT_TRUE(fwd4_.cycle.HasCycles()); + EXPECT_TRUE(bwd4_.cycle.HasCycles()); + + EXPECT_TRUE(SourceTargetPairs({{0, 1}, {1, 2}, {2, 3}, {3, 2}}).HasCycles()) + << "Lasso 3->2"; + EXPECT_TRUE(SourceTargetPairs({{0, 1}, {1, 2}, {2, 3}, {3, 1}}).HasCycles()) + << "Lasso 3->1"; + + EXPECT_FALSE(SourceTargetPairs({{1, 2}, {2, 3}, {3, 0}}).HasCycles()) + << "Forward only"; + EXPECT_FALSE(SourceTargetPairs({{1, 2}}).HasCycles()) << "Single edge"; +} + +bool IsForwardCycle(Cannonical& canonical) { + return SourceTargetPairs::IsForwardCycle(canonical.bwd_edge, + canonical.fwd_edge); +} +bool IsBackwardCycle(Cannonical& canonical) { + return SourceTargetPairs::IsBackwardCycle(canonical.bwd_edge, + canonical.fwd_edge); +} + +TEST_F(CollectivePermuteUtilsTest, IsForwardCycle) { + EXPECT_TRUE(IsForwardCycle(fwd2_)); + EXPECT_TRUE(IsForwardCycle(fwd4_)); + + EXPECT_FALSE(IsForwardCycle(bwd2_)); + EXPECT_FALSE(IsForwardCycle(bwd4_)); + + EXPECT_FALSE(SourceTargetPairs::IsForwardCycle( + SourceTargetPairs({{3, 0}}), SourceTargetPairs({{0, 2}, {2, 3}, {3, 0}}))) + << "Skip 1"; +} + +TEST_F(CollectivePermuteUtilsTest, IsBackwardCycle) { + EXPECT_TRUE(IsBackwardCycle(bwd2_)); + EXPECT_TRUE(IsBackwardCycle(bwd4_)); + + EXPECT_FALSE(IsBackwardCycle(fwd2_)); + EXPECT_FALSE(IsBackwardCycle(fwd4_)); +} + +TEST_F(CollectivePermuteUtilsTest, SourceTargetPairsString) { + EXPECT_EQ(fwd2_.cycle.ToString(), "{{0,1},{1,0}}"); + EXPECT_EQ(bwd2_.cycle.ToString(), "{{1,0},{0,1}}"); + EXPECT_EQ(fwd4_.cycle.ToString(), "{{0,1},{1,2},{2,3},{3,0}}"); + EXPECT_EQ(bwd4_.cycle.ToString(), "{{0,3},{1,0},{2,1},{3,2}}"); +} + +} // namespace +} // namespace xla