Skip to content

Commit

Permalink
Scheduling ending repeat (#3716)
Browse files Browse the repository at this point in the history
In RoPE, it seems common to have repeat ops at the end of fusions. For
example, `k` and `v` are repeated in the Mistral case as shown
[here](https://github.com/NVIDIA/Fuser/blob/main/benchmarks/python/rope_ops.py#L914-L915).

This can be problematic in terms of performance since we typically
choose to use a fusion output as the reference of scheduling. Suppose
the final repeat is done for a factor of 4, the output tensor is 4x
larger than the size of the actual computed tensors. Since we apply
scheduling to the 4x larger reference, this could mean, for example, we
would launch 4x larger number of threads or blocks, and there would be
redundancy for a factor of 4.

This PR attempts to alleviate the perf issue by detecting an ending
repeat pattern. The idea is to factor out the iter domain that
corresponds to the repetition and move it to the outermost position. The
reference scheduling is then done only for the remaining iter domains.

For example, for one of the resize segments of the Mistral forward case
(I think this should correspond to the computation of the `K` input), we
have this code sequence:

```
  #pragma unroll
  for(nvfuser_index_t i14 = 0LL; i14 < 8LL; ++i14) {
    nvfuser_index_t i15;
    i15 = -i14;
    __bfloat T5[1LL];
    T5[0LL]
       = T59[i14];
    __bfloat T6[1LL];
    T6[0LL]
       = T5[0LL];
    float T36[1LL];
    T36[0LL]
       = __bfloat2float(T6[0LL]);
    float T37[1LL];
    T37[0LL]
       = __bfloat2float(T60[i14]);
    float T38[1LL];
    T38[0LL]
      = T36[0LL]
      * T37[0LL];
    __bfloat T67[1LL];
    T67[0LL]
       = T66[i14];
    __bfloat T68[1LL];
    T68[0LL]
       = T67[0LL];
    __bfloat T39[1LL];
    T39[0LL]
       = T68[0LL];
    __bfloat T45[1LL];
    T45[0LL]
       = ((i10 >= i15) && (i11 < i15)) ? T39[0LL] : 0.0000e+00f;
    float T49[1LL];
    T49[0LL]
       = __bfloat2float(T61[i14]);
    __bfloat T64[1LL];
    T64[0LL]
       = T63[i14];
    __bfloat T65[1LL];
    T65[0LL]
       = T64[0LL];
    __bfloat T40[1LL];
    T40[0LL]
       = T65[0LL];
    float T41[1LL];
    T41[0LL]
       = __bfloat2float(T40[0LL]);
    float T42[1LL];
    T42[0LL]
       = -T41[0LL];
    __bfloat T43[1LL];
    T43[0LL]
       = __float2bfloat(T42[0LL]);
    __bfloat T44[1LL];
    T44[0LL]
       = (i10 < i15) ? T43[0LL] : 0.0000e+00f;
    __bfloat T46[1LL];
    T46[0LL]
      = T44[0LL]
      | T45[0LL];
    float T48[1LL];
    T48[0LL]
       = __bfloat2float(T46[0LL]);
    float T50[1LL];
    T50[0LL]
      = T48[0LL]
      * T49[0LL];
    float T51[1LL];
    T51[0LL]
      = T38[0LL]
      + T50[0LL];
    __bfloat T52[1LL];
    T52[0LL]
       = __float2bfloat(T51[0LL]);
    __bfloat T53[1LL];
    T53[0LL]
       = T52[0LL];
    __bfloat T54[1LL];
    T54[0LL]
       = T53[0LL];
    T62[i14]
       = T54[0LL];
  }
  if ((b12 && (((4LL * i6) + (i3 / 4096LL)) < 32LL))) {
    loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T55[((((128LL * i2) + (2097152LL * i6)) + i1) + (128LL * ((nvfuser_index_t)blockIdx.x)))], &T62[0LL]);
  }
```

With this PR, it looks like:

```
  for(nvfuser_index_t i16 = 0LL; i16 < 8LL; ++i16) {
    nvfuser_index_t i17;
    i17 = -i16;
    __bfloat T67[1LL];
    T67[0LL]
       = T66[i16];
    __bfloat T68[1LL];
    T68[0LL]
       = T67[0LL];
    __bfloat T39[1LL];
    T39[0LL]
       = T68[0LL];
    __bfloat T45[1LL];
    T45[0LL]
       = ((i7 >= i17) && (i8 < i17)) ? T39[0LL] : 0.0000e+00f;
    __bfloat T64[1LL];
    T64[0LL]
       = T63[i16];
    __bfloat T65[1LL];
    T65[0LL]
       = T64[0LL];
    __bfloat T40[1LL];
    T40[0LL]
       = T65[0LL];
    float T41[1LL];
    T41[0LL]
       = __bfloat2float(T40[0LL]);
    float T42[1LL];
    T42[0LL]
       = -T41[0LL];
    __bfloat T43[1LL];
    T43[0LL]
       = __float2bfloat(T42[0LL]);
    __bfloat T44[1LL];
    T44[0LL]
       = (i7 < i17) ? T43[0LL] : 0.0000e+00f;
    __bfloat T46[1LL];
    T46[0LL]
      = T44[0LL]
      | T45[0LL];
    __bfloat T5[1LL];
    T5[0LL]
       = T59[i16];
    __bfloat T6[1LL];
    T6[0LL]
       = T5[0LL];
    float T36[1LL];
    T36[0LL]
       = __bfloat2float(T6[0LL]);
    float T37[1LL];
    T37[0LL]
       = __bfloat2float(T60[i16]);
    float T38[1LL];
    T38[0LL]
      = T36[0LL]
      * T37[0LL];
    float T49[1LL];
    T49[0LL]
       = __bfloat2float(T61[i16]);
    float T48[1LL];
    T48[0LL]
       = __bfloat2float(T46[0LL]);
    float T50[1LL];
    T50[0LL]
      = T48[0LL]
      * T49[0LL];
    float T51[1LL];
    T51[0LL]
      = T38[0LL]
      + T50[0LL];
    T52[i16]
       = __float2bfloat(T51[0LL]);
  }
  #pragma unroll
  for(nvfuser_index_t i18 = 0LL; i18 < 4LL; ++i18) {
    Array<__bfloat, 8LL, 8> T62;
    #pragma unroll
    for(nvfuser_index_t i19 = 0LL; i19 < 8LL; ++i19) {
      __bfloat T53[1LL];
      T53[0LL]
         = T52[i19];
      __bfloat T54[1LL];
      T54[0LL]
         = T53[0LL];
      T62[i19]
         = T54[0LL];
    }
    if ((b12 && (i14 < (-i18)))) {
      loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T55[(i10 + (524288LL * i18))], &T62[0LL]);
    }
  }
```

Notice that the final store has now its own loop with extent 4, which is
the repetition factor.

The launch configurations before:
```
Launch Parameters: BlockDim.x = 128, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 16384, GridDim.y = -1, GridDim.z = -1, Smem Size = 0
```

The launch configurations after:
```
Launch Parameters: BlockDim.x = 128, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4096, GridDim.y = -1, GridDim.z = -1, Smem Size = 0
```

The number of blocks is reduced by a factor of 4.

Currently, while this pattern can appear with any scheduler pattern, I
have only added to the resize scheduler. In RoPE, there's indeed a
pointwise segment with an ending repeat, but that's not addressed in
this PR.

### Performance benfit

Will update
  • Loading branch information
naoyam authored Jan 17, 2025
1 parent 265f78d commit 902375e
Show file tree
Hide file tree
Showing 5 changed files with 427 additions and 11 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/tools/loop_domain_scheduler.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/maxinfo_propagator.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/resize_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/static_repeat.cpp
${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp
${NVFUSER_SRCS_DIR}/scheduler/utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/vectorize_helper.cpp
Expand Down
114 changes: 103 additions & 11 deletions csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
#include <scheduler/tools/inlining.h>
#include <scheduler/tools/loop_domain_scheduler.h>
#include <scheduler/tools/resize_utils.h>
#include <scheduler/tools/static_repeat.h>
#include <val_graph_visitor.h>

#include <memory>

namespace nvfuser {

namespace {
Expand Down Expand Up @@ -257,18 +260,22 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {

scheduler_utils::clearMemorySpace(fusion);

auto ref_tv = getReferenceTensor(fusion);
NVF_ERROR(ref_tv != nullptr);

scheduler_utils::cacheInputs(fusion, true);
scheduler_utils::cacheAndForkOutputs(fusion, true);

auto resize_tensor_ops = ir_utils::getOpsOfType<SliceOp, PadOp>(fusion);

IdModel id_model(fusion, /*build_graphs=*/false);
const auto& exact_graph = id_model.buildExactGraph();
std::unique_ptr<IdModel> id_model =
std::make_unique<IdModel>(fusion, /*build_graphs=*/false);
id_model->buildExactGraph();

// Replicate resize inputs if necessary to avoid conflicting
// propagations
const auto exclusivity_info_map = scheduler_tools::getNonExclusiveResizeInfo(
resize_tensor_ops, exact_graph);
resize_tensor_ops, id_model->idGraph(IdMappingMode::EXACT));
for (auto resize_tensor_op : resize_tensor_ops) {
auto out_tv = resize_tensor_op->output(0)->as<TensorView>();
if (exclusivity_info_map.count(out_tv) == 0) {
Expand Down Expand Up @@ -304,8 +311,12 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
scheduler_tools::propagateResizeToInputs(expr);
}

auto ref_tv = getReferenceTensor(fusion);
NVF_ERROR(ref_tv != nullptr);
// Update the IdModel
id_model = std::make_unique<IdModel>(fusion, /*build_graphs=*/false);
id_model->buildExactGraph();

// Detect an ending repeat
auto static_repeat_info = scheduler_tools::getMaybeStaticRepeatInfo(ref_tv);

// Just simple scheduling for now.
// TODO: Do something smarter. Can just use the pointwise scheduler?
Expand Down Expand Up @@ -335,7 +346,48 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
const int64_t bdimx = 128;

// Make sure the DID ID located at the outermost position
const auto outermost_pos = scheduler_utils::reorderDevicesToOuter(ref_tv);
auto outermost_pos = scheduler_utils::reorderDevicesToOuter(ref_tv);

// [DID, ..., ...]
// ^
// +--- outermost_pos

// Move the static repeat ID to the outermost position if
// detected. The repeat ID then just remains there with no
// scheduling.
bool repeat_id_moved_to_outermost = false;
if (static_repeat_info.has_value()) {
NVF_ERROR(ref_tv == static_repeat_info->repeat_output_tv);
auto ref_repeat_id_it = std::find_if(
ref_tv->getLoopDomain().begin(),
ref_tv->getLoopDomain().end(),
[&](IterDomain* loop_id) {
return id_model->idGraph(IdMappingMode::EXACT)
.disjointValSets()
.strictAreMapped(loop_id, static_repeat_info->reshape_repeat_id);
});
// Gives up if the repeat ID is not found. Unclear if this could
// actually happen, though.
if (ref_repeat_id_it != ref_tv->getLoopDomain().end()) {
auto repeat_id_pos =
std::distance(ref_tv->getLoopDomain().begin(), ref_repeat_id_it);
NVF_ERROR(
repeat_id_pos >= outermost_pos,
"Unexpected to have DID-parallelized repeat axis: ",
static_repeat_info->reshape_repeat_id->toString());

// [DID, ..., repeat_id, ...]
// ^
// +--- outermost_pos
ref_tv->reorder(std::unordered_map<int64_t, int64_t>{{repeat_id_pos, 0}});
++outermost_pos;
// [repeat_id, DID, ...]
// ^
// +--- outermost_pos

repeat_id_moved_to_outermost = true;
}
}

const int64_t vec_factor = resize_params->vectorization_factor;

Expand Down Expand Up @@ -373,15 +425,55 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
// [..., I0/bdimx(BIDx), bdimx(TIDx), vec_factor]

// Propagate the reference to the other tensors. Note that the
// update flag is enabled so to workaround the resize propagation
// update flag is enabled to workaround the resize propagation
// issue. This may not work if there's a tensor that is reshaped
// from the reference tensor, but that should not be the case as the
// reference is picked by the same routine used for the pointwise
// scheduler.
scheduler_tools::scheduleLoopDomainsLike(
fusion->allTvs(),
ref_tv->getLoopDomain(),
/*update_loop_domain_only=*/true);
//
// When an ending static repeat is detected and the repeat ID is
// moved to the outermost position, propagation is done separately
// between the tensors before the repeat and after the repeat. The
// tensors are first grouped into the pre-repeat group and the
// post-repeat group, where only the latter group has the repeat
// IDs. When propagating the loop domain of the reference tensor,
// which has the repeat ID, the full loop domain is propagated only
// to the post-repeat group. For the pre-repeat group, the repeat ID
// is dropped and only the remaining loop domain is propagated.
if (repeat_id_moved_to_outermost) {
// Divide all tvs to the pre and posgt repeat groups
auto all_tvs = fusion->allTvs();
std::vector<TensorView*> post_repeat_tvs;
post_repeat_tvs.reserve(static_repeat_info->repeat_tvs.size());
std::vector<TensorView*> pre_repeat_tvs;
pre_repeat_tvs.reserve(
all_tvs.size() - static_repeat_info->repeat_tvs.size());
for (auto tv : all_tvs) {
if (static_repeat_info->repeat_tvs.count(tv)) {
post_repeat_tvs.push_back(tv);
} else {
pre_repeat_tvs.push_back(tv);
}
}

// The repeat ID should be located at the outermost position
std::vector<IterDomain*> non_repeated_loop{
ref_tv->getLoopDomain().begin() + 1, ref_tv->getLoopDomain().end()};

scheduler_tools::scheduleLoopDomainsLike(
pre_repeat_tvs,
non_repeated_loop,
/*update_loop_domain_only=*/true);
scheduler_tools::scheduleLoopDomainsLike(
post_repeat_tvs,
ref_tv->getLoopDomain(),
/*update_loop_domain_only=*/true);
} else {
scheduler_tools::scheduleLoopDomainsLike(
fusion->allTvs(),
ref_tv->getLoopDomain(),
/*update_loop_domain_only=*/true);
}

if (vec_factor > 1) {
auto vec_ref_tv = largest_input != nullptr ? largest_input : ref_tv;
Expand Down
166 changes: 166 additions & 0 deletions csrc/scheduler/tools/static_repeat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on

#include <ir/all_nodes.h>
#include <ir/utils.h>
#include <scheduler/tools/static_repeat.h>

namespace nvfuser {
namespace scheduler_tools {

std::optional<StaticRepeatInfo> getMaybeStaticRepeatInfo(
TensorView* maybe_repeat_out) {
// The pattern to detect:
//
// broadcast_out = broadcast(input)
// expand_out = expand(broadcast_out)
// repeat_out = reshape(expand_out)
//
// Additionally, since maybe_repeat_out is commonly a fusion
// output, it is likely there's a cache tv between expand_out and
// repeat_out, so the following pattern should also be detected.
//
// broadcast_out = broadcast(input)
// expand_out = expand(broadcast_out)
// cache_of_repeat_out = reshape(expand_out)
// repeat_out = set(cache_of_repeat_out)

std::unordered_set<TensorView*> repeat_tvs;
repeat_tvs.insert(maybe_repeat_out);

auto reshape_out = maybe_repeat_out;

// Check if there's a cache
if (auto ldst = dynamic_cast<LoadStoreOp*>(maybe_repeat_out->definition());
ldst->opType() == LoadStoreOpType::Set) {
reshape_out = ldst->in()->as<TensorView>();
repeat_tvs.insert(reshape_out);
}

// Detect reshape
auto reshape = dynamic_cast<ViewOp*>(reshape_out->definition());
if (reshape == nullptr) {
return std::nullopt;
}

// Detect expand
auto expand_out = reshape->in();
repeat_tvs.insert(expand_out);
auto expand = dynamic_cast<ExpandOp*>(expand_out->definition());
if (expand == nullptr) {
return std::nullopt;
}

// Detect broadcast
auto broadcast_out = expand->in();
repeat_tvs.insert(broadcast_out);
auto broadcast = dynamic_cast<BroadcastOp*>(broadcast_out->definition());
if (broadcast == nullptr) {
return std::nullopt;
}

auto inp_tv = broadcast->in();

// Not sure if this is really necessary to check, but assume there's
// only single chain of the ops and tensors from inp_tv to
// maybe_reshape_out
if (inp_tv->uses().size() > 1 &&
std::any_of(repeat_tvs.begin(), repeat_tvs.end(), [](TensorView* tv) {
return tv->uses().size() > 1;
})) {
return std::nullopt;
}

// Check if the ops match with the repeat pattern. Currently only
// one iter domain can be repeated
IterDomain* broadcast_id = nullptr;
int64_t broadcast_pos = -1;
for (const auto i : c10::irange(broadcast_out->getLogicalDomain().size())) {
if (broadcast->getBroadcastDimFlags().at(i)) {
if (broadcast_id != nullptr) {
// Multiple broadcast IDs not supported
return std::nullopt;
}
broadcast_id = broadcast_out->getLogicalDomain().at(i);
broadcast_pos = (int64_t)i;
}
}

if (broadcast_id == nullptr) {
return std::nullopt;
}

// Check if and only if the broadcast ID is expanded
IterDomain* expanded_id = nullptr;
for (const auto i : c10::irange(broadcast_out->getLogicalDomain().size())) {
auto p_id = broadcast_out->getLogicalDomain().at(i);
auto c_id = expand_out->getLogicalDomain().at(i);
if (p_id == broadcast_id && c_id->isBroadcast() &&
c_id->hasExpandedExtent()) {
expanded_id = c_id;
} else if (
p_id->isBroadcast() && !p_id->hasExpandedExtent() &&
c_id->isBroadcast() && c_id->hasExpandedExtent()) {
// Expanded but this broadcast was not introduced by the
// preceding broadcast op
return std::nullopt;
}
}

if (expanded_id == nullptr) {
return std::nullopt;
}

// Only a static repeat factor is considered
if (!expanded_id->expandedExtent()->isConstInt()) {
return std::nullopt;
}

// The expanded ID should be merged with the iter domain next to it,
// and that should be the only reshape expr
auto reshape_exprs = DependencyCheck::getAllExprsBetween(
{reshape_out->getRootDomain().begin(),
reshape_out->getRootDomain().end()},
{reshape_out->getLogicalDomain().begin(),
reshape_out->getLogicalDomain().end()});
if (reshape_exprs.size() != 1) {
return std::nullopt;
}

auto reshape_merge = dynamic_cast<Merge*>(reshape_exprs.at(0));
if (reshape_merge == nullptr) {
return std::nullopt;
}

// The corresponding root ID of the outout tv should be one of the
// inputs of the merge
auto reshape_root_broadcast = reshape_out->getRootDomain().at(broadcast_pos);
if (reshape_merge->outer() != reshape_root_broadcast &&
reshape_merge->inner() != reshape_root_broadcast) {
return std::nullopt;
}

// Reshape of an expanded broadcast always generates a concrete
// non-broadcast ID, so this check is not necessary, but just in
// case in the future that may change.
if (reshape_merge->out()->isBroadcast() ||
reshape_merge->out()->hasExpandedExtent()) {
return std::nullopt;
}

StaticRepeatInfo info;
info.repeat_output_tv = maybe_repeat_out;
info.reshape_output_tv = reshape_out;
info.reshape_repeat_id = reshape_out->getRootDomain().at(broadcast_pos);
info.repeat_tvs = repeat_tvs;

return info;
}

} // namespace scheduler_tools
} // namespace nvfuser
Loading

0 comments on commit 902375e

Please sign in to comment.