diff --git a/csrc/device_lower/pass/rng.cpp b/csrc/device_lower/pass/rng.cpp index 5930f6a2d32..4eb96fbe163 100644 --- a/csrc/device_lower/pass/rng.cpp +++ b/csrc/device_lower/pass/rng.cpp @@ -141,6 +141,14 @@ class RNGInserter : public kir::ExprMutator { rng_subseq, std::get<0>(rop_subseq_tuple)), SimplifyingIrBuilder::neExpr( rng_offset, std::get<0>(rop_offset_tuple))))); + + ite->thenBody().push_back(IrBuilder::create( + TernaryOpType::Philox, + rng_result, + rop->getRNGSeedVal(), + rng_subseq, + rng_offset)); + ite->thenBody().push_back(IrBuilder::create( LoadStoreOpType::Set, rng_subseq, std::get<0>(rop_subseq_tuple))); @@ -148,47 +156,6 @@ class RNGInserter : public kir::ExprMutator { LoadStoreOpType::Set, rng_offset, std::get<0>(rop_offset_tuple))); kir::ExprMutator::registerInsertBefore(rop, ite); - - kir::ExprMutator::registerInsertBefore( - rop, - IrBuilder::create( - TernaryOpType::Philox, - IrBuilder::create("rng_result", DataType::Index), - rop->getRNGSeedVal(), - rng_subseq, - rng_offset)); - - // auto rop_component = - // createAndAllocNS("rng_component" + std::to_string(rop->name())); - - // auto rng_subseq = SimplifyingIrBuilder::div(linear_index, multiple); - // auto rng_component = SimplifyingIrBuilder::mod(linear_index, multiple); - // auto rng_offset = rop->getRNGOffsetVal(); - - // nvfuser_index_t rng_offset215 = (((ptr2 == nullptr) ? i3 : ((*ptr2) + - // i3)) / 4LL); - // if (rng_subseq != rng_subseq215 || rng_offset != rng_offset215) { - // rng_result = philox(((ptr0 == nullptr) ? i1 : (*ptr0)), - // rng_subseq215, rng_offset215); rng_subseq = rng_subseq215; rng_offset - // = rng_offset215; - // } - // T1[i5] = rng_uniformf(rng_result, rng_component215); - // } - - // if (fl->isUnrolled()) { - // if (scope_.empty()) { - // kir::ExprMutator::registerInsertAfter( - // fl, IrBuilder::create()); - // } else { - // NVF_ERROR( - // !scope_.back()->exprs().empty(), "Not expecting an empty loop."); - // kir::ExprMutator::registerInsertAfter( - // fl, IrBuilder::create(), scope_.back()); - // } - // } else { - // kir::ExprMutator::handle(fl); - // } - // NVF_THROW("TEST"); } std::vector insertion_list_;