Skip to content

Commit

Permalink
PR #21511: Fix bitcast transposes in layout normalization.
Browse files Browse the repository at this point in the history
Imported from GitHub PR #21511

Currently, transposes that are bitcasts are converted to a bitcast that does not satisfy the invariants of the layout normalization pass.

As far as I can tell, the special handling of bitcast transposes does nothing useful, so we can simply remove it.

While we're here, we can stop emitting identity transposes.

This fixes jax-ml/jax#25759.
Copybara import of the project:

--
90aab32 by Johannes Reifferscheid <[email protected]>:

Fix bitcast transposes in layout normalization.

Currently, transposes that are bitcasts are converted to a bitcast
that does not satisfy the invariants of the layout normalization
pass.

As far as I can tell, the special handling of bitcast transposes
does nothing useful, so we can simply remove it.

While we're here, we can stop emitting identity transposes.

Merging this change closes #21511

COPYBARA_INTEGRATE_REVIEW=#21511 from jreiffers:main 90aab32
PiperOrigin-RevId: 716193335
  • Loading branch information
jreiffers authored and Google-ML-Automation committed Jan 16, 2025
1 parent 93b15af commit 1149399
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 24 deletions.
39 changes: 17 additions & 22 deletions xla/service/layout_normalization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -515,15 +515,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor {
return absl::OkStatus();
}

// For bitcasting transposes, converts:
//
// A{I} -> bitcast[S]{L} -> transpose{L2}
//
// Into:
//
// A{I} -> bitcast{L2}
//
// For non-bitcasting ones, converts:
// Converts:
//
// A{I} -> bitcast[S0]{L} -> transpose[S]{L2}
//
Expand All @@ -547,25 +539,28 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor {
auto normalized_shape = Normalize(s);
VLOG(3) << "Input transpose: " << hlo->ToString();

if (!ShapeUtil::TransposeIsBitcast(s, operand_s, hlo->dimensions())) {
auto l0_perm =
InversePermutation(ToTransposeDimensions(operand_s.layout()));
auto l_perm = ToTransposeDimensions(s.layout());
auto l0_perm =
InversePermutation(ToTransposeDimensions(operand_s.layout()));
auto l_perm = ToTransposeDimensions(s.layout());

auto t = ComposePermutations(l0_perm, hlo->dimensions());
auto dimensions = ComposePermutations(t, l_perm);
HloInstruction* normalized_transpose;

auto t = ComposePermutations(l0_perm, hlo->dimensions());
auto dimensions = ComposePermutations(t, l_perm);
auto normalized_transpose = hlo->AddInstruction(
if (IsIdentityPermutation(dimensions)) {
// If we're dealing with an identity transposition, there's no need to
// actually create the transpose.
normalized_transpose = a0;
} else {
normalized_transpose = hlo->AddInstruction(
HloInstruction::CreateTranspose(normalized_shape, a0, dimensions));
SetVisited(*normalized_transpose);
VLOG(3) << "Generated normalized physical transpose: "
<< normalized_transpose->ToString();
auto bc_to_orig = MakeBitcastHlo(normalized_transpose, s);
TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
} else {
auto bc_to_orig = MakeBitcastHlo(a0, s, &hlo->metadata());
TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
}
return absl::OkStatus();

auto bc_to_orig = MakeBitcastHlo(normalized_transpose, s);
return ReplaceInstruction(hlo, bc_to_orig);
}

// Converts a purely physical copy into a physical+logical transposition.
Expand Down
20 changes: 18 additions & 2 deletions xla/service/layout_normalization_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,7 @@ ENTRY main {

CheckLayoutNormalization(hlo, R"(
// CHECK: [[bitcast_0:%[^ ]+]] = f32[1,5,4,3]{3,2,1,0} bitcast([[p_1:%[^ ]+]])
// CHECK: [[transpose_2:%[^ ]+]] = f32[1,5,4,3]{3,2,1,0} transpose([[bitcast_0]]), dimensions={0,1,2,3}
// CHECK: [[abs_3:%[^ ]+]] = f32[1,5,4,3]{3,2,1,0} abs([[transpose_2]])
// CHECK: [[abs_3:%[^ ]+]] = f32[1,5,4,3]{3,2,1,0} abs([[bitcast_0]])
)");
}

Expand Down Expand Up @@ -937,5 +936,22 @@ ENTRY main {
)");
}

TEST_F(LayoutNormalizationTest, RegressionJaxB25759) {
const char* hlo = R"(
HloModule repro
ENTRY main {
p0 = f32[2,3,2,2]{2,1,3,0} parameter(0)
p1 = f32[2,2,2,3] parameter(1)
transpose = f32[2,3,2,2]{2,1,3,0} transpose(p1), dimensions={0,3,1,2}
ROOT multiply = f32[2,3,2,2]{2,1,3,0} multiply(p0, transpose)
})";

CheckLayoutNormalization(hlo, R"(
// CHECK: %[[TRANSPOSE:.*]] = f32[2,2,3,2]{3,2,1,0} transpose
// CHECK: multiply({{.*}}, %[[TRANSPOSE]])
)");
}

} // namespace
} // namespace xla

0 comments on commit 1149399

Please sign in to comment.