diff --git a/csrc/transform_rfactor.cpp b/csrc/transform_rfactor.cpp index 982f7cc62ec..90a0bac0085 100644 --- a/csrc/transform_rfactor.cpp +++ b/csrc/transform_rfactor.cpp @@ -116,12 +116,15 @@ class ReplayRFactor : public ReplayTransformations { bool static_logical_outputs = static_logical_ids_.count(s->outer()) || static_logical_ids_.count(s->inner()); - auto outer_iter_type = rfactor_axes_.count(s->outer()) - ? IterType::Reduction - : IterType::Iteration; - auto inner_iter_type = rfactor_axes_.count(s->inner()) - ? IterType::Reduction - : IterType::Iteration; + std::optional outer_iter_type; + if (s->outer()->isReduction() && !rfactor_dep_ids_.count(s->outer())) { + outer_iter_type = IterType::Iteration; + } + + std::optional inner_iter_type; + if (s->inner()->isReduction() && !rfactor_dep_ids_.count(s->inner())) { + inner_iter_type = IterType::Iteration; + } auto [ido, idi] = IterDomain::split( mapped, @@ -171,9 +174,7 @@ class ReplayRFactor : public ReplayTransformations { // when m->out is a reduction domain. If it isn't involved in // the rfactor, it's no longer a redunction domain std::optional iter_type; - if (std::find(target_domain_.begin(), target_domain_.end(), m->out()) != - target_domain_.end() && - m->out()->isReduction() && !rfactor_axes_.count(m->out())) { + if (m->out()->isReduction() && !rfactor_dep_ids_.count(m->out())) { iter_type = IterType::Iteration; } @@ -219,6 +220,9 @@ class ReplayRFactor : public ReplayTransformations { // The IterDomains in the original_domain that are being factored into the // first stage of the two stage reduction (the producer). std::unordered_set rfactor_axes_; + // All iter domains between the logical and the loop that the + // rfactor_axes_ depend on + std::unordered_set rfactor_dep_ids_; // Iter domains whose history cannot be changed as it would break rfactor // dependencies. std::unordered_set static_logical_ids_; @@ -245,6 +249,14 @@ class ReplayRFactor : public ReplayTransformations { rfactor_axes_(std::move(rfactor_axes)), static_logical_ids_(std::move(static_logical_ids)), logical_domain_(original_domain->logical()) { + const auto all_dep_vals = DependencyCheck::getAllValsBetween( + {original_domain->maybeRoot().begin(), + original_domain->maybeRoot().end()}, + {rfactor_axes_.begin(), rfactor_axes_.end()}); + + auto all_dep_ids = ir_utils::filterByType(all_dep_vals); + rfactor_dep_ids_.insert(all_dep_ids.begin(), all_dep_ids.end()); + setErrorOnFailure(false); } };