Skip to content

Commit

Permalink
Extend the shortcut logic for loop promotion with cyclic graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Jan 16, 2025
1 parent 0d0402f commit 6e8db1f
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 19 deletions.
7 changes: 5 additions & 2 deletions csrc/id_model/id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ void IdModel::initializeLoopGraph(const StatefulInliningInfo& info) {
}
}

ValGraph& IdModel::buildLoopGraph() {
ValGraph& IdModel::buildLoopGraph(bool force_full_loop_promotion_analysis) {
// Make sure the depedent graphs are already built
maybeBuildGraph(IdMappingMode::EXACT);
maybeBuildGraph(IdMappingMode::PERMISSIVE);
Expand All @@ -767,7 +767,10 @@ ValGraph& IdModel::buildLoopGraph() {
validateLoopGraphHasNoSelfMappedLeafDomains();

loop_promotion_map_ = LoopPromotionMapBuilder::get(
*this, inlining_info, loop_promotion_map_builder_callback_);
*this,
inlining_info,
loop_promotion_map_builder_callback_,
force_full_loop_promotion_analysis);

// New domains are added. Make sure there's still no self mapping in
// the loop domains
Expand Down
6 changes: 5 additions & 1 deletion csrc/id_model/id_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ class IdModel : public PolymorphicBase {
// Fills disjoint_ids_[IdMappingMode::LOOP]. Map only inlined
// domains that are mapped in the permissive graph. Build the Exact
// and Permissive graphs as well if not yet done.
ValGraph& buildLoopGraph();
//
// (For debugging only) When force_full_loop_promotion_analysis is
// true, it always performs the full loop promotion analysis even
// when it's possible to take a quicker shortcut.
ValGraph& buildLoopGraph(bool force_full_loop_promotion_analysis = false);

// Build a graph. Dependent graphs are also built if not yet done.
void buildGraph(IdMappingMode mode);
Expand Down
55 changes: 42 additions & 13 deletions csrc/id_model/loop_promotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ namespace nvfuser {
LoopPromotionMapBuilder::LoopPromotionMapBuilder(
IdModel& id_model,
const StatefulInliningInfo& inlining_info,
LoopPromotionMapBuilderCallback* callback)
: id_model_(id_model), inlining_info_(inlining_info), callback_(callback) {}
LoopPromotionMapBuilderCallback* callback,
bool force_full_loop_promotion_analysis)
: id_model_(id_model),
inlining_info_(inlining_info),
callback_(callback),
force_full_loop_promotion_analysis_(force_full_loop_promotion_analysis) {}

ValGraph& LoopPromotionMapBuilder::idGraph(IdMappingMode mode) {
return id_model_.idGraph(mode);
Expand Down Expand Up @@ -97,11 +101,12 @@ std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::

namespace {

// Check if all the domains of each loop group are exactly mapped. If
// so, the full promotion analysis should not be necessary. Only the
// Check if each loop group has at most one group of concrete domains. If
// so, the full promotion analysis should not be necessary since
// finding the promotion ID is a trivial probelm. Only the
// loop groups of the loop domains need to be checked as loop
// promotion does not matter for the other domains.
bool isLoopGraphUniform(const IdModel& id_model) {
bool isLoopGraphAlmostUniform(const IdModel& id_model) {
for (const auto tv : id_model.tvs()) {
if (tv->isFusionInput()) {
continue;
Expand All @@ -111,8 +116,22 @@ bool isLoopGraphUniform(const IdModel& id_model) {
id_model.idGraph(IdMappingMode::LOOP).toGroup(loop_id);
const auto all_exact_groups =
id_model.idGraph(IdMappingMode::EXACT).toGroups(*loop_group);
if (all_exact_groups.size() > 1) {
return false;
if (all_exact_groups.size() == 1) {
continue;
}

// Even when multiple exact groups are found, if there's only
// one concrete group and all the others are broadcast, it's
// obvious that the concrete group represents the promotion.
bool concrete_group_found = false;
for (const auto& exact_group : all_exact_groups) {
if (!exact_group->front()->as<IterDomain>()->isBroadcast()) {
if (concrete_group_found) {
// multiple concrete groups
return false;
}
concrete_group_found = true;
}
}
}
}
Expand All @@ -126,8 +145,9 @@ std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::build() {
// Some quick shortcut conditions to skip the full loop promotion
// analysis. These are not comprehensive. Should add more conditions
// if necessary.
if (inlining_info_.p2c_root_broadcast_resolution_map.empty() ||
isLoopGraphUniform(id_model_)) {
if (!force_full_loop_promotion_analysis_ &&
(inlining_info_.p2c_root_broadcast_resolution_map.empty() ||
isLoopGraphAlmostUniform(id_model_))) {
return buildWithNoBroadcast();
}

Expand Down Expand Up @@ -936,8 +956,10 @@ void LoopPromotionMapBuilder::sanityCheckLoopPromotionMap(
std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::get(
IdModel& id_model,
const StatefulInliningInfo& inlining_info,
LoopPromotionMapBuilderCallback* callback) {
LoopPromotionMapBuilder builder(id_model, inlining_info, callback);
LoopPromotionMapBuilderCallback* callback,
bool force_full_loop_promotion_analysis) {
LoopPromotionMapBuilder builder(
id_model, inlining_info, callback, force_full_loop_promotion_analysis);
return builder.build();
}

Expand Down Expand Up @@ -967,14 +989,21 @@ std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::
(int64_t)StmtSort::getExprsTo({loop_id->extent()}).size();
auto this_is_const = loop_id->extent()->isConstInt();

// First ID
if (promotion == nullptr) {
// A group is allowed to have one single exact group of concrete
// IDs with a broadcast group.
if (promotion == nullptr ||
(promotion->isBroadcast() && !loop_id->isBroadcast())) {
is_const = this_is_const;
promotion = loop_id;
num_exprs = this_num_exprs;
continue;
}

// Ignore broadcast if a concrete ID is already found
if (!promotion->isBroadcast() && loop_id->isBroadcast()) {
continue;
}

// If new ID is non-const while the current promotion is const,
// or if both IDs are const or non-const and the number of
// expressions is not smaller, keep the current promotion
Expand Down
15 changes: 13 additions & 2 deletions csrc/id_model/loop_promotion.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,22 @@ class LoopPromotionMapBuilder {
// Build a map of loop groups to IterDomains that represent actual
// loops. The map is built based on the broadcast resolution with
// root domains between inlined producer and consumer tensors.
//
// (For debugging only) When force_full_loop_promotion_analysis is
// true, it always performs the full loop promotion analysis even
// when it's possible to take a quicker shortcut.
static std::unordered_map<ValGroup, IterDomain*> get(
IdModel& id_model,
const StatefulInliningInfo& inlining_info,
LoopPromotionMapBuilderCallback* callback = nullptr);
LoopPromotionMapBuilderCallback* callback = nullptr,
bool force_full_loop_promotion_analysis = false);

private:
LoopPromotionMapBuilder(
IdModel& id_model,
const StatefulInliningInfo& inlining_info,
LoopPromotionMapBuilderCallback* callback = nullptr);
LoopPromotionMapBuilderCallback* callback = nullptr,
bool force_full_loop_promotion_analysis = false);

std::unordered_map<ValGroup, IterDomain*> build();

Expand Down Expand Up @@ -164,6 +170,11 @@ class LoopPromotionMapBuilder {
IdModel& id_model_;
const StatefulInliningInfo& inlining_info_;
LoopPromotionMapBuilderCallback* callback_ = nullptr;

// (For debugging only) When force_full_loop_promotion_analysis_ is
// true, it always performs the full loop promotion analysis even
// when it's possible to take a quicker shortcut.
bool force_full_loop_promotion_analysis_ = false;
};

} // namespace nvfuser
2 changes: 1 addition & 1 deletion tests/cpp/test_id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class IdModelTester : public LoopPromotionMapBuilderCallback {
/*loop_promotion_map_builder_callback=*/this);

// Only build the loop graph
id_model->buildLoopGraph();
id_model->buildLoopGraph(/*force_full_loop_promotion_analysis=*/true);
}

void postStep1(
Expand Down

0 comments on commit 6e8db1f

Please sign in to comment.