Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lowering vectorized pad #3261

Merged
merged 51 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
8f9708f
relaxing check
jjsjann123 Sep 2, 2024
54826aa
allow cache on inputs for pad
jjsjann123 Sep 3, 2024
e54938c
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Sep 3, 2024
2bc3c7a
cpp example
jjsjann123 Sep 24, 2024
d04e8c3
Merge branch 'jjsjann123/pad_vec' into jjsjann123/resize_vec
jjsjann123 Sep 24, 2024
d0addc4
reverting earlier changes
jjsjann123 Sep 24, 2024
490fdbe
Revert "reverting earlier changes"
jjsjann123 Sep 24, 2024
51c3022
cherry-pick my revert
jjsjann123 Sep 24, 2024
1158ef0
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Oct 2, 2024
fdc6a9a
debug print
jjsjann123 Oct 3, 2024
9a6c03a
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Oct 4, 2024
a9d16ce
removing comments
jjsjann123 Oct 7, 2024
3401119
removing assert
jjsjann123 Oct 8, 2024
5d05284
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Oct 8, 2024
b6587ee
patching test
jjsjann123 Oct 10, 2024
28decac
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Oct 10, 2024
3e53feb
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Oct 20, 2024
ad61ecb
fixing test
jjsjann123 Oct 20, 2024
a8edc56
fixing
jjsjann123 Oct 20, 2024
9cdeb64
fixing test
jjsjann123 Oct 21, 2024
09a2aee
does this work to replace Ternary(where) with IfThenElse
jjsjann123 Oct 21, 2024
895d0bf
fixing build
jjsjann123 Oct 21, 2024
7a15e22
removing print
jjsjann123 Oct 22, 2024
a6e8fb1
restore lower to ternary:where; restore vectorization on tests
jjsjann123 Oct 22, 2024
fe0f263
testing water
jjsjann123 Oct 23, 2024
baa7b09
fixing syntax
jjsjann123 Oct 23, 2024
ca5ced1
now it's functional
jjsjann123 Oct 23, 2024
e0492d3
better formatting on printed code
jjsjann123 Oct 23, 2024
b528429
adding a tab
jjsjann123 Oct 23, 2024
a23e010
supporting local memory
jjsjann123 Oct 23, 2024
57b90d1
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Oct 23, 2024
7a976c7
clangformat
jjsjann123 Oct 23, 2024
f11d662
apparently there are ternary operations on scalars
jjsjann123 Oct 23, 2024
5a83fc6
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Oct 23, 2024
39f83f7
fixing
jjsjann123 Oct 23, 2024
07eafd1
fixing
jjsjann123 Oct 23, 2024
986b361
clangformat
jjsjann123 Oct 23, 2024
7409913
clangformat
jjsjann123 Oct 23, 2024
76cbcd8
clangformat again
jjsjann123 Oct 23, 2024
5f996fc
Merge branch 'main' into jjsjann123/resize_vec
jjsjann123 Oct 31, 2024
803a95b
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Nov 4, 2024
a67fb57
polish PR for review
jjsjann123 Nov 4, 2024
11cd4d1
Merge remote-tracking branch 'origin/jjsjann123/resize_vec' into jjsj…
jjsjann123 Nov 4, 2024
65aa77d
Merge remote-tracking branch 'origin/main' into jjsjann123/resize_vec
jjsjann123 Nov 4, 2024
1f75d7a
missed one arg
jjsjann123 Nov 4, 2024
4c92371
oops, fixing the generated code
jjsjann123 Nov 4, 2024
3ec2a6b
review comments
jjsjann123 Nov 4, 2024
d2864ab
fixing code
jjsjann123 Nov 5, 2024
1b4f2c1
I think this is fixed now
jjsjann123 Nov 5, 2024
0e4e61f
adding comments per review request
jjsjann123 Nov 5, 2024
4d4f747
another comment
jjsjann123 Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 115 additions & 47 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,55 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
}
}

void generateVectorizedLdSt(
Val* in,
Val* out,
CacheOp cache_op,
int64_t vector_word_size) {
auto out_tv = out->as<kir::TensorIndex>()->view();
auto in_tv = in->as<kir::TensorIndex>()->view();

bool localToGlobal = out_tv->getMemoryType() == MemoryType::Global &&
in_tv->getMemoryType() == MemoryType::Local;

bool globalToLocal = out_tv->getMemoryType() == MemoryType::Local &&
in_tv->getMemoryType() == MemoryType::Global;

bool globalToGlobal = out_tv->getMemoryType() == MemoryType::Global &&
in_tv->getMemoryType() == MemoryType::Global;

bool is_volatile_to = out_tv->getMemoryType() == MemoryType::Global &&
kernel_->summary().sync_map->needsRawSync(out_tv).hasBID();

bool is_volatile_from = in_tv->getMemoryType() == MemoryType::Global &&
kernel_->summary().sync_map->needsRawSync(in_tv).hasBID();

if (localToGlobal) {
code_ << "loadLocalToGlobal<" << out->dtype() << ", /*vec_size=*/"
<< vector_word_size << ", /*is_volatile=*/"
<< (is_volatile_to ? "true" : "false") << ">(";
code_ << " &" << gen(out) << ", &" << gen(in) << ")";
} else if (globalToLocal) {
code_ << "loadGlobalToLocal<" << out->dtype() << ", /*vec_size=*/"
<< vector_word_size << ", /*is_volatile=*/"
<< (is_volatile_from ? "true" : "false") << ", "
<< "CacheOp::" << cache_op << ">(&" << gen(out) << ", ";
code_ << " &" << gen(in) << ")";
} else if (globalToGlobal) {
code_ << "loadGlobalToGlobal<" << out->dtype() << ", /*vec_size=*/"
<< vector_word_size << ", /*is_volatile_to=*/"
<< (is_volatile_to ? "true" : "false") << ", /*is_volatile_from=*/"
<< (is_volatile_from ? "true" : "false") << ">(";
code_ << " &" << gen(out) << ", ";
code_ << " &" << gen(in) << ")";
} else {
code_ << "loadGeneric<" << out->dtype() << ", " << vector_word_size
<< ">(";
code_ << " &" << gen(out) << ", ";
code_ << " &" << gen(in) << ")";
}
}

// Cannot just use ConstIrVisitor::handle as it expects a vector of
// const Expr*, whereas most of the IR API returns a vector of
// non-const Expr*.
Expand Down Expand Up @@ -1001,6 +1050,68 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
}

void handle(const TernaryOp* top) final {
// Note: vectorized TernaryOp looks something like:
// ```
// predicate
// ? LoadGlobalToLocal(&dst[0], &in2[index])
// : arraySet(&dst[0], in3);
// ```
//
// Current limitation:
// 1. only TernaryOpType::Where is supported;
// 2. predicate needs to be a scalar;
// 3. output needs to be a TensorView;
// 4. one and only one of the inputs needs to be a TensorView. (This is
// coming from validation analysis.)
if (top->out()->isA<kir::TensorIndex>()) {
// Get vectorization information
auto out_tv = top->out()->as<kir::TensorIndex>()->view();
int64_t vector_word_size = ir_utils::getVectorizeSize(out_tv);
bool is_vector_op = vectorize_scope_ && vector_word_size != 1;

if (is_vector_op) {
NVF_CHECK(
top->in1()->isScalar(),
"predicate should be a scalar for vectorized TernaryOp::where");
NVF_CHECK(
!top->out()->isScalar(),
"scalar output in vectorization isn't supported");
NVF_CHECK(
top->getTernaryOpType() == TernaryOpType::Where,
"vectorization only works on TernaryOp::where");
indent() << gen(top->in1()) << "\n";
indent() << kTab << "? ";
auto vec_load = [&out_tv, &top, &vector_word_size, this](Val* in) {
if (in->isScalar()) {
if (out_tv->getMemoryType() == MemoryType::Local &&
!out_tv->isCircularBuffered()) {
// Vectorized initialization, explicit type conversion is needed
// for complex numbers
code_ << genVariableName(out_tv) << ".set("
<< genCall(out_tv->dtype(), gen(in)) << ")";
} else {
// Note: currently arraySet option is not vectorized, so it will
// rely on auto vectorization pass of cuda compiler.
code_ << "arraySet<" << out_tv->getDataType().value() << ", "
<< vector_word_size << ">(&" << gen(top->out()) << ", ("
<< out_tv->getDataType().value() << ")" << gen(in) << ")";
}
} else {
generateVectorizedLdSt(
in, top->out(), CacheOp::AllLevels, vector_word_size);
}
};

// TODO: should we have the option to specify cache level?
vec_load(top->in2());
code_ << "\n";
indent() << kTab << ": ";
vec_load(top->in3());
code_ << ";\n";
return;
}
}

if (!print_inline_) {
indent() << gen(top->out());
if (!top->out()->isScalar()) {
Expand Down Expand Up @@ -1338,53 +1449,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
"Invalid input to unary op with tensor output, found: ",
ldst->in()->toString());

auto in_tv = ldst->in()->as<kir::TensorIndex>()->view();
bool localToGlobal = out_tv->getMemoryType() == MemoryType::Global &&
in_tv->getMemoryType() == MemoryType::Local;

bool globalToLocal = out_tv->getMemoryType() == MemoryType::Local &&
in_tv->getMemoryType() == MemoryType::Global;

bool globalToGlobal = out_tv->getMemoryType() == MemoryType::Global &&
in_tv->getMemoryType() == MemoryType::Global;

bool is_volatile_to = out_tv->getMemoryType() == MemoryType::Global &&
kernel_->summary().sync_map->needsRawSync(out_tv).hasBID();

bool is_volatile_from =
in_tv->getMemoryType() == MemoryType::Global &&
kernel_->summary().sync_map->needsRawSync(in_tv).hasBID();

if (localToGlobal) {
indent() << "loadLocalToGlobal<" << ldst->out()->dtype()
<< ", /*vec_size=*/" << vector_word_size
<< ", /*is_volatile=*/"
<< (is_volatile_to ? "true" : "false") << ">(";
code_ << " &" << gen(ldst->out()) << ", &" << gen(ldst->in())
<< ");\n";
} else if (globalToLocal) {
indent() << "loadGlobalToLocal<" << ldst->out()->dtype()
<< ", /*vec_size=*/" << vector_word_size
<< ", /*is_volatile=*/"
<< (is_volatile_from ? "true" : "false") << ", "
<< "CacheOp::" << ldst->cacheOp() << ">(&"
<< gen(ldst->out()) << ", ";
code_ << " &" << gen(ldst->in()) << ");\n";
} else if (globalToGlobal) {
indent() << "loadGlobalToGlobal<" << ldst->out()->dtype()
<< ", /*vec_size=*/" << vector_word_size
<< ", /*is_volatile_to=*/"
<< (is_volatile_to ? "true" : "false")
<< ", /*is_volatile_from=*/"
<< (is_volatile_from ? "true" : "false") << ">(";
code_ << " &" << gen(ldst->out()) << ", ";
code_ << " &" << gen(ldst->in()) << ");\n";
} else {
indent() << "loadGeneric<" << ldst->out()->dtype() << ", "
<< vector_word_size << ">(";
code_ << " &" << gen(ldst->out()) << ", ";
code_ << " &" << gen(ldst->in()) << ");\n";
}
indent();
generateVectorizedLdSt(
ldst->in(), ldst->out(), ldst->cacheOp(), vector_word_size);
code_ << ";\n";
}
return;
}
Expand Down
4 changes: 0 additions & 4 deletions csrc/device_lower/lower2device.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@

namespace nvfuser {

// TODO: we frequently use pairwise root mapping from consumers to producers.
// This information is implicitly in the computeAtMaps, but there's no isolated
// container for this information that we can reuse. Would be nice to generate
// such a structure and propagate it through lowering.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class GpuLower : public NonCopyable {
class KernelIrMapper;
Expand Down
3 changes: 2 additions & 1 deletion csrc/device_lower/pass/predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator {
"Expecting predicated body to only have one vectorized expression.");
auto vec_expr = ite->thenBody()[0];
NVF_ERROR(
vec_expr->isA<UnaryOp>() || vec_expr->isA<LoadStoreOp>(),
vec_expr->isA<UnaryOp>() || vec_expr->isA<LoadStoreOp>() ||
vec_expr->isA<TernaryOp>(),
"Vectorize predicate exprs only supported on set operations.");
NVF_ERROR(
ir_utils::isTvOp(vec_expr),
Expand Down
38 changes: 28 additions & 10 deletions csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,17 +668,31 @@ class VectorizeValidator : public OptInDispatch {
tv_def != nullptr,
"Tv has no definition, cannot validate vectorization:",
tv);
auto producer_tv = tv_def->inputs().at(0)->as<TensorView>();
auto producer_word_size_it =
GpuLower::current()->vectorizedAccesses().find(producer_tv);
if (producer_word_size_it !=
GpuLower::current()->vectorizedAccesses().end()) {
producer_word_size_it->second =
std::max(vector_word_size, producer_word_size_it->second);
} else {
GpuLower::current()->vectorizedAccesses().emplace(
producer_tv, vector_word_size);
// TernaryOp(where) is a could have multiple inputs. But we only support
// single TensorView input for vectorization.
TensorView* producer_tv = nullptr;
for (auto input : tv_def->inputs()) {
if (!input->isA<TensorView>()) {
continue;
}
NVF_ERROR(
producer_tv == nullptr,
"Vectorization validation only support op with a single TensorView input");
producer_tv = input->as<TensorView>();
auto producer_word_size_it =
GpuLower::current()->vectorizedAccesses().find(producer_tv);
if (producer_word_size_it !=
GpuLower::current()->vectorizedAccesses().end()) {
producer_word_size_it->second =
std::max(vector_word_size, producer_word_size_it->second);
} else {
GpuLower::current()->vectorizedAccesses().emplace(
producer_tv, vector_word_size);
}
}
NVF_ERROR(
producer_tv != nullptr,
"Vectorization validation requires a TensorView input");

VectorizedSetInfo vectorized_set_info;
vectorized_set_info.consumer_tv = tv;
Expand Down Expand Up @@ -798,6 +812,10 @@ void validateAndCollectVectorizeInfo(Fusion* fusion) {
Expr* def = tv->definition();
NVF_ERROR(
def == nullptr || def->isA<LoadStoreOp>() || def->isA<SliceOp>() ||
def->isA<PadOp>() ||
(def->isA<TernaryOp>() &&
def->as<TernaryOp>()->getTernaryOpType() ==
TernaryOpType::Where) ||
(def->isA<ReductionOp>() &&
def->as<ReductionOp>()->serialGridReductionRequested()),
"Vectorized accesses cannot be inline with computation: ",
Expand Down
70 changes: 70 additions & 0 deletions tests/cpp/test_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4041,4 +4041,74 @@ TEST_F(ResizeTest, SliceSliceConcatConcat) {
NVF_CHECK(ref.equal(cg_outputs[0]));
}

// manual scheduling that should have vectorized load on padded inputs.
TEST_F(ResizeTest, VectorizePadLowering) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a test for vectorizing where without using pad?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call. almost forgot that we have where directly 🤕

auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());

const std::vector<int64_t> shape({1024L * 1024L});

auto tv0 = makeContigConcreteTensor(shape);
fusion.addInput(tv0);

auto tv1 = pad(tv0, {IrBuilder::create<Val>(4L), IrBuilder::create<Val>(4L)});
fusion.addOutput(tv1);

tv1->split(0, 4);
tv1->split(0, 128);

tv1->axis(0)->parallelize(ParallelType::BIDx);
tv1->axis(1)->parallelize(ParallelType::TIDx);
tv1->axis(2)->parallelize(ParallelType::Vectorize);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn(shape, options);
std::vector<c10::IValue> aten_inputs({t0});

FusionExecutor fe;
fe.compileFusion(&fusion, aten_inputs);
auto cg_outputs = fe.runFusion(aten_inputs);

auto ref = at::pad(t0, {4, 4});
ASSERT_TRUE(ref.equal(cg_outputs[0]));
}

// manual scheduling that should have vectorized load.
TEST_F(ResizeTest, VectorizeWhereLowering) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());

const std::vector<int64_t> shape({1024L * 1024L});

// Note: nvfuser currently only supports vectorization with a single
// TensorView input.
auto s0 = IrBuilder::create<Val>(DataType::Bool);
fusion.addInput(s0);
auto tv0 = makeContigConcreteTensor(shape);
fusion.addInput(tv0);
auto tv1 = where(s0, IrBuilder::create<Val>(2.0), tv0);
fusion.addOutput(tv1);

tv1->split(0, 4);
tv1->split(0, 128);

tv1->axis(0)->parallelize(ParallelType::BIDx);
tv1->axis(1)->parallelize(ParallelType::TIDx);
tv1->axis(2)->parallelize(ParallelType::Vectorize);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn(shape, options);
std::vector<c10::IValue> aten_inputs({at::Scalar(false), t0});

FusionExecutor fe;
fe.compileFusion(&fusion, aten_inputs);
auto cg_outputs = fe.runFusion(aten_inputs);

// Note: we cannot use at::where, because aten only support tensor as
// predicate.
ASSERT_TRUE(t0.equal(cg_outputs[0]));
}

} // namespace nvfuser
Loading