Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
In RoPE, it seems common to have repeat ops at the end of fusions. For example, `k` and `v` are repeated in the Mistral case as shown [here](https://github.com/NVIDIA/Fuser/blob/main/benchmarks/python/rope_ops.py#L914-L915). This can be problematic in terms of performance since we typically choose to use a fusion output as the reference of scheduling. Suppose the final repeat is done for a factor of 4, the output tensor is 4x larger than the size of the actual computed tensors. Since we apply scheduling to the 4x larger reference, this could mean, for example, we would launch 4x larger number of threads or blocks, and there would be redundancy for a factor of 4. This PR attempts to alleviate the perf issue by detecting an ending repeat pattern. The idea is to factor out the iter domain that corresponds to the repetition and move it to the outermost position. The reference scheduling is then done only for the remaining iter domains. For example, for one of the resize segments of the Mistral forward case (I think this should correspond to the computation of the `K` input), we have this code sequence: ``` #pragma unroll for(nvfuser_index_t i14 = 0LL; i14 < 8LL; ++i14) { nvfuser_index_t i15; i15 = -i14; __bfloat T5[1LL]; T5[0LL] = T59[i14]; __bfloat T6[1LL]; T6[0LL] = T5[0LL]; float T36[1LL]; T36[0LL] = __bfloat2float(T6[0LL]); float T37[1LL]; T37[0LL] = __bfloat2float(T60[i14]); float T38[1LL]; T38[0LL] = T36[0LL] * T37[0LL]; __bfloat T67[1LL]; T67[0LL] = T66[i14]; __bfloat T68[1LL]; T68[0LL] = T67[0LL]; __bfloat T39[1LL]; T39[0LL] = T68[0LL]; __bfloat T45[1LL]; T45[0LL] = ((i10 >= i15) && (i11 < i15)) ? T39[0LL] : 0.0000e+00f; float T49[1LL]; T49[0LL] = __bfloat2float(T61[i14]); __bfloat T64[1LL]; T64[0LL] = T63[i14]; __bfloat T65[1LL]; T65[0LL] = T64[0LL]; __bfloat T40[1LL]; T40[0LL] = T65[0LL]; float T41[1LL]; T41[0LL] = __bfloat2float(T40[0LL]); float T42[1LL]; T42[0LL] = -T41[0LL]; __bfloat T43[1LL]; T43[0LL] = __float2bfloat(T42[0LL]); __bfloat T44[1LL]; T44[0LL] = (i10 < i15) ? T43[0LL] : 0.0000e+00f; __bfloat T46[1LL]; T46[0LL] = T44[0LL] | T45[0LL]; float T48[1LL]; T48[0LL] = __bfloat2float(T46[0LL]); float T50[1LL]; T50[0LL] = T48[0LL] * T49[0LL]; float T51[1LL]; T51[0LL] = T38[0LL] + T50[0LL]; __bfloat T52[1LL]; T52[0LL] = __float2bfloat(T51[0LL]); __bfloat T53[1LL]; T53[0LL] = T52[0LL]; __bfloat T54[1LL]; T54[0LL] = T53[0LL]; T62[i14] = T54[0LL]; } if ((b12 && (((4LL * i6) + (i3 / 4096LL)) < 32LL))) { loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T55[((((128LL * i2) + (2097152LL * i6)) + i1) + (128LL * ((nvfuser_index_t)blockIdx.x)))], &T62[0LL]); } ``` With this PR, it looks like: ``` for(nvfuser_index_t i16 = 0LL; i16 < 8LL; ++i16) { nvfuser_index_t i17; i17 = -i16; __bfloat T67[1LL]; T67[0LL] = T66[i16]; __bfloat T68[1LL]; T68[0LL] = T67[0LL]; __bfloat T39[1LL]; T39[0LL] = T68[0LL]; __bfloat T45[1LL]; T45[0LL] = ((i7 >= i17) && (i8 < i17)) ? T39[0LL] : 0.0000e+00f; __bfloat T64[1LL]; T64[0LL] = T63[i16]; __bfloat T65[1LL]; T65[0LL] = T64[0LL]; __bfloat T40[1LL]; T40[0LL] = T65[0LL]; float T41[1LL]; T41[0LL] = __bfloat2float(T40[0LL]); float T42[1LL]; T42[0LL] = -T41[0LL]; __bfloat T43[1LL]; T43[0LL] = __float2bfloat(T42[0LL]); __bfloat T44[1LL]; T44[0LL] = (i7 < i17) ? T43[0LL] : 0.0000e+00f; __bfloat T46[1LL]; T46[0LL] = T44[0LL] | T45[0LL]; __bfloat T5[1LL]; T5[0LL] = T59[i16]; __bfloat T6[1LL]; T6[0LL] = T5[0LL]; float T36[1LL]; T36[0LL] = __bfloat2float(T6[0LL]); float T37[1LL]; T37[0LL] = __bfloat2float(T60[i16]); float T38[1LL]; T38[0LL] = T36[0LL] * T37[0LL]; float T49[1LL]; T49[0LL] = __bfloat2float(T61[i16]); float T48[1LL]; T48[0LL] = __bfloat2float(T46[0LL]); float T50[1LL]; T50[0LL] = T48[0LL] * T49[0LL]; float T51[1LL]; T51[0LL] = T38[0LL] + T50[0LL]; T52[i16] = __float2bfloat(T51[0LL]); } #pragma unroll for(nvfuser_index_t i18 = 0LL; i18 < 4LL; ++i18) { Array<__bfloat, 8LL, 8> T62; #pragma unroll for(nvfuser_index_t i19 = 0LL; i19 < 8LL; ++i19) { __bfloat T53[1LL]; T53[0LL] = T52[i19]; __bfloat T54[1LL]; T54[0LL] = T53[0LL]; T62[i19] = T54[0LL]; } if ((b12 && (i14 < (-i18)))) { loadLocalToGlobal<__bfloat, /*vec_size=*/8, /*is_volatile=*/false>( &T55[(i10 + (524288LL * i18))], &T62[0LL]); } } ``` Notice that the final store has now its own loop with extent 4, which is the repetition factor. The launch configurations before: ``` Launch Parameters: BlockDim.x = 128, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 16384, GridDim.y = -1, GridDim.z = -1, Smem Size = 0 ``` The launch configurations after: ``` Launch Parameters: BlockDim.x = 128, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4096, GridDim.y = -1, GridDim.z = -1, Smem Size = 0 ``` The number of blocks is reduced by a factor of 4. Currently, while this pattern can appear with any scheduler pattern, I have only added to the resize scheduler. In RoPE, there's indeed a pointwise segment with an ending repeat, but that's not addressed in this PR. ### Performance benfit Will update
- Loading branch information