Skip to content

Commit

Permalink
Fix Philox expr
Browse files Browse the repository at this point in the history
  • Loading branch information
csarofeen committed Jan 20, 2025
1 parent 9e7a4cd commit 2fa293d
Showing 1 changed file with 8 additions and 41 deletions.
49 changes: 8 additions & 41 deletions csrc/device_lower/pass/rng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,54 +141,21 @@ class RNGInserter : public kir::ExprMutator {
rng_subseq, std::get<0>(rop_subseq_tuple)),
SimplifyingIrBuilder::neExpr(
rng_offset, std::get<0>(rop_offset_tuple)))));

ite->thenBody().push_back(IrBuilder::create<TernaryOp>(
TernaryOpType::Philox,
rng_result,
rop->getRNGSeedVal(),
rng_subseq,
rng_offset));

ite->thenBody().push_back(IrBuilder::create<LoadStoreOp>(
LoadStoreOpType::Set, rng_subseq, std::get<0>(rop_subseq_tuple)));

ite->thenBody().push_back(IrBuilder::create<LoadStoreOp>(
LoadStoreOpType::Set, rng_offset, std::get<0>(rop_offset_tuple)));

kir::ExprMutator::registerInsertBefore(rop, ite);

kir::ExprMutator::registerInsertBefore(
rop,
IrBuilder::create<TernaryOp>(
TernaryOpType::Philox,
IrBuilder::create<NamedScalar>("rng_result", DataType::Index),
rop->getRNGSeedVal(),
rng_subseq,
rng_offset));

// auto rop_component =
// createAndAllocNS("rng_component" + std::to_string(rop->name()));

// auto rng_subseq = SimplifyingIrBuilder::div(linear_index, multiple);
// auto rng_component = SimplifyingIrBuilder::mod(linear_index, multiple);
// auto rng_offset = rop->getRNGOffsetVal();

// nvfuser_index_t rng_offset215 = (((ptr2 == nullptr) ? i3 : ((*ptr2) +
// i3)) / 4LL);
// if (rng_subseq != rng_subseq215 || rng_offset != rng_offset215) {
// rng_result = philox(((ptr0 == nullptr) ? i1 : (*ptr0)),
// rng_subseq215, rng_offset215); rng_subseq = rng_subseq215; rng_offset
// = rng_offset215;
// }
// T1[i5] = rng_uniformf(rng_result, rng_component215);
// }

// if (fl->isUnrolled()) {
// if (scope_.empty()) {
// kir::ExprMutator::registerInsertAfter(
// fl, IrBuilder::create<kir::UpdateMagicZero>());
// } else {
// NVF_ERROR(
// !scope_.back()->exprs().empty(), "Not expecting an empty loop.");
// kir::ExprMutator::registerInsertAfter(
// fl, IrBuilder::create<kir::UpdateMagicZero>(), scope_.back());
// }
// } else {
// kir::ExprMutator::handle(fl);
// }
// NVF_THROW("TEST");
}

std::vector<InsertionInfo> insertion_list_;
Expand Down

0 comments on commit 2fa293d

Please sign in to comment.