Skip to content

Commit

Permalink
In the permissive bfs traversal, don't allow reverse traversal
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Jan 16, 2025
1 parent ef6f169 commit bdf526f
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 4 deletions.
71 changes: 67 additions & 4 deletions csrc/bfs.h
Original file line number Diff line number Diff line change
Expand Up @@ -561,9 +561,8 @@ template <
class BFSWithPermissiveDependence
: public BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT> {
public:
using NodeType =
typename BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT>::
NodeType;
using BFSBaseType = BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT>;
using NodeType = typename BFSBaseType::NodeType;

BFSWithPermissiveDependence(
DefinitionT definition,
Expand All @@ -574,7 +573,7 @@ class BFSWithPermissiveDependence
std::vector<NodeType> to,
bool require_all_to_visited = true,
Direction allowed_direction = Direction::Undefined)
: BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT>(
: BFSBaseType(
definition,
uses,
inputs,
Expand Down Expand Up @@ -618,6 +617,70 @@ class BFSWithPermissiveDependence
}
return std::nullopt;
}

// When adding new neighbors of an expr node, if any of inputs is
// the previous node of this expr, then don't add the remaining
// inputs to the to-visit list. Similary, if any of the outputs is
// the previous node of this expr, don't add the remaining
// outputs. See BFSTest.IRBFSPermissiveTraversal2 for a concrete
// example.
void addNewNeighbors(const NodeType& node) override {
const ExprT* e = std::get_if<ExprT>(&node);
if (e == nullptr) {
BFSBaseType::addNewNeighbors(node);
return;
}

auto add_to_visit_list = [&](const NodeType& n) -> void {
if (this->isVisited(n) || this->excludeFromTraversal(n)) {
return;
}
this->to_visit_.emplace_back(n);
};

auto prev_nodes_it = this->prev_nodes_.find(node);

auto is_any_already_visited = [&](const auto& inputs_or_outputs) -> bool {
if (prev_nodes_it == this->prev_nodes_.end()) {
return false;
}

const std::vector<NodeType>& prev_nodes = prev_nodes_it->second.second;

return std::any_of(
inputs_or_outputs.begin(),
inputs_or_outputs.end(),
[&](const auto& input_or_output) {
return std::find(
prev_nodes.begin(),
prev_nodes.end(),
NodeType(input_or_output)) != prev_nodes.end();
});
};

if (this->allowed_direction_ == Direction::Backward ||
this->allowed_direction_ == Direction::Undefined) {
// There's an input node that is marked as a previous node of
// this node. Since this is permissive traversal, some of the
// other inputs may not be visited yet, but going back to
// the input nodes doesn't seem to make sense
auto input_nodes = this->inputs_(*e);
if (!is_any_already_visited(input_nodes)) {
for (const auto& v : input_nodes) {
add_to_visit_list(v);
}
}
}
if (this->allowed_direction_ == Direction::Forward ||
this->allowed_direction_ == Direction::Undefined) {
auto output_nodes = this->outputs_(*e);
if (!is_any_already_visited(output_nodes)) {
for (const auto& v : output_nodes) {
add_to_visit_list(v);
}
}
}
}
};

// Find the shortest path from the from vals to the to
Expand Down
41 changes: 41 additions & 0 deletions tests/cpp/test_bfs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,4 +537,45 @@ TEST_F(BFSTest, IRBFSPermissiveTraversal) {
}
}

TEST_F(BFSTest, IRBFSPermissiveTraversal2) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);

auto tv1 = set(tv0);
fusion.addOutput(tv1);

tv1->merge(0);
tv1->split(0, 4);

// T1_g_float[iS5{( ceilDiv(( i0 * i2 ), 4) )}, iS6{4}]
// logical domain : (iS2{i0}, iS3{i2})
// contiguity: t t
// Merge: iS2{i0} and iS3{i2} -> iS4{( i0 * i2 )}
// Split: iS4{( i0 * i2 )} by factor 4 -> iS5{( ceilDiv(( i0 * i2 ), 4) )},
// iS6{4}
// loop domain : (iS5{( ceilDiv(( i0 * i2 ), 4) )}, iS6{4})
fusion.print();

auto iS5 = tv1->axis(0);
auto iS6 = tv1->axis(1);

// When starting with just iS5 witout iS6, the permissive traversal
// allows to visit the split expr node, even though iS6 is
// missing. The next set of nodes to visit after the split are its
// neighbors, which includes iS6. However, it does not seem to make
// any intuitive sense to allow this visit. The split expr is visited
// because one of its outputs, iS5, is visited. That in turn allowing to
// visit the missing split output, iS6, does not seem to make sense.

// Make sure iS6 is not reachable from iS5
EXPECT_FALSE(getExprsBetween<IRPermissiveBFS>(
{iS5},
{iS6},
/*require_all_to_visited=*/false)
.second);
}

} // namespace nvfuser

0 comments on commit bdf526f

Please sign in to comment.