Skip to content

Commit

Permalink
Modernize and make tighter CollectivePermuteDecomposerTest
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708462256
  • Loading branch information
toli-y authored and Google-ML-Automation committed Dec 21, 2024
1 parent dc700af commit d29d8ea
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 110 deletions.
7 changes: 4 additions & 3 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ cc_library(
"//xla/hlo/pass:hlo_pass",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/graphcycles",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand All @@ -283,14 +284,14 @@ xla_cc_test(
":collective_ops_utils",
":collective_permute_decomposer",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/hlo/utils:hlo_matchers",
"//xla/hlo/utils:hlo_query",
"//xla/service/gpu:backend_configs_cc",
"//xla/tests:hlo_test_base",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test_main",
],
)
Expand Down
5 changes: 5 additions & 0 deletions xla/service/collective_permute_decomposer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ limitations under the License.
#ifndef XLA_SERVICE_COLLECTIVE_PERMUTE_DECOMPOSER_H_
#define XLA_SERVICE_COLLECTIVE_PERMUTE_DECOMPOSER_H_

#include <cstdint>

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h"

Expand Down
167 changes: 60 additions & 107 deletions xla/service/collective_permute_decomposer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,72 +23,54 @@ limitations under the License.
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/hlo/utils/hlo_matchers.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/tests/hlo_test_base.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/platform/statusor.h"

namespace xla {
namespace {

using ::testing::ElementsAre;
using ::testing::HasSubstr;
namespace op = xla::testing::opcode_matchers;
using CollectivePermuteDecomposerTest = HloTestBase;

TEST_F(CollectivePermuteDecomposerTest, WithCycleNotTransformed) {
const absl::string_view kModuleStr = R"(
HloModule test
ENTRY test_computation {
p = u32[] replica-id()
ROOT cp = u32[] collective-permute(p), channel_id=1,
source_target_pairs={{0,1}, {1,0}}
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((kModuleStr)));
CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_FALSE(changed);
}

TEST_F(CollectivePermuteDecomposerTest, WithContextDataNotTransformed) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
p = u32[] replica-id()
ROOT cp = (u32[], u32[], u32[], u32[]) collective-permute(p), channel_id=1,
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}
}
)";
namespace op = xla::testing::opcode_matchers;
using Pass = CollectivePermuteDecomposer;

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((kModuleStr)));
CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_FALSE(changed);
class DecomposerTest : public HloHardwareIndependentTestBase {
protected:
void AssertNoTranform(absl::string_view hlo) {
TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Pass(0), false));
};
auto Transform(absl::string_view hlo) {
return RunAndCheckHloRewrite(hlo, Pass(0), true);
};
};

TEST_F(DecomposerTest, WithCycleNotTransformed) {
AssertNoTranform(R"(HloModule test
ENTRY test_computation {
p = u32[] replica-id()
ROOT cp = u32[] collective-permute(p), channel_id=1,
source_target_pairs={{0,1}, {1,0}}
}
)");
}

TEST_F(CollectivePermuteDecomposerTest, TransformedExplicitChannelId) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
p = u32[] replica-id()
ROOT cp = u32[] collective-permute(p), channel_id=1,
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}},
metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
}
TEST_F(DecomposerTest, TransformedExplicitChannelId) {
absl::string_view hlo = R"(
HloModule test
ENTRY test_computation {
p = u32[] replica-id()
ROOT cp = u32[] collective-permute(p), channel_id=1,
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}},
metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((kModuleStr)));
CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, Transform(hlo));

auto check_metadata = [](const HloInstruction* inst) {
EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add");
Expand Down Expand Up @@ -131,8 +113,8 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedExplicitChannelId) {
EXPECT_THAT(root, op::GetTupleElement(recv_done, 0));
}

TEST_F(CollectivePermuteDecomposerTest, NotTransformedDefaultChannelId) {
const char* const kModuleStr = R"(
TEST_F(DecomposerTest, TransformedDefaultNoChannelId) {
absl::string_view hlo = R"(
HloModule test
ENTRY test_computation {
p = u32[] replica-id()
Expand All @@ -141,11 +123,7 @@ TEST_F(CollectivePermuteDecomposerTest, NotTransformedDefaultChannelId) {
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((kModuleStr)));
CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, Transform(hlo));

HloInstruction* after_all = FindInstruction(module.get(), "after-all");
HloInstruction* recv = FindInstruction(module.get(), "recv");
Expand All @@ -172,26 +150,20 @@ TEST_F(CollectivePermuteDecomposerTest, NotTransformedDefaultChannelId) {
EXPECT_THAT(root, op::GetTupleElement(recv_done, 0));
}

TEST_F(CollectivePermuteDecomposerTest, ThresholdNotTransformed) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
p = u32[] replica-id()
ROOT cp = u32[] collective-permute(p), channel_id=1,
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}},
metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((kModuleStr)));
CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/8);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_FALSE(changed);
TEST_F(DecomposerTest, ThresholdNotTransformed) {
absl::string_view hlo = R"(HloModule test
ENTRY test_computation {
p = u32[] replica-id()
ROOT cp = u32[] collective-permute(p), channel_id=1,
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}},
metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
})";
TF_ASSERT_OK(
RunAndCheckHloRewrite(hlo, Pass(/*threshold_in_bytes=*/8), false));
}

TEST_F(CollectivePermuteDecomposerTest, Pipeline1) {
const char* const kModuleStr = R"(
TEST_F(DecomposerTest, Pipeline1) {
absl::string_view hlo = R"(
HloModule module
cond {
param = (u32[], u32[2]) parameter(0)
Expand Down Expand Up @@ -229,11 +201,7 @@ TEST_F(CollectivePermuteDecomposerTest, Pipeline1) {
ROOT result = u32[2] get-tuple-element(while_result), index=1
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((kModuleStr)));
CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, Transform(hlo));
HloInstruction* recv = FindInstruction(module.get(), "recv");
EXPECT_EQ(recv->channel_id().value(), 1);
EXPECT_THAT(
Expand Down Expand Up @@ -262,7 +230,7 @@ TEST_F(CollectivePermuteDecomposerTest, Pipeline1) {
EXPECT_THAT(recv_done->control_predecessors(), ElementsAre(send));
}

TEST_F(CollectivePermuteDecomposerTest, ForwardPipeline2) {
TEST_F(DecomposerTest, ForwardPipeline2) {
const char* const kModuleStr = R"(
HloModule module
cond {
Expand Down Expand Up @@ -310,10 +278,8 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipeline2) {
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((kModuleStr)));
CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);
Transform(kModuleStr));

HloInstruction* recv = FindInstruction(module.get(), "recv");
EXPECT_EQ(recv->channel_id().value(), 1);
EXPECT_THAT(recv->ToString(),
Expand Down Expand Up @@ -347,7 +313,7 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipeline2) {
EXPECT_THAT(send1->control_predecessors(), ElementsAre(recv1));
}

TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) {
TEST_F(DecomposerTest, ForwardPipelineWithMatmul) {
// The HLO module below is generated by passing the HLO in
// CollectiveOpsTest.CollectivePermute_CircularPipelinePreOptimization through
// the collective_permute_cycle_decomposer.transformation.
Expand Down Expand Up @@ -401,10 +367,7 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((kModuleStr)));
CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);
Transform(kModuleStr));
HloModule* transformed_module = module.get();
// Check the annotations and ordering of the decomposed send-recv pairs.
// We expect the recv to come before the send in the while body, both for the
Expand Down Expand Up @@ -458,8 +421,8 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) {
EXPECT_THAT(recv_done_bwd->control_predecessors(), ElementsAre(send_fwd));
}

TEST_F(CollectivePermuteDecomposerTest, BackwardPipeline2) {
const char* const kModuleStr = R"(
TEST_F(DecomposerTest, BackwardPipeline2) {
absl::string_view hlo = R"(
HloModule module
cond {
param = (u32[], u32[2]) parameter(0)
Expand Down Expand Up @@ -505,11 +468,7 @@ TEST_F(CollectivePermuteDecomposerTest, BackwardPipeline2) {
ROOT result = u32[2] get-tuple-element(while_result), index=1
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((kModuleStr)));
CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, Transform(hlo));
HloInstruction* recv = FindInstruction(module.get(), "recv");
EXPECT_EQ(recv->channel_id().value(), 1);
EXPECT_THAT(
Expand Down Expand Up @@ -537,22 +496,16 @@ TEST_F(CollectivePermuteDecomposerTest, BackwardPipeline2) {
EXPECT_THAT(send->control_predecessors(), ElementsAre(recv));
}

TEST_F(CollectivePermuteDecomposerTest,
DecomposeCrossReplicaCollectivePermute) {
const char* const kModuleStr = R"(
TEST_F(DecomposerTest, DecomposeCrossReplicaCollectivePermute) {
absl::string_view hlo = R"(
HloModule module
ENTRY body {
data = f32[16] parameter(0)
ROOT data_ = f32[16] collective-permute(data),
source_target_pairs={{0,1}, {1,2}, {2,3}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((kModuleStr)));

CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, Transform(hlo));

HloComputation* comp = module->entry_computation();
HloInstruction* root = comp->root_instruction();
Expand Down

0 comments on commit d29d8ea

Please sign in to comment.