Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge up-cast, ops, down-cast sequences as minimal units of segments #3699

Merged
merged 8 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3586,6 +3586,194 @@ bool CombineReductions::shouldRun(
return false;
}

// This preprocessing attempts to find groups of exprs consist of an
// up-cast, followed by some ops and ended by a downcast. It is highly
// likely that such sequences of ops should never be segmented
// out. This is particularly commonly seen in fusions given by Thunder
// as it inserts fine-grained downcasting and upcasting ops. Without
// this preprocessing, a fusion may be segmented right after an
// up-cast op, for example, and in fact it happened quite frequently
// in some of the RoPE cases. This preprocessing does not completely
// avoid such segmentation boundaries, but it should become less
// likely. See also https://github.com/NVIDIA/Fuser/pull/3699.
class MergeUpAndDownCast {
public:
static void run(SegmentCandidateFinder* segment_candidate_finder) {
MergeUpAndDownCast group_cast(segment_candidate_finder);
}

private:
MergeUpAndDownCast(SegmentCandidateFinder* segment_candidate_finder)
: segment_candidate_finder_(segment_candidate_finder) {
merge();
}

void merge() {
bool merged = true;
while (merged) {
merged = false;
std::unordered_set<SegmentedGroup*> considered_groups;

for (SegmentedGroup* group : segment_candidate_finder_->groups()) {
// If the group is an up-cast group, see if there's a
// candidate group starting with the group.
if (!isUpCast(group) || considered_groups.count(group)) {
continue;
}

auto groups_to_merge = getCandidateCastGroup(group);
if (groups_to_merge.size() < 2) {
continue;
}

for (auto group : groups_to_merge) {
considered_groups.insert(group);
}

// Try merging the detected group
if (mergeCastGroup(groups_to_merge)) {
merged = true;
break;
}
}
}
}

// Try to detect a set of groups that could be merged as a cast
// group. The analysis starts with an initial group that solely
// consists of an up-cast expression. From the initial group, it
// traverses its neighbor groups. If the group is an down-cast group,
// it only traverses through the consumer edges. If it's an up-cast
// group, it only traverses through the producer edges.
//
// Additionaly, this traversal has several safeguards to keep the
// DAG property intact:
//
// - For a given group, it does not visit its consumers if it has
// multiple consumers, even if the group is not a down-cast
// group.
// - Similarly, it does not visit a producer if the producer has
// multiple cosumers.
//
// The basic form of this set of groups should look like an up-cast
// group, followed by some op groups and ended by a down-cast
// group. However, it is not always the case because of the above
// safeguards. For example, the following groups would be detected
// as a cast group.
//
// t1 = bf16ToFp32(t0)
// t2 = neg(t1)
// t3 = sin(t2)
// t4 = cos(t2)
// t5 = fp32ToBf16(t3)
// t6 = fp32ToBf16(t4)
//
// In this case, t1 and t2 would be detected as a candidate group,
// but t3 and t4 would not be included. While we could certainly
// extend the analysis, it would need to make sure the DAG property
// is not violated.
std::vector<SegmentedGroup*> getCandidateCastGroup(
SegmentedGroup* initial_group) {
std::vector<SegmentedGroup*> groups_to_merge;
std::unordered_set<SegmentedGroup*> groups_to_merge_set;

std::deque<SegmentedGroup*> to_visit;
to_visit.push_back(initial_group);

while (!to_visit.empty()) {
SegmentedGroup* group = to_visit.front();
to_visit.pop_front();

if (groups_to_merge_set.count(group)) {
continue;
}

// For simplicity, all groups are assumed to be the initial
// single-expr groups. Skip if not

groups_to_merge.push_back(group);
groups_to_merge_set.insert(group);

// Consumer traversal. Stop if this group is a down cast
// group. Also stop if there are multiple consumer edges to
// simplify keeping the DAG property.
if (!isDownCast(group) && group->consumer_edges.size() == 1) {
auto consumer_edge = group->consumer_edges.at(0);
SegmentedGroup* consumer_group = consumer_edge->to;
if (!groups_to_merge_set.count(consumer_group)) {
to_visit.push_back(consumer_group);
}
}

if (!isUpCast(group)) {
for (const auto producer_edge : group->producer_edges) {
SegmentedGroup* producer_group = producer_edge->from;
// Don't add producers that have more than multiple consumers
if (producer_group->consumer_edges.size() > 1) {
continue;
}
if (!groups_to_merge_set.count(producer_group)) {
to_visit.push_back(producer_group);
}
}
}
}

return groups_to_merge;
}

// Try merging a candidate cast group. Return true if merged.
bool mergeCastGroup(const std::vector<SegmentedGroup*>& groups) {
auto sched_type = tryMerge(
segment_candidate_finder_->segmented_fusion_.get(),
segment_candidate_finder_->runtimeInfo(),
groups);

if (sched_type == SchedulerType::None) {
return false;
}

segment_candidate_finder_->mergeAllGivenGroups(groups);

return true;
}

bool isUpCast(SegmentedGroup* group) const {
if (auto precision_bits = getProducerConsumerPrecision(group);
precision_bits.has_value()) {
return precision_bits->first < precision_bits->second;
} else {
return false;
}
}

bool isDownCast(SegmentedGroup* group) const {
if (auto precision_bits = getProducerConsumerPrecision(group);
precision_bits.has_value()) {
return precision_bits->first > precision_bits->second;
} else {
return false;
}
}

std::optional<std::pair<int64_t, int64_t>> getProducerConsumerPrecision(
SegmentedGroup* group) const {
if (group->exprs().size() != 1) {
return std::nullopt;
}

auto uop = dynamic_cast<UnaryOp*>(group->exprs().front());
if (uop == nullptr || uop->getUnaryOpType() != UnaryOpType::Cast) {
return std::nullopt;
}

return ir_utils::getPrecisionOfProducerConsumerTensors(uop);
}

private:
SegmentCandidateFinder* segment_candidate_finder_ = nullptr;
};

namespace {

//! Returns true if group1 and group2 are an immediate producer-consumer pair.
Expand Down Expand Up @@ -3945,6 +4133,9 @@ void SegmentCandidateFinder::findSegments() {
removeScalarEdges();

// Run pre-merge heuristics
MergeUpAndDownCast::run(this);
segmented_fusion_->validateIfDebug(true);

if (options_.run_combine_reductions && CombineReductions::shouldRun(this)) {
CombineReductions::run(this);
}
Expand Down
2 changes: 2 additions & 0 deletions csrc/fusion_segmenter.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ class GroupDependencyAnalysis;

// Manual node merging passes
class CombineReductions;
class MergeUpAndDownCast;

//! Options to configure/debug candidate finder
struct SegmentCandidateFinderOptions {
Expand Down Expand Up @@ -691,6 +692,7 @@ class SegmentCandidateFinder {
//! eventually should have a dedicated interface
//! instead of keeping adding friends
friend class CombineReductions;
friend class MergeUpAndDownCast;

//! options to configure and debug the segment process
SegmentCandidateFinderOptions options_;
Expand Down
28 changes: 28 additions & 0 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1524,4 +1524,32 @@ std::vector<IterDomain*> strideOrderToAllocation(
return allocation_domain;
}

std::optional<std::pair<int64_t, int64_t>> getPrecisionOfProducerConsumerTensors(
UnaryOp* uop) {
NVF_CHECK(
uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Cast,
"Invalid expr: ",
uop->toString());

auto inp_tv = ir_utils::getTvInput(uop);
auto out_tv = ir_utils::getTvOutput(uop);
if (inp_tv == nullptr || out_tv == nullptr) {
return std::nullopt;
}

auto inp_dtype = inp_tv->dtype().type;
auto out_dtype = out_tv->dtype().type;
auto inp_prim_type = std::get_if<PrimDataType>(&inp_dtype);
auto out_prim_type = std::get_if<PrimDataType>(&out_dtype);

if (inp_prim_type == nullptr || out_prim_type == nullptr ||
*inp_prim_type == PrimDataType::Index ||
*out_prim_type == PrimDataType::Index) {
return std::nullopt;
}

return std::make_pair(
primDataTypeSize(*inp_prim_type), primDataTypeSize(*out_prim_type));
}

} // namespace nvfuser::ir_utils
5 changes: 5 additions & 0 deletions csrc/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -803,4 +803,9 @@ std::vector<IterDomain*> strideOrderToAllocation(
const std::vector<IterDomain*>& logical_domain,
const std::vector<int64_t>& stride_order);

// Returns the number of bytes of data types of the producer and
// consumer tensors of a cast unary op
std::optional<std::pair<int64_t, int64_t>> getPrecisionOfProducerConsumerTensors(
UnaryOp* cast_op);

} // namespace nvfuser::ir_utils
35 changes: 35 additions & 0 deletions tests/cpp/test_gpu3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9349,6 +9349,41 @@ TEST_F(NVFuserTest, RepeatBroadcastAndNonBroadcast) {
testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, CastPrecision) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(2, DataType::Half);
fusion.addInput(tv0);

auto tv1 = castOp(DataType::Float, tv0);
auto tv2 = castOp(DataType::BFloat16, tv1);
fusion.addOutput(tv2);

auto tv3 = makeSymbolicTensor(2, DataType::Index);
fusion.addInput(tv3);

auto tv4 = castOp(DataType::Int, tv3);
fusion.addOutput(tv4);

auto tv1_precision = ir_utils::getPrecisionOfProducerConsumerTensors(
tv1->definition()->as<UnaryOp>());
ASSERT_TRUE(tv1_precision.has_value());
EXPECT_EQ(tv1_precision->first, 2);
EXPECT_EQ(tv1_precision->second, 4);

auto tv2_precision = ir_utils::getPrecisionOfProducerConsumerTensors(
tv2->definition()->as<UnaryOp>());
ASSERT_TRUE(tv2_precision.has_value());
EXPECT_EQ(tv2_precision->first, 4);
EXPECT_EQ(tv2_precision->second, 2);

// Precision of type Index is not possible to determine until lowering
auto tv4_precision = ir_utils::getPrecisionOfProducerConsumerTensors(
tv4->definition()->as<UnaryOp>());
ASSERT_FALSE(tv4_precision.has_value());
}

// Test file size should be up to 10K LoC. Create a new file for more tests.

} // namespace nvfuser
Loading