Skip to content

Commit

Permalink
Move random numbers to use Array instead of uint4/uint2.
Browse files Browse the repository at this point in the history
  • Loading branch information
csarofeen committed Jan 20, 2025
1 parent 3158d84 commit 6465c8b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 54 deletions.
2 changes: 1 addition & 1 deletion csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
const auto& kernel_summary = kernel_->summary();

if (kernel_summary.has_philox_op) {
indent() << "uint4 rng_result;\n";
indent() << "Array<uint32_t, 4> rng_result;\n";
indent() << "nvfuser_index_t rng_subseq = -1;\n";
indent() << "nvfuser_index_t rng_offset = -1;\n";
}
Expand Down
106 changes: 53 additions & 53 deletions runtime/random_numbers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,39 @@ __device__ unsigned int mulhilo32(
return a * b;
}

__device__ uint4 single_round(uint4 ctr, uint2 key) {
__device__ Array<uint32_t, 4> single_round(Array<uint32_t, 4> ctr, Array<uint32_t, 2> key) {
constexpr unsigned long kPhiloxSA = 0xD2511F53;
constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
unsigned int hi0;
unsigned int hi1;
unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
unsigned int lo0 = mulhilo32(kPhiloxSA, ctr[0], &hi0);
unsigned int lo1 = mulhilo32(kPhiloxSB, ctr[2], &hi1);
Array<uint32_t, 4> ret = {hi1 ^ ctr[1] ^ key[0], lo1, hi0 ^ ctr[3] ^ key[1], lo0};
return ret;
}

__device__ uint4 philox(
__device__ Array<uint32_t, 4> philox(
unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
constexpr unsigned long kPhilox10A = 0x9E3779B9;
constexpr unsigned long kPhilox10B = 0xBB67AE85;
uint2 key = {};
key.x = (unsigned int)seed;
key.y = (unsigned int)(seed >> 32);
uint4 counter = make_uint4(0, 0, 0, 0);
counter.x = (unsigned int)(offset);
counter.y = (unsigned int)(offset >> 32);
counter.z = (unsigned int)(subsequence);
counter.w = (unsigned int)(subsequence >> 32);

uint4 output = {};
uint2 key_ = key;
uint4 counter_ = counter;
Array<uint32_t, 2> key;
key[0] = (unsigned int)seed;
key[1] = (unsigned int)(seed >> 32);
Array<uint32_t, 4> counter;
counter[0] = (unsigned int)(offset);
counter[1] = (unsigned int)(offset >> 32);
counter[2] = (unsigned int)(subsequence);
counter[3] = (unsigned int)(subsequence >> 32);

Array<uint32_t, 4> output = {};
Array<uint32_t, 2> key_ = key;
Array<uint32_t, 4> counter_ = counter;
for (int i = 0; i < 9; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
key_[0] += (kPhilox10A);
key_[1] += (kPhilox10B);
}
output = single_round(counter_, key_);
return output;
Expand Down Expand Up @@ -85,27 +85,27 @@ __device__ double uniform(unsigned int x, unsigned int y) {
return result == 1.0 ? 0.0 : result;
}

__device__ double rng_uniform(const uint4& rng_result, int rng_component) {
__device__ double rng_uniform(const Array<uint32_t, 4>& rng_result, int rng_component) {
return uniform(
(&rng_result.x)[rng_component * 2],
(&rng_result.x)[rng_component * 2 + 1]);
rng_result[rng_component * 2],
rng_result[rng_component * 2 + 1]);
}

__device__ float rng_uniformf(const uint4& rng_result, int rng_component) {
return uniformf((&rng_result.x)[rng_component]);
__device__ float rng_uniformf(const Array<uint32_t, 4>& rng_result, int rng_component) {
return uniformf(rng_result[rng_component]);
}

__device__ __half rng_uniform_half(const uint4& rng_result, int rng_component) {
return uniform_half((&rng_result.x)[rng_component]);
__device__ __half rng_uniform_half(const Array<uint32_t, 4>& rng_result, int rng_component) {
return uniform_half(rng_result[rng_component]);
}

__device__ __bfloat
rng_uniform_bfloat(const uint4& rng_result, int rng_component) {
return uniform_bfloat((&rng_result.x)[rng_component]);
rng_uniform_bfloat(const Array<uint32_t, 4>& rng_result, int rng_component) {
return uniform_bfloat(rng_result[rng_component]);
}

__device__ double rng_uniform_range(
const uint4& rng_result,
const Array<uint32_t, 4>& rng_result,
int rng_component,
double from,
double to) {
Expand All @@ -115,7 +115,7 @@ __device__ double rng_uniform_range(
}

__device__ float rng_uniform_rangef(
const uint4& rng_result,
const Array<uint32_t, 4>& rng_result,
int rng_component,
float from,
float to) {
Expand All @@ -125,23 +125,23 @@ __device__ float rng_uniform_rangef(
}

__device__ __half rng_uniform_range_half(
const uint4& rng_result,
const Array<uint32_t, 4>& rng_result,
int rng_component,
float from,
float to) {
auto range = to - from;
float uniform01 = raw_uniform_float((&rng_result.x)[rng_component]);
float uniform01 = raw_uniform_float(rng_result[rng_component]);
__half result = __float2half(from + range * uniform01);
return __heq(result, __float2half(to)) ? __float2half(from) : result;
}

__device__ __bfloat rng_uniform_range_bfloat(
const uint4& rng_result,
const Array<uint32_t, 4>& rng_result,
int rng_component,
float from,
float to) {
auto range = to - from;
float uniform01 = raw_uniform_float((&rng_result.x)[rng_component]);
float uniform01 = raw_uniform_float(rng_result[rng_component]);
__bfloat result = __float2bfloat(from + range * uniform01);
return __heq(result, __float2bfloat(to)) ? __float2bfloat(from) : result;
}
Expand Down Expand Up @@ -174,39 +174,39 @@ __device__ double normal(
}

__device__ double rng_normal_standard(
const uint4& rng_result,
const Array<uint32_t, 4>& rng_result,
int rng_component) {
return normal(
rng_result.x, rng_result.y, rng_result.z, rng_result.w, rng_component);
rng_result[0], rng_result[1], rng_result[2], rng_result[3], rng_component);
}

__device__ float rng_normal_standardf(
const uint4& rng_result,
const Array<uint32_t, 4>& rng_result,
int rng_component) {
return normalf(
(&rng_result.x)[rng_component / 2 * 2],
(&rng_result.y)[rng_component / 2 * 2],
rng_result[rng_component / 2 * 2],
rng_result[1 + rng_component / 2 * 2],
rng_component);
}

__device__ __half
rng_normal_standard_half(const uint4& rng_result, int rng_component) {
rng_normal_standard_half(const Array<uint32_t, 4>& rng_result, int rng_component) {
return __float2half(normalf(
(&rng_result.x)[rng_component / 2 * 2],
(&rng_result.y)[rng_component / 2 * 2],
rng_result[rng_component / 2 * 2],
rng_result[1 + rng_component / 2 * 2],
rng_component));
}

__device__ __bfloat
rng_normal_standard_bfloat(const uint4& rng_result, int rng_component) {
rng_normal_standard_bfloat(const Array<uint32_t, 4>& rng_result, int rng_component) {
return __float2bfloat(normalf(
(&rng_result.x)[rng_component / 2 * 2],
(&rng_result.y)[rng_component / 2 * 2],
rng_result[rng_component / 2 * 2],
rng_result[1 + rng_component / 2 * 2],
rng_component));
}

__device__ double rng_normal_general(
const uint4& rng_result,
const Array<uint32_t, 4>& rng_result,
int rng_component,
double mean,
double std) {
Expand All @@ -215,7 +215,7 @@ __device__ double rng_normal_general(
}

__device__ float rng_normal_generalf(
const uint4& rng_result,
const Array<uint32_t, 4>& rng_result,
int rng_component,
float mean,
float std) {
Expand All @@ -224,25 +224,25 @@ __device__ float rng_normal_generalf(
}

__device__ __half rng_normal_general_half(
const uint4& rng_result,
const Array<uint32_t, 4>& rng_result,
int rng_component,
float mean,
float std) {
auto normal01 = normalf(
(&rng_result.x)[rng_component / 2 * 2],
(&rng_result.y)[rng_component / 2 * 2],
rng_result[rng_component / 2 * 2],
rng_result[1 + rng_component / 2 * 2],
rng_component);
return __float2half(normal01 * std + mean);
}

__device__ __bfloat rng_normal_general_bfloat(
const uint4& rng_result,
const Array<uint32_t, 4>& rng_result,
int rng_component,
float mean,
float std) {
auto normal01 = normalf(
(&rng_result.x)[rng_component / 2 * 2],
(&rng_result.y)[rng_component / 2 * 2],
rng_result[rng_component / 2 * 2],
rng_result[1 + rng_component / 2 * 2],
rng_component);
return __float2bfloat(normal01 * std + mean);
}

0 comments on commit 6465c8b

Please sign in to comment.