From 3158d844b823e0d3a399334fd30fe0ab2187b2ec Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Mon, 20 Jan 2025 12:10:50 -0800 Subject: [PATCH] Prefer Array class over register arrays. i.e. Array rather than float[2]. --- csrc/codegen.cpp | 13 ++++++------ tests/cpp/test_loop_rotation.cpp | 32 +++++++++++++++--------------- tests/cpp/test_scalar_hoisting.cpp | 2 +- 3 files changed, 23 insertions(+), 24 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index fe00c1f8105..643c9b67838 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -265,7 +265,9 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { } else if (v->isA()) { tv = v->as(); } - if (tv && aligned_array_of_regs_.count(tv)) { + if (tv && + (aligned_array_of_regs_.count(tv) || + tv->getMemoryType() == MemoryType::Local)) { return genVariableName(tv).append(".array"); } else { return genVariableName(v); @@ -3169,14 +3171,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { break; case MemoryType::Local: { auto va = kernel_->summary().vectorized_accesses; + indent() << "Array<" << buffer_dtype << ", " << genInline(size) + << ", " << (va.find(tv) != va.end() ? va.at(tv) : 1) << "> " + << genVariableName(tv) << ";\n"; if (va.find(tv) != va.end()) { - indent() << "Array<" << buffer_dtype << ", " << genInline(size) - << ", " << va.at(tv) << "> " << genVariableName(tv) - << ";\n"; aligned_array_of_regs_.insert(tv); - } else { - indent() << buffer_dtype << " " << genVariableName(tv) << "[" - << genInline(size) << "];\n"; } } break; default: diff --git a/tests/cpp/test_loop_rotation.cpp b/tests/cpp/test_loop_rotation.cpp index 4b5122a66fd..d8ae9d49bda 100644 --- a/tests/cpp/test_loop_rotation.cpp +++ b/tests/cpp/test_loop_rotation.cpp @@ -41,8 +41,8 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor i1 = T0.alloc_stride[0LL] * i0; nvfuser_index_t i2; i2 = 3LL * i0; - float T1[1LL]; - float T2[1LL]; + Array T1; + Array T2; T1[0LL] = 0LL; T1[0LL] = T0[i1]; @@ -53,7 +53,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor for(nvfuser_index_t i3 = 0LL; i3 < 3LL; ++i3) { nvfuser_index_t i4; i4 = (1LL + i3) + nvfuser_zero; - float T3[1LL]; + Array T3; T3[0LL] = T2[0LL]; T4[(i2 + (i3 + nvfuser_zero))] @@ -101,8 +101,8 @@ TEST_F(LoopRotationTest, RotateOuter) { const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T4) { NVFUSER_DEFINE_MAGIC_ZERO; - float T1[3LL]; - float T2[3LL]; + Array T1; + Array T2; #pragma unroll for(nvfuser_index_t i0 = 0LL; i0 < 3LL; ++i0) { T1[i0] = 0LL; @@ -202,8 +202,8 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor i0 = T0.logical_size[0LL] * T0.logical_size[1LL]; nvfuser_index_t i1; i1 = ceilDiv(i0, 5LL); - float T1[5LL]; - float T2[5LL]; + Array T1; + Array T2; #pragma unroll for(nvfuser_index_t i2 = 0LL; i2 < 5LL; ++i2) { T1[i2] = 0LL; @@ -306,7 +306,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor NVFUSER_DEFINE_MAGIC_ZERO; nvfuser_index_t i0; i0 = 4LL * T0.alloc_stride[0LL]; - float T1[15LL]; + Array T1; #pragma unroll 4 for(nvfuser_index_t i1 = 0LL; i1 < 4LL; ++i1) { nvfuser_index_t i2; @@ -328,7 +328,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor } } NVFUSER_UPDATE_MAGIC_ZERO; - float T2[3LL]; + Array T2; #pragma unroll for(nvfuser_index_t i6 = 0LL; i6 < 3LL; ++i6) { T2[i6] @@ -362,7 +362,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor } } NVFUSER_UPDATE_MAGIC_ZERO; - float T3[3LL]; + Array T3; #pragma unroll for(nvfuser_index_t i14 = 0LL; i14 < 3LL; ++i14) { T3[i14] @@ -421,7 +421,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor i1 = 5LL * T0.alloc_stride[0LL]; bool b2; b2 = 4LL < T0.logical_size[0LL]; - float T1[15LL]; + Array T1; #pragma unroll for(nvfuser_index_t i3 = 0LL; i3 < 3LL; ++i3) { T1[i3] = 0LL; @@ -454,7 +454,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor } } NVFUSER_UPDATE_MAGIC_ZERO; - float T2[3LL]; + Array T2; #pragma unroll for(nvfuser_index_t i3 = 0LL; i3 < 3LL; ++i3) { T1[(12LL + i3)] = 0LL; @@ -486,7 +486,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor i13 = 3LL * ((1LL + i9) % 5LL); bool b14; b14 = (5LL + i9) < T0.logical_size[0LL]; - float T3[3LL]; + Array T3; #pragma unroll for(nvfuser_index_t i15 = 0LL; i15 < 3LL; ++i15) { T3[i15] @@ -599,7 +599,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor } NVFUSER_UPDATE_MAGIC_ZERO; asm volatile("cp.async.wait_group %0;\n"::"n"(3LL)); - float T1[2LL]; + Array T1; T1[0LL] = T4[0LL]; #pragma unroll 4 @@ -637,14 +637,14 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor for(nvfuser_index_t i14 = 0LL; i14 < 2LL; ++i14) { T1[((1LL + i14) % 2LL)] = T4[(i11 + i14)]; - float T2[1LL]; + Array T2; T2[0LL] = T1[i14]; T3[(i12 + (i14 + nvfuser_zero))] = T2[0LL]; } NVFUSER_UPDATE_MAGIC_ZERO; - float T2[1LL]; + Array T2; T2[0LL] = T1[0LL]; T3[(2LL + i12)] diff --git a/tests/cpp/test_scalar_hoisting.cpp b/tests/cpp/test_scalar_hoisting.cpp index d0295aa20f3..ae23b3e5593 100644 --- a/tests/cpp/test_scalar_hoisting.cpp +++ b/tests/cpp/test_scalar_hoisting.cpp @@ -316,7 +316,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor b7 = i0 < i6; float f8; f8 = (float)(i6); - float T1[1LL]; + Array T1; if (b7) { T1[0LL] = sinf(T0[i0]);