Skip to content

Commit

Permalink
adding one more test
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 committed Oct 31, 2024
1 parent 04c9db0 commit 7816a93
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 20 deletions.
30 changes: 12 additions & 18 deletions csrc/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,24 +398,18 @@ std::vector<IterDomain*> ContiguousInnerDimensionsMapper::projectId(
frontier.erase(frontier.begin(), it);

if (recording_) {
if (it+1 == frontier.end()) {
// FIXME: real analysis is needed here
// TODO: test on a single sided pad.
auto consumer_factor = getProjectedExtent(id_from);
auto comp = [](Val* factor, Val* extent) {
return SimplifyingIrBuilder::whereExpr(
SimplifyingIrBuilder::eqExpr(extent, extent->container()->zeroVal()),
factor,
SimplifyingIrBuilder::gcdExpr(factor, extent));
};
consumer_factor = comp(consumer_factor, resize_op->leftExpand());
consumer_factor = comp(consumer_factor, resize_op->rightExpand());
addProjectedExtent(id_to, consumer_factor);
} else {
// pad vectorization can only be done at fastest dimension, project it to 0 I believe would avoid that.
// FIXME: add a test case for me
addProjectedExtent(id_to, id_to->container()->zeroVal());
}
// FIXME: real analysis is needed here
// TODO: test on a single sided pad.
auto consumer_factor = getProjectedExtent(id_from);
auto comp = [](Val* factor, Val* extent) {
return SimplifyingIrBuilder::whereExpr(
SimplifyingIrBuilder::eqExpr(extent, extent->container()->zeroVal()),
factor,
SimplifyingIrBuilder::gcdExpr(factor, extent));
};
consumer_factor = comp(consumer_factor, resize_op->leftExpand());
consumer_factor = comp(consumer_factor, resize_op->rightExpand());
addProjectedExtent(id_to, consumer_factor);
}
} else {
frontier.erase(frontier.begin(), it + 1);
Expand Down
37 changes: 35 additions & 2 deletions tests/cpp/test_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4146,7 +4146,7 @@ TEST_F(ResizeTest, UnrollNonInnermost) {
auto tv0 = makeContigConcreteTensor(shape);
fusion.addInput(tv0);

auto tv1 = pad(tv0, {IrBuilder::create<Val>(4L), IrBuilder::create<Val>(4L), IrBuilder::create<Val>(0L), IrBuilder::create<Val>(0L)});
auto tv1 = pad(tv0, {IrBuilder::create<Val>(0L), IrBuilder::create<Val>(0L), IrBuilder::create<Val>(4L), IrBuilder::create<Val>(4L)});
fusion.addOutput(tv1);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
Expand All @@ -4156,7 +4156,7 @@ TEST_F(ResizeTest, UnrollNonInnermost) {
FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);

auto ref = at::pad(t0, {4, 4, 0, 0});
auto ref = at::pad(t0, {0, 0, 4, 4});

NVF_CHECK(ref.equal(cg_outputs[0]));
}
Expand Down Expand Up @@ -4190,4 +4190,37 @@ TEST_F(ResizeTest, PadAndCacheUses) {
auto ref_1 = at::relu(t0);
NVF_CHECK(ref_1.equal(cg_outputs[1]));
}

TEST_F(ResizeTest, Playground) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());

const std::vector<int64_t> shape({1024L * 1024L});

// Using a concrete tensor to avoid dynamic reshape
auto tv0 = makeContigConcreteTensor(shape);
fusion.addInput(tv0);

auto tv1 = pad(tv0, {IrBuilder::create<Val>(4L), IrBuilder::create<Val>(4L)});
fusion.addOutput(tv1);
auto tv2 = slice(
tv0,
{{IrBuilder::create<Val>(2L),
sub(tv0->axis(0)->extent(), IrBuilder::create<Val>(2L))}});
fusion.addOutput(tv2);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn(shape, options);
std::vector<c10::IValue> aten_inputs({t0});

FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);

auto ref_0 = at::pad(t0, {4, 4});
NVF_CHECK(ref_0.equal(cg_outputs[0]));

auto ref_1 = t0.index({at::indexing::Slice(2, shape[0] - 2)});
NVF_CHECK(ref_1.equal(cg_outputs[1]));
}
} // namespace nvfuser

0 comments on commit 7816a93

Please sign in to comment.