-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
245 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
// clang-format off | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
* All rights reserved. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
// clang-format on | ||
#include <device_lower/pass/magic_zero.h> | ||
|
||
#include <device_lower/analysis/index_compute.h> | ||
#include <device_lower/lower2device.h> | ||
#include <dispatch.h> | ||
#include <instrumentation.h> | ||
#include <ir/utils.h> | ||
#include <kernel_ir_dispatch.h> | ||
|
||
namespace nvfuser { | ||
|
||
namespace { | ||
|
||
class RNGInserter : public kir::ExprMutator { | ||
public: | ||
static std::vector<Expr*> insert(const std::vector<Expr*>& exprs) { | ||
RNGInserter inserter(exprs); | ||
return inserter.exprs_; | ||
} | ||
|
||
private: | ||
Val* rng_subseq; | ||
Val* rng_offset; | ||
struct InsertionInfo { | ||
Scope* scope = nullptr; | ||
ForLoop* fl = nullptr; | ||
}; | ||
|
||
RNGInserter(const std::vector<Expr*>& exprs) { | ||
NVF_ERROR(!exprs.empty()); | ||
auto neg_1 = IrBuilder::create<Val>(-1, DataType::Index); | ||
auto rng_subseq = | ||
IrBuilder::create<NamedScalar>("rng_subseq", DataType::Index); | ||
auto rng_offset = | ||
IrBuilder::create<NamedScalar>("rng_offset", DataType::Index); | ||
kir::ExprMutator::registerInsertBefore( | ||
exprs.front(), | ||
IrBuilder::create<LoadStoreOp>( | ||
LoadStoreOpType::Set, rng_subseq, neg_1)); | ||
kir::ExprMutator::registerInsertBefore( | ||
exprs.front(), | ||
IrBuilder::create<LoadStoreOp>( | ||
LoadStoreOpType::Set, rng_offset, neg_1)); | ||
kir::ExprMutator::traverseAndInsert(exprs); | ||
} | ||
|
||
void handle(RNGOp* rng_op) final { | ||
std::cout << rng_op->toString() << std::endl; | ||
// auto linear_index = rng_op->getPhiloxIndex(); | ||
// auto multiple = rng_op->getPhiloxMultiple(); | ||
// auto rng_subseq = SimplifyingIrBuilder::div(linear_index, multiple); | ||
// auto rng_component = SimplifyingIrBuilder::mod(linear_index, multiple); | ||
// auto rng_offset = rng_op->getRNGOffsetVal(); | ||
|
||
// nvfuser_index_t rng_offset215 = (((ptr2 == nullptr) ? i3 : ((*ptr2) + | ||
// i3)) / 4LL); | ||
// if (rng_subseq != rng_subseq215 || rng_offset != rng_offset215) { | ||
// rng_result = philox(((ptr0 == nullptr) ? i1 : (*ptr0)), | ||
// rng_subseq215, rng_offset215); rng_subseq = rng_subseq215; rng_offset | ||
// = rng_offset215; | ||
// } | ||
// T1[i5] = rng_uniformf(rng_result, rng_component215); | ||
// } | ||
|
||
// if (fl->isUnrolled()) { | ||
// if (scope_.empty()) { | ||
// kir::ExprMutator::registerInsertAfter( | ||
// fl, IrBuilder::create<kir::UpdateMagicZero>()); | ||
// } else { | ||
// NVF_ERROR( | ||
// !scope_.back()->exprs().empty(), "Not expecting an empty loop."); | ||
// kir::ExprMutator::registerInsertAfter( | ||
// fl, IrBuilder::create<kir::UpdateMagicZero>(), scope_.back()); | ||
// } | ||
// } else { | ||
// kir::ExprMutator::handle(fl); | ||
// } | ||
// NVF_THROW("TEST"); | ||
} | ||
|
||
std::vector<InsertionInfo> insertion_list_; | ||
}; | ||
|
||
} // namespace | ||
|
||
std::vector<Expr*> addRNG(const std::vector<Expr*>& exprs) { | ||
FUSER_PERF_SCOPE("GpuLower::Lower::addRNG"); | ||
// Check if magic zero was even used, if not we don't have to define it or | ||
// update it. | ||
const auto gpu_lower = GpuLower::current(); | ||
auto kernel = gpu_lower->kernel(); | ||
const bool has_rng = std::any_of( | ||
kernel->exprs().begin(), kernel->exprs().end(), [](Expr* expr) { | ||
return expr->isA<RNGOp>(); | ||
}); | ||
|
||
if (!has_rng) { | ||
return exprs; | ||
} | ||
|
||
return RNGInserter::insert(exprs); | ||
} | ||
|
||
} // namespace nvfuser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// clang-format off | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
* All rights reserved. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
// clang-format on | ||
#pragma once | ||
|
||
#include <ir/all_nodes.h> | ||
#include <kernel_ir.h> | ||
#include <kernel_ir_dispatch.h> | ||
|
||
namespace nvfuser { | ||
std::vector<Expr*> addRNG(const std::vector<Expr*>& exprs); | ||
} // namespace nvfuser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
// clang-format off | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
* All rights reserved. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
// clang-format on | ||
#include <csrc/exceptions.h> | ||
#include <gmock/gmock-matchers.h> | ||
#include <gtest/gtest.h> | ||
|
||
#include <codegen.h> | ||
#include <debug.h> | ||
#include <device_lower/lower2device.h> | ||
#include <device_lower/pass/magic_zero.h> | ||
#include <device_lower/pass/replace_size.h> | ||
#include <disjoint_set.h> | ||
#include <expr_evaluator.h> | ||
#include <fusion.h> | ||
#include <fusion_segmenter.h> | ||
#include <grouped_reduction.h> | ||
#include <id_model/id_model.h> | ||
#include <ir/all_nodes.h> | ||
#include <ir/builder.h> | ||
#include <ir/graphviz.h> | ||
#include <ir/iostream.h> | ||
#include <ir/utils.h> | ||
#include <iter_visitor.h> | ||
#include <kernel_ir.h> | ||
#include <kernel_ir_dispatch.h> | ||
#include <logical_domain_map.h> | ||
#include <ops/all_ops.h> | ||
#include <runtime/executor.h> | ||
#include <runtime/executor_params.h> | ||
#include <runtime/fusion_executor_cache.h> | ||
#include <scheduler/all_schedulers.h> | ||
#include <scheduler/reduction_utils.h> | ||
#include <scheduler/tools/abstract_tensor.h> | ||
#include <scheduler/tools/inlining.h> | ||
#include <scheduler/utils.h> | ||
#include <tests/cpp/utils.h> | ||
#include <tests/cpp/validator.h> | ||
#include <transform_replay.h> | ||
#include <transform_rfactor.h> | ||
|
||
#include <torch/csrc/jit/api/function_impl.h> | ||
#include <torch/csrc/jit/codegen/cuda/interface.h> | ||
#include <torch/torch.h> | ||
|
||
#include <ATen/cuda/CUDAContext.h> | ||
#include <ATen/cuda/Exceptions.h> | ||
#include <c10/cuda/CUDAStream.h> | ||
|
||
#include <algorithm> | ||
#include <cmath> | ||
#include <sstream> | ||
#include "parallel_dimension_map.h" | ||
|
||
namespace nvfuser { | ||
|
||
using namespace at::indexing; | ||
|
||
TEST_F(NVFuserTest, IntRNG_CUDA) { | ||
Fusion fusion; | ||
FusionGuard fg(&fusion); | ||
|
||
auto input_tv = makeContigConcreteTensor({4 * 128 * 4}); | ||
fusion.addInput(input_tv); | ||
|
||
constexpr float kDropoutProbability = 0.9; | ||
constexpr float kScale = 1.0f / kDropoutProbability; | ||
|
||
auto prob = IrBuilder::create<Val>(kDropoutProbability); | ||
auto scale = IrBuilder::create<Val>(kScale); | ||
|
||
// dropout start | ||
auto rand_vals = rand_like(input_tv); | ||
auto mask = lt(rand_vals, prob); | ||
auto apply_mask = mul(input_tv, mask); | ||
auto output_tv = mul(apply_mask, scale); | ||
// dropout end | ||
// fusion.addOutput(mask); | ||
fusion.addOutput(output_tv); | ||
|
||
auto inp_cache = input_tv->cacheAfter(); | ||
output_tv->cacheBefore(); | ||
|
||
output_tv->split(0, 4); | ||
output_tv->split(0, 128); | ||
output_tv->axis(0)->parallelize(ParallelType::BIDx); | ||
|
||
TransformPropagator propagator(output_tv); | ||
MaxLogicalDomainInfoSpanningTree spanning_tree(output_tv); | ||
spanning_tree.traverse(&propagator); | ||
scheduler_utils::parallelizeAllLike(output_tv); | ||
|
||
inp_cache->axis(-1)->parallelize(ParallelType::Vectorize); | ||
rand_vals->axis(-1)->parallelize(ParallelType::Unroll); | ||
output_tv->axis(-1)->parallelize(ParallelType::Vectorize); | ||
|
||
inlineMost(); | ||
|
||
fusion.printMath(); | ||
fusion.printKernel(); | ||
} | ||
|
||
} // namespace nvfuser |