Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reapply #3621 #3714

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion csrc/scheduler/ampere_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,8 @@ void AmpereMultipleMatmulScheduler::cacheInputsAndOutputs() {
scheduler_utils::clearMemorySpace(fusion_);

// Cache inputs
scheduler_utils::cacheInputs(fusion_, /*unroll=*/true);
scheduler_utils::cacheInputs(
fusion_, /*unroll=*/true, /*propagate_allocation=*/true);

// Cache and fork outputs
cached_outputs_ =
Expand Down
3 changes: 2 additions & 1 deletion csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ void HopperMultipleMatmulScheduler::cacheInputsAndOutputs() {
scheduler_utils::clearMemorySpace(fusion_);

// Cache inputs
scheduler_utils::cacheInputs(fusion_, /*unroll=*/true);
scheduler_utils::cacheInputs(
fusion_, /*unroll=*/true, /*propagate_allocation=*/true);

// Cache and fork outputs
scheduler_utils::cacheAndForkOutputs(fusion_, /*unroll=*/true);
Expand Down
15 changes: 6 additions & 9 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,15 +860,12 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) {
grouped_inputs_outputs[1].begin(), grouped_inputs_outputs[1].end());
for (auto tv : grouped_inputs_outputs[1]) {
if (tv->isFusionInput()) {
auto existing_cache = ir_utils::consumerTvsOf(tv)[0];
if (ir_utils::consumerTvsOf(existing_cache).size() > 1) {
auto new_cache = tv->cacheAfter();
new_cache->setMemoryType(MemoryType::Shared);
group2_and_cached_inputs.emplace(new_cache);
} else {
existing_cache->setMemoryType(MemoryType::Shared);
group2_and_cached_inputs.emplace(existing_cache);
}
auto new_cache = tv->cacheAfter(
LoadStoreOpType::Set,
CacheOp::Unspecified,
/*propagate_allocation_domain=*/true);
new_cache->setMemoryType(MemoryType::Shared);
group2_and_cached_inputs.emplace(new_cache);
}
}
// set cached outputs of group 2 to shared memory
Expand Down
13 changes: 8 additions & 5 deletions csrc/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,10 @@ void clearMemorySpace(Fusion* fusion) {

// Returns cached after tensors of the fusion inputs if unrolled. Otherwise
// return empty vector.
std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll) {
std::vector<TensorView*> cacheInputs(
Fusion* fusion,
bool unroll,
bool propagate_allocation) {
if (!unroll) {
return {};
}
Expand Down Expand Up @@ -1224,10 +1227,10 @@ std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll) {
}

auto cached_tv = tv->cacheAfter(
/*op_type=*/LoadStoreOpType::Set,
/*cache_op=*/CacheOp::Unspecified,
/*propagate_allocation_domain=*/true,
/*cached_uses=*/cached_uses);
LoadStoreOpType::Set,
CacheOp::Unspecified,
propagate_allocation,
cached_uses);
cached_inputs.emplace_back(cached_tv);
}
return cached_inputs;
Expand Down
5 changes: 4 additions & 1 deletion csrc/scheduler/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,10 @@ void clearMemorySpace(Fusion* fusion);

// Returns cached after tensors of the fusion inputs if unrolled. Otherwise
// return empty vector.
std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll);
std::vector<TensorView*> cacheInputs(
Fusion* fusion,
bool unroll,
bool propagate_allocation = false);

// Returns the pairs of <cache of each fusion output, corresponding output> for
// all outputs.
Expand Down
4 changes: 1 addition & 3 deletions tests/cpp/test_allocation_domain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1426,11 +1426,9 @@ TEST_F(AllocationDomainTest, InputAllocationIsSplit_Concrete) {
fusion->addInput(in);
fusion->addOutput(out);

// Ideally, loop should stay the same as logical because a fusion input comes
// from outside and isn't generated by a loop in the containing kernel (cf.
// #3479).
in->split(0, 2);
in->setAllocationDomain(in->getLoopDomain(), true);
in->setLoopDomain(in->getLogicalDomain());

FusionExecutorCache executor_cache(std::move(fusion));
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA);
Expand Down
41 changes: 41 additions & 0 deletions tests/python/test_pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,44 @@ def fusion_func(fd: FusionDefinition):

with pytest.raises(RuntimeError, match="No executor supports provided fusion."):
_ = fd.execute(inputs)


def test_issue_3071():
def fusion_func(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
shape=[1, 1024, 128],
contiguity=[None, True, True],
dtype=DataType.Float,
is_cpu=False,
stride_order=[2, 0, 1],
)
T1 = fd.define_tensor(
shape=[1, 32, 1024, 128],
contiguity=[None, True, True, True],
dtype=DataType.Float,
is_cpu=False,
stride_order=[3, 1, 2, 0],
)
T2 = fd.ops.broadcast(T0, is_broadcast_dim=[False, True, False, False])
S3 = fd.ops.size(T2, dim=0)
S4 = fd.define_scalar(32, dtype=DataType.Int)
S5 = fd.ops.size(T2, dim=2)
S6 = fd.ops.size(T2, dim=3)
V7 = fd.define_vector([S3, S4, S5, S6], dtype=DataType.Int)
T8 = fd.ops.expand(T2, shape=V7)
T9 = fd.ops.mul(T8, T1)
fd.add_output(T9)
fd.add_output(T2)

with FusionDefinition() as fd:
fusion_func(fd)

t0 = torch.randn(131072, dtype=torch.float32, device="cuda").as_strided(
(1, 1024, 128), (131072, 1, 1024)
)
t1 = torch.randn(4194304, dtype=torch.float32, device="cuda").as_strided(
(1, 32, 1024, 128), (4194304, 128, 4096, 1)
)
t9, t2 = fd.execute([t0, t1])
torch.testing.assert_close(t9, t0.unsqueeze(1) * t1)
torch.testing.assert_close(t2, t0.unsqueeze(1))
Loading