diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 075987d16fe..12f84e2d163 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -754,7 +754,7 @@ void IdModel::initializeLoopGraph(const StatefulInliningInfo& info) { } } -ValGraph& IdModel::buildLoopGraph() { +ValGraph& IdModel::buildLoopGraph(bool force_full_loop_promotion_analysis) { // Make sure the depedent graphs are already built maybeBuildGraph(IdMappingMode::EXACT); maybeBuildGraph(IdMappingMode::PERMISSIVE); @@ -767,7 +767,10 @@ ValGraph& IdModel::buildLoopGraph() { validateLoopGraphHasNoSelfMappedLeafDomains(); loop_promotion_map_ = LoopPromotionMapBuilder::get( - *this, inlining_info, loop_promotion_map_builder_callback_); + *this, + inlining_info, + loop_promotion_map_builder_callback_, + force_full_loop_promotion_analysis); // New domains are added. Make sure there's still no self mapping in // the loop domains diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 3708fa942bf..32c206dda6d 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -198,7 +198,11 @@ class IdModel : public PolymorphicBase { // Fills disjoint_ids_[IdMappingMode::LOOP]. Map only inlined // domains that are mapped in the permissive graph. Build the Exact // and Permissive graphs as well if not yet done. - ValGraph& buildLoopGraph(); + // + // (For debugging only) When force_full_loop_promotion_analysis is + // true, it always performs the full loop promotion analysis even + // when it's possible to take a quicker shortcut. + ValGraph& buildLoopGraph(bool force_full_loop_promotion_analysis = false); // Build a graph. Dependent graphs are also built if not yet done. void buildGraph(IdMappingMode mode); diff --git a/csrc/id_model/loop_promotion.cpp b/csrc/id_model/loop_promotion.cpp index 1c055943eda..08ae225e6eb 100644 --- a/csrc/id_model/loop_promotion.cpp +++ b/csrc/id_model/loop_promotion.cpp @@ -17,8 +17,12 @@ namespace nvfuser { LoopPromotionMapBuilder::LoopPromotionMapBuilder( IdModel& id_model, const StatefulInliningInfo& inlining_info, - LoopPromotionMapBuilderCallback* callback) - : id_model_(id_model), inlining_info_(inlining_info), callback_(callback) {} + LoopPromotionMapBuilderCallback* callback, + bool force_full_loop_promotion_analysis) + : id_model_(id_model), + inlining_info_(inlining_info), + callback_(callback), + force_full_loop_promotion_analysis_(force_full_loop_promotion_analysis) {} ValGraph& LoopPromotionMapBuilder::idGraph(IdMappingMode mode) { return id_model_.idGraph(mode); @@ -97,11 +101,12 @@ std::unordered_map LoopPromotionMapBuilder:: namespace { -// Check if all the domains of each loop group are exactly mapped. If -// so, the full promotion analysis should not be necessary. Only the +// Check if each loop group has at most one group of concrete domains. If +// so, the full promotion analysis should not be necessary since +// finding the promotion ID is a trivial probelm. Only the // loop groups of the loop domains need to be checked as loop // promotion does not matter for the other domains. -bool isLoopGraphUniform(const IdModel& id_model) { +bool isLoopGraphAlmostUniform(const IdModel& id_model) { for (const auto tv : id_model.tvs()) { if (tv->isFusionInput()) { continue; @@ -111,8 +116,22 @@ bool isLoopGraphUniform(const IdModel& id_model) { id_model.idGraph(IdMappingMode::LOOP).toGroup(loop_id); const auto all_exact_groups = id_model.idGraph(IdMappingMode::EXACT).toGroups(*loop_group); - if (all_exact_groups.size() > 1) { - return false; + if (all_exact_groups.size() == 1) { + continue; + } + + // Even when multiple exact groups are found, if there's only + // one concrete group and all the others are broadcast, it's + // obvious that the concrete group represents the promotion. + bool concrete_group_found = false; + for (const auto& exact_group : all_exact_groups) { + if (!exact_group->front()->as()->isBroadcast()) { + if (concrete_group_found) { + // multiple concrete groups + return false; + } + concrete_group_found = true; + } } } } @@ -126,8 +145,9 @@ std::unordered_map LoopPromotionMapBuilder::build() { // Some quick shortcut conditions to skip the full loop promotion // analysis. These are not comprehensive. Should add more conditions // if necessary. - if (inlining_info_.p2c_root_broadcast_resolution_map.empty() || - isLoopGraphUniform(id_model_)) { + if (!force_full_loop_promotion_analysis_ && + (inlining_info_.p2c_root_broadcast_resolution_map.empty() || + isLoopGraphAlmostUniform(id_model_))) { return buildWithNoBroadcast(); } @@ -936,8 +956,10 @@ void LoopPromotionMapBuilder::sanityCheckLoopPromotionMap( std::unordered_map LoopPromotionMapBuilder::get( IdModel& id_model, const StatefulInliningInfo& inlining_info, - LoopPromotionMapBuilderCallback* callback) { - LoopPromotionMapBuilder builder(id_model, inlining_info, callback); + LoopPromotionMapBuilderCallback* callback, + bool force_full_loop_promotion_analysis) { + LoopPromotionMapBuilder builder( + id_model, inlining_info, callback, force_full_loop_promotion_analysis); return builder.build(); } @@ -967,14 +989,21 @@ std::unordered_map LoopPromotionMapBuilder:: (int64_t)StmtSort::getExprsTo({loop_id->extent()}).size(); auto this_is_const = loop_id->extent()->isConstInt(); - // First ID - if (promotion == nullptr) { + // A group is allowed to have one single exact group of concrete + // IDs with a broadcast group. + if (promotion == nullptr || + (promotion->isBroadcast() && !loop_id->isBroadcast())) { is_const = this_is_const; promotion = loop_id; num_exprs = this_num_exprs; continue; } + // Ignore broadcast if a concrete ID is already found + if (!promotion->isBroadcast() && loop_id->isBroadcast()) { + continue; + } + // If new ID is non-const while the current promotion is const, // or if both IDs are const or non-const and the number of // expressions is not smaller, keep the current promotion diff --git a/csrc/id_model/loop_promotion.h b/csrc/id_model/loop_promotion.h index 88ff26a5d6f..1c6aa486c97 100644 --- a/csrc/id_model/loop_promotion.h +++ b/csrc/id_model/loop_promotion.h @@ -48,16 +48,22 @@ class LoopPromotionMapBuilder { // Build a map of loop groups to IterDomains that represent actual // loops. The map is built based on the broadcast resolution with // root domains between inlined producer and consumer tensors. + // + // (For debugging only) When force_full_loop_promotion_analysis is + // true, it always performs the full loop promotion analysis even + // when it's possible to take a quicker shortcut. static std::unordered_map get( IdModel& id_model, const StatefulInliningInfo& inlining_info, - LoopPromotionMapBuilderCallback* callback = nullptr); + LoopPromotionMapBuilderCallback* callback = nullptr, + bool force_full_loop_promotion_analysis = false); private: LoopPromotionMapBuilder( IdModel& id_model, const StatefulInliningInfo& inlining_info, - LoopPromotionMapBuilderCallback* callback = nullptr); + LoopPromotionMapBuilderCallback* callback = nullptr, + bool force_full_loop_promotion_analysis = false); std::unordered_map build(); @@ -164,6 +170,11 @@ class LoopPromotionMapBuilder { IdModel& id_model_; const StatefulInliningInfo& inlining_info_; LoopPromotionMapBuilderCallback* callback_ = nullptr; + + // (For debugging only) When force_full_loop_promotion_analysis_ is + // true, it always performs the full loop promotion analysis even + // when it's possible to take a quicker shortcut. + bool force_full_loop_promotion_analysis_ = false; }; } // namespace nvfuser diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index 4acead00286..246e677db35 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -135,7 +135,7 @@ class IdModelTester : public LoopPromotionMapBuilderCallback { /*loop_promotion_map_builder_callback=*/this); // Only build the loop graph - id_model->buildLoopGraph(); + id_model->buildLoopGraph(/*force_full_loop_promotion_analysis=*/true); } void postStep1(