diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index bef6b2b7def..9048792dc4a 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -860,15 +860,9 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) { grouped_inputs_outputs[1].begin(), grouped_inputs_outputs[1].end()); for (auto tv : grouped_inputs_outputs[1]) { if (tv->isFusionInput()) { - auto existing_cache = ir_utils::consumerTvsOf(tv)[0]; - if (ir_utils::consumerTvsOf(existing_cache).size() > 1) { - auto new_cache = tv->cacheAfter(); - new_cache->setMemoryType(MemoryType::Shared); - group2_and_cached_inputs.emplace(new_cache); - } else { - existing_cache->setMemoryType(MemoryType::Shared); - group2_and_cached_inputs.emplace(existing_cache); - } + auto new_cache = tv->cacheAfter(/*propagate_allocation_domain=*/true); + new_cache->setMemoryType(MemoryType::Shared); + group2_and_cached_inputs.emplace(new_cache); } } // set cached outputs of group 2 to shared memory