Skip to content

Commit

Permalink
make sure only rfactor-affected ids are changed
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Nov 2, 2024
1 parent 654b064 commit 8321cfd
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions csrc/transform_rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,15 @@ class ReplayRFactor : public ReplayTransformations {
bool static_logical_outputs = static_logical_ids_.count(s->outer()) ||
static_logical_ids_.count(s->inner());

auto outer_iter_type = rfactor_axes_.count(s->outer())
? IterType::Reduction
: IterType::Iteration;
auto inner_iter_type = rfactor_axes_.count(s->inner())
? IterType::Reduction
: IterType::Iteration;
std::optional<IterType> outer_iter_type;
if (s->outer()->isReduction() && !rfactor_dep_ids_.count(s->outer())) {
outer_iter_type = IterType::Iteration;
}

std::optional<IterType> inner_iter_type;
if (s->inner()->isReduction() && !rfactor_dep_ids_.count(s->inner())) {
inner_iter_type = IterType::Iteration;
}

auto [ido, idi] = IterDomain::split(
mapped,
Expand Down Expand Up @@ -171,9 +174,7 @@ class ReplayRFactor : public ReplayTransformations {
// when m->out is a reduction domain. If it isn't involved in
// the rfactor, it's no longer a redunction domain
std::optional<IterType> iter_type;
if (std::find(target_domain_.begin(), target_domain_.end(), m->out()) !=
target_domain_.end() &&
m->out()->isReduction() && !rfactor_axes_.count(m->out())) {
if (m->out()->isReduction() && !rfactor_dep_ids_.count(m->out())) {
iter_type = IterType::Iteration;
}

Expand Down Expand Up @@ -219,6 +220,9 @@ class ReplayRFactor : public ReplayTransformations {
// The IterDomains in the original_domain that are being factored into the
// first stage of the two stage reduction (the producer).
std::unordered_set<IterDomain*> rfactor_axes_;
// All iter domains between the logical and the loop that the
// rfactor_axes_ depend on
std::unordered_set<IterDomain*> rfactor_dep_ids_;
// Iter domains whose history cannot be changed as it would break rfactor
// dependencies.
std::unordered_set<IterDomain*> static_logical_ids_;
Expand All @@ -245,6 +249,14 @@ class ReplayRFactor : public ReplayTransformations {
rfactor_axes_(std::move(rfactor_axes)),
static_logical_ids_(std::move(static_logical_ids)),
logical_domain_(original_domain->logical()) {
const auto all_dep_vals = DependencyCheck::getAllValsBetween(
{original_domain->maybeRoot().begin(),
original_domain->maybeRoot().end()},
{rfactor_axes_.begin(), rfactor_axes_.end()});

auto all_dep_ids = ir_utils::filterByType<IterDomain>(all_dep_vals);
rfactor_dep_ids_.insert(all_dep_ids.begin(), all_dep_ids.end());

setErrorOnFailure(false);
}
};
Expand Down

0 comments on commit 8321cfd

Please sign in to comment.