Skip to content

Commit

Permalink
Transpose propagates allocation to input caches that live in shared m…
Browse files Browse the repository at this point in the history
…emory.
  • Loading branch information
wujingyue committed Jan 16, 2025
1 parent de303db commit 9209c24
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,15 +860,9 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) {
grouped_inputs_outputs[1].begin(), grouped_inputs_outputs[1].end());
for (auto tv : grouped_inputs_outputs[1]) {
if (tv->isFusionInput()) {
auto existing_cache = ir_utils::consumerTvsOf(tv)[0];
if (ir_utils::consumerTvsOf(existing_cache).size() > 1) {
auto new_cache = tv->cacheAfter();
new_cache->setMemoryType(MemoryType::Shared);
group2_and_cached_inputs.emplace(new_cache);
} else {
existing_cache->setMemoryType(MemoryType::Shared);
group2_and_cached_inputs.emplace(existing_cache);
}
auto new_cache = tv->cacheAfter(/*propagate_allocation_domain=*/true);
new_cache->setMemoryType(MemoryType::Shared);
group2_and_cached_inputs.emplace(new_cache);
}
}
// set cached outputs of group 2 to shared memory
Expand Down

0 comments on commit 9209c24

Please sign in to comment.