Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Int based RNG #3733

Draft
wants to merge 11 commits into
base: int_types
Choose a base branch
from
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/device_lower/pass/misaligned_vectorization.cpp
${NVFUSER_SRCS_DIR}/device_lower/pass/predicate.cpp
${NVFUSER_SRCS_DIR}/device_lower/pass/replace_size.cpp
${NVFUSER_SRCS_DIR}/device_lower/pass/rng.cpp
${NVFUSER_SRCS_DIR}/device_lower/pass/scalar_hoist.cpp
${NVFUSER_SRCS_DIR}/device_lower/pass/unroll.cpp
${NVFUSER_SRCS_DIR}/device_lower/pass/vectorize_welford.cpp
Expand Down Expand Up @@ -555,6 +556,7 @@ list(APPEND JIT_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_gpu1.cpp
${NVFUSER_ROOT}/tests/cpp/test_gpu2.cpp
${NVFUSER_ROOT}/tests/cpp/test_gpu3.cpp
${NVFUSER_ROOT}/tests/cpp/test_gpu4.cpp
${NVFUSER_ROOT}/tests/cpp/test_gpu_compute_with.cpp
${NVFUSER_ROOT}/tests/cpp/test_gpu_fused_reduction.cpp
${NVFUSER_ROOT}/tests/cpp/test_gpu_indexing_ops.cpp
Expand Down
15 changes: 7 additions & 8 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,9 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
} else if (v->isA<TensorView>()) {
tv = v->as<TensorView>();
}
if (tv && aligned_array_of_regs_.count(tv)) {
if (tv &&
(aligned_array_of_regs_.count(tv) ||
tv->getMemoryType() == MemoryType::Local)) {
return genVariableName(tv).append(".array");
} else {
return genVariableName(v);
Expand Down Expand Up @@ -358,7 +360,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
const auto& kernel_summary = kernel_->summary();

if (kernel_summary.has_philox_op) {
indent() << "uint4 rng_result;\n";
indent() << "Array<uint32_t, 4> rng_result;\n";
indent() << "nvfuser_index_t rng_subseq = -1;\n";
indent() << "nvfuser_index_t rng_offset = -1;\n";
}
Expand Down Expand Up @@ -3169,14 +3171,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
break;
case MemoryType::Local: {
auto va = kernel_->summary().vectorized_accesses;
indent() << "Array<" << buffer_dtype << ", " << genInline(size)
<< ", " << (va.find(tv) != va.end() ? va.at(tv) : 1) << "> "
<< genVariableName(tv) << ";\n";
if (va.find(tv) != va.end()) {
indent() << "Array<" << buffer_dtype << ", " << genInline(size)
<< ", " << va.at(tv) << "> " << genVariableName(tv)
<< ";\n";
aligned_array_of_regs_.insert(tv);
} else {
indent() << buffer_dtype << " " << genVariableName(tv) << "["
<< genInline(size) << "];\n";
}
} break;
default:
Expand Down
2 changes: 2 additions & 0 deletions csrc/device_lower/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <device_lower/pass/misaligned_vectorization.h>
#include <device_lower/pass/predicate.h>
#include <device_lower/pass/replace_size.h>
#include <device_lower/pass/rng.h>
#include <device_lower/pass/unroll.h>
#include <device_lower/pass/vectorize_welford.h>
#include <device_lower/pass/warp_reduce.h>
Expand Down Expand Up @@ -282,6 +283,7 @@ GpuLower::GpuLower(Fusion* fusion, const CompileParams& cparams)
generateConditionalFromPredicate},
{"vectorizeWelford", vectorizeWelford},
{"allocateCommonScalars", allocateCommonScalars},
{"addRNG", addRNG},
{"insertMagicZero", insertMagicZero},
{"KIRCleaner", KIRCleaner::cleanUp},
{"instrumentKernel", instrumentKernel},
Expand Down
190 changes: 190 additions & 0 deletions csrc/device_lower/pass/rng.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <device_lower/pass/magic_zero.h>

#include <device_lower/analysis/index_compute.h>
#include <device_lower/lower2device.h>
#include <dispatch.h>
#include <instrumentation.h>
#include <ir/utils.h>
#include <kernel_ir_dispatch.h>

namespace nvfuser {

namespace {

std::tuple<Val*, Expr*> createAndAllocNS(
std::string name,
DataType dtype = DataType::Index) {
Val* val = IrBuilder::create<NamedScalar>(name, dtype);
auto alloc = IrBuilder::create<kir::Allocate>(
val, MemoryType::Local, GpuLower::current()->kernel()->oneVal());
return std::make_tuple(val, alloc);
}

class RNGInserter : public kir::ExprMutator {
public:
static std::vector<Expr*> insert(const std::vector<Expr*>& exprs) {
RNGInserter inserter(exprs);
return inserter.exprs_;
}

private:
Val* rng_subseq = nullptr;
Val* rng_offset = nullptr;
TensorView* rng_result = nullptr;
const std::vector<Expr*>& exprs;

struct InsertionInfo {
Scope* scope = nullptr;
ForLoop* fl = nullptr;
};

RNGInserter(const std::vector<Expr*>& _exprs) : exprs(_exprs) {
kir::ExprMutator::traverseAndInsert(exprs);
}

void handle(RNGOp* rop) final {
// Set prologue if not already set
if (rng_subseq == nullptr) {
NVF_ERROR(!exprs.empty());
auto neg_1 = IrBuilder::create<Val>(-1, DataType::Index);
auto subseq_tuple = createAndAllocNS("rng_subseq");
kir::ExprMutator::registerInsertBefore(
exprs.front(), std::get<1>(subseq_tuple), nullptr);
kir::ExprMutator::registerInsertBefore(
exprs.front(),
IrBuilder::create<LoadStoreOp>(
LoadStoreOpType::Set, std::get<0>(subseq_tuple), neg_1),
nullptr);

rng_subseq = std::get<0>(subseq_tuple);

auto offset_tuple = createAndAllocNS("rng_offset");
kir::ExprMutator::registerInsertBefore(
exprs.front(), std::get<1>(offset_tuple), nullptr);
kir::ExprMutator::registerInsertBefore(
exprs.front(),
IrBuilder::create<LoadStoreOp>(
LoadStoreOpType::Set, std::get<0>(offset_tuple), neg_1),
nullptr);

rng_offset = std::get<0>(offset_tuple);

rng_result = TensorViewBuilder()
.shape(std::vector<int64_t>{4})
.dtype(DataType::UInt64)
.contiguity(true)
.build();
rng_result->setMemoryType(MemoryType::Local);

auto rng_result_alloc =
IrBuilder::create<kir::Allocate>(rng_result, MemoryType::Local);
kir::ExprMutator::registerInsertBefore(
exprs.front(), rng_result_alloc, nullptr);
}

auto index_tuple =
createAndAllocNS("liner_index" + std::to_string(rop->name()));
kir::ExprMutator::registerInsertBefore(rop, std::get<1>(index_tuple));
kir::ExprMutator::registerInsertBefore(
rop,
IrBuilder::create<LoadStoreOp>(
LoadStoreOpType::Set,
std::get<0>(index_tuple),
rop->getPhiloxIndex()));

auto multiple =
IrBuilder::create<Val>(rop->getPhiloxMultiple(), DataType::Index);

auto rop_subseq_tuple =
createAndAllocNS("rng_subseq" + std::to_string(rop->name()));
kir::ExprMutator::registerInsertBefore(rop, std::get<1>(rop_subseq_tuple));
kir::ExprMutator::registerInsertBefore(
rop,
IrBuilder::create<BinaryOp>(
BinaryOpType::Div,
std::get<0>(rop_subseq_tuple),
std::get<0>(index_tuple),
multiple));

auto rop_component_tuple =
createAndAllocNS("rng_component" + std::to_string(rop->name()));
kir::ExprMutator::registerInsertBefore(
rop, std::get<1>(rop_component_tuple));
kir::ExprMutator::registerInsertBefore(
rop,
IrBuilder::create<BinaryOp>(
BinaryOpType::Mod,
std::get<0>(rop_component_tuple),
std::get<0>(index_tuple),
multiple));

auto rop_offset_tuple =
createAndAllocNS("rng_offset" + std::to_string(rop->name()));
kir::ExprMutator::registerInsertBefore(rop, std::get<1>(rop_offset_tuple));
kir::ExprMutator::registerInsertBefore(
rop,
IrBuilder::create<LoadStoreOp>(
LoadStoreOpType::Set,
std::get<0>(rop_offset_tuple),
rop->getRNGOffsetVal()));

kir::IfThenElse* ite = IrBuilder::create<kir::IfThenElse>(
IrBuilder::create<kir::Predicate>(SimplifyingIrBuilder::logicalOrExpr(
SimplifyingIrBuilder::neExpr(
rng_subseq, std::get<0>(rop_subseq_tuple)),
SimplifyingIrBuilder::neExpr(
rng_offset, std::get<0>(rop_offset_tuple)))));

ite->thenBody().push_back(IrBuilder::create<TernaryOp>(
TernaryOpType::Philox,
rng_result,
rop->getRNGSeedVal(),
rng_subseq,
rng_offset));

ite->thenBody().push_back(IrBuilder::create<LoadStoreOp>(
LoadStoreOpType::Set, rng_subseq, std::get<0>(rop_subseq_tuple)));

ite->thenBody().push_back(IrBuilder::create<LoadStoreOp>(
LoadStoreOpType::Set, rng_offset, std::get<0>(rop_offset_tuple)));

kir::ExprMutator::registerInsertBefore(rop, ite);
}

std::vector<InsertionInfo> insertion_list_;
};

} // namespace

std::vector<Expr*> addRNG(const std::vector<Expr*>& exprs) {
FUSER_PERF_SCOPE("GpuLower::Lower::addRNG");
// Check if magic zero was even used, if not we don't have to define it or
// update it.
const auto gpu_lower = GpuLower::current();
auto kernel = gpu_lower->kernel();
const bool has_rng = std::any_of(
kernel->exprs().begin(), kernel->exprs().end(), [](Expr* expr) {
return expr->isA<RNGOp>();
});

if (!has_rng) {
return exprs;
}
auto exprs_ = RNGInserter::insert(exprs);
std::cout << "====================" << std::endl;
for (auto expr : exprs_) {
std::cout << expr->toString() << std::endl;
}
std::cout << "====================" << std::endl;
// NVF_THROW("throw");
return exprs_;
}

} // namespace nvfuser
16 changes: 16 additions & 0 deletions csrc/device_lower/pass/rng.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

#include <ir/all_nodes.h>
#include <kernel_ir.h>
#include <kernel_ir_dispatch.h>

namespace nvfuser {
std::vector<Expr*> addRNG(const std::vector<Expr*>& exprs);
} // namespace nvfuser
4 changes: 4 additions & 0 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2001,6 +2001,10 @@ class NVF_API NamedScalar : public Val {
p == ParallelType::BIDz);
}

bool isParallelScalar() const {
return isGridDim() || isBlockDim() || isBlockIdx() || isThreadIdx();
}

//! Return the named scalar extent of a parallel dimension (e.g. blockDim.x)
//! WARNING: Only works with Fusion container at the moment
static NamedScalar* getParallelDim(ParallelType p_type);
Expand Down
11 changes: 10 additions & 1 deletion csrc/kernel_ir_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ void ConstIrVisitor::handle(const IfThenElse* ite) {

std::vector<Expr*> ExprMutator::mutate(bool reverse_order) {
if (insertions_.empty() && replacements_.empty() && removal_.empty()) {
std::cout << "ExprMutator::Empty" << std::endl;
return exprs_;
}

Expand All @@ -104,7 +105,7 @@ std::vector<Expr*> ExprMutator::mutate(bool reverse_order) {
}
auto pos_it = std::find(exprs_.begin(), exprs_.end(), info.reference);
NVF_ERROR(
pos_it != exprs_.end(),
pos_it >= exprs_.begin() && pos_it != exprs_.end(),
"Issue finding reference expression for insertion.");
if (info.mode == MutationMode::BEFORE) {
exprs_.insert(pos_it, info.new_expr);
Expand Down Expand Up @@ -132,6 +133,7 @@ std::vector<Expr*> ExprMutator::mutate(bool reverse_order) {
}
} else {
for (auto insertion_info : insertions_) {
std::cout << "ExprMutator::run_insertion" << std::endl;
run_insertion(insertion_info);
}
}
Expand Down Expand Up @@ -173,6 +175,12 @@ std::vector<Expr*> ExprMutator::mutate(bool reverse_order) {
insertions_.clear();
replacements_.clear();

std::cout << "------------" << std::endl;
for (auto expr : exprs_) {
std::cout << expr->toString() << std::endl;
}
std::cout << "------------" << std::endl;

return exprs_;
}

Expand Down Expand Up @@ -208,6 +216,7 @@ void ExprMutator::registerInsertBefore(
Expr* reference,
Expr* new_expr,
Scope* scope) {
std::cout << "Register insert before" << std::endl;
registerMutation(reference, new_expr, scope, MutationMode::BEFORE);
}

Expand Down
2 changes: 2 additions & 0 deletions csrc/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,8 @@ static const char* ternary_op_type2string(TernaryOpType t) {
return "threshold";
case TernaryOpType::Where:
return "where";
case TernaryOpType::Philox:
return "philox";
default:
NVF_THROW("Unexpected TernaryOpType");
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ bool isIntegerOp(const BinaryOpType bopt);
// Return if output of operator should be a boolean
bool isLogicalOp(const BinaryOpType bopt);

enum class TernaryOpType { Clamp, Lerp, Threshold, Where };
enum class TernaryOpType { Clamp, Lerp, Threshold, Where, Philox };

enum class ParallelType {
DIDx,
Expand Down
Loading
Loading