Skip to content

Commit

Permalink
Create a SourceTargetPairs class.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714253785
  • Loading branch information
toli-y authored and Google-ML-Automation committed Jan 17, 2025
1 parent a7e0233 commit fba376d
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 230 deletions.
18 changes: 9 additions & 9 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -256,22 +256,23 @@ cc_library(
)

cc_library(
name = "collective_permute_utils",
srcs = ["collective_permute_utils.cc"],
hdrs = ["collective_permute_utils.h"],
name = "source_target_pairs",
srcs = ["source_target_pairs.cc"],
hdrs = ["source_target_pairs.h"],
deps = [
"//xla/hlo/ir:hlo",
"//xla/service/graphcycles",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
],
)

xla_cc_test(
name = "collective_permute_utils_test",
srcs = ["collective_permute_utils_test.cc"],
name = "source_target_pairs_test",
srcs = ["source_target_pairs_test.cc"],
deps = [
":collective_permute_utils",
":source_target_pairs",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"@com_google_googletest//:gtest_main",
Expand All @@ -284,7 +285,7 @@ cc_library(
hdrs = ["collective_permute_decomposer.h"],
deps = [
":collective_ops_utils",
":collective_permute_utils",
":source_target_pairs",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
Expand All @@ -296,7 +297,6 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
37 changes: 15 additions & 22 deletions xla/service/collective_permute_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,16 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/collective_permute_utils.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/source_target_pairs.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/platform/errors.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {

using SourceTargetPair = std::pair<int64_t, int64_t>;
using SourceTargetPairs = std::vector<SourceTargetPair>;

// Returns true if the CollectivePermute instruction should be transformed
// to Send/Recv. We currently limit the transformation to CollectivePermute
// operations without any cycle in their (source, target) relationship,
Expand All @@ -65,7 +61,8 @@ bool ShouldDecompose(const HloCollectivePermuteInstruction& collective_permute,
if (ShapeUtil::ByteSizeOf(result_shape) < threshold_in_bytes) {
return false;
}
return !cp_utils::HasCycles(collective_permute.source_target_pairs());
return !SourceTargetPairs(collective_permute.source_target_pairs())
.HasCycles();
}

// Returns true for a pipelineable collective-permute. As a simple heuristic,
Expand All @@ -82,7 +79,7 @@ bool MayPipeline(const HloCollectivePermuteInstruction& collective_permute) {
struct DecomposedCp {
HloInstruction* send;
HloInstruction* recv;
SourceTargetPairs source_target_pairs;
std::vector<std::pair<int64_t, int64_t>> source_target_pairs;
};

xla::FrontendAttributes ExtractFrontendAttributes(
Expand All @@ -92,7 +89,7 @@ xla::FrontendAttributes ExtractFrontendAttributes(
attributes.mutable_map()->insert(old_attributes.map().begin(),
old_attributes.map().end());
(*attributes.mutable_map())[kSendRecvSourceTargetPairsAttr] =
cp_utils::SourceTargetPairsString(cp);
SourceTargetPairs(cp.source_target_pairs()).ToString();
return attributes;
}

Expand Down Expand Up @@ -170,21 +167,17 @@ std::optional<std::pair<HloCollectivePermuteInstruction*,
HloCollectivePermuteInstruction*>>
CheckCyclePatterns(HloCollectivePermuteInstruction* cp0,
HloCollectivePermuteInstruction* cp1) {
const SourceTargetPairs& cp0_pairs = cp0->source_target_pairs();
const SourceTargetPairs& cp1_pairs = cp1->source_target_pairs();
if (cp0_pairs.size() == 1) {
if (cp_utils::IsForwardCycle(cp0_pairs.front(), cp1_pairs) ||
cp_utils::IsBackwardCycle(cp0_pairs.front(), cp1_pairs)) {
// cp0 represents the backedge for the cycle.
return std::make_pair(cp0, cp1);
}
const SourceTargetPairs cp0_pairs(cp0->source_target_pairs());
const SourceTargetPairs cp1_pairs(cp1->source_target_pairs());
if (SourceTargetPairs::IsForwardCycle(cp0_pairs, cp1_pairs) ||
SourceTargetPairs::IsBackwardCycle(cp0_pairs, cp1_pairs)) {
// cp0 represents the backedge for the cycle.
return std::make_pair(cp0, cp1);
}
if (cp1_pairs.size() == 1) {
if (cp_utils::IsForwardCycle(cp1_pairs.front(), cp0_pairs) ||
cp_utils::IsBackwardCycle(cp1_pairs.front(), cp0_pairs)) {
// cp1 represents the forward edge for the cycle.
return std::make_pair(cp1, cp0);
}
if (SourceTargetPairs::IsForwardCycle(cp1_pairs, cp0_pairs) ||
SourceTargetPairs::IsBackwardCycle(cp1_pairs, cp0_pairs)) {
// cp1 represents the forward edge for the cycle.
return std::make_pair(cp1, cp0);
}
return std::nullopt;
}
Expand Down
54 changes: 0 additions & 54 deletions xla/service/collective_permute_utils.h

This file was deleted.

107 changes: 0 additions & 107 deletions xla/service/collective_permute_utils_test.cc

This file was deleted.

10 changes: 5 additions & 5 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,14 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/service:collective_ops_utils",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
],
)

Expand All @@ -556,14 +556,14 @@ xla_cc_test(
":collective_select_folder",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:filecheck",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/hlo/utils:hlo_matchers",
"//xla/tests:hlo_test_base",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
6 changes: 3 additions & 3 deletions xla/service/gpu/transforms/collective_select_folder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/comparison_util.h"
Expand All @@ -34,9 +35,8 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/shape_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"

namespace xla {
namespace {
Expand Down
6 changes: 3 additions & 3 deletions xla/service/gpu/transforms/collective_select_folder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,18 @@ limitations under the License.
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/testlib/filecheck.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/hlo/utils/hlo_matchers.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/statusor.h"

namespace xla {
namespace {

namespace op = ::xla::testing::opcode_matchers;
using ::testing::HasSubstr;

class CollectiveSelectFolderTest : public HloTestBase {
class CollectiveSelectFolderTest : public HloHardwareIndependentTestBase {
public:
absl::Status ExpectNoTranform(absl::string_view hlo_template) {
return RunAndCheckHloRewrite(hlo_template, CollectiveSelectFolder(),
Expand Down
Loading

0 comments on commit fba376d

Please sign in to comment.