Tuning wgrad kernels with split-k, in practice #396
-
Hi, I have wgrad kernels with non split-k working in TVM, now I'm looking to leverage split-k. I'm wondering what's the best strategy to pick the best kernel among all cartesian product of
The first two are already supported in TVM, and for wgrad kernels there are about 80 variants. Now, if I want to add split-k, the number of combinations becomes too large. What is a good, practical way to go about this? Also what values (how many and how large) of split-k slices should be considered? For example, how about this one: First consider only tile shape and alignment, and after deciding the top performing one, add split-k variations on top of it (tile and alignment fixed) for further tuning. |
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 12 replies
-
Below are some guidelines and information on finding the best tile shape, alignment, split-k-mode (serial, parallel), and split-k slice. 1. Tile Shape: You would want to go with the largest Tile Shape for the most reuse; however, the trade-off is that a large Tile Shape might not be able to reach full GPU utilization because of quantization effects. Thus, it is best that you sweep through all possible Tile Shape for each problem size. 2. Alignment: This one is straightforward. The largest possible alignment will always win. Thus, for F16 input go with align8 wgrad kernels. 3. Split-k-mode: Parallel split-k-mode always surpasses serial split-k-mode. Parallel split-k-mode runs a reduction kernel instead of reducing the split-k chunks serially. 4. Split-k-slice: The goal here is to slice in the problem in GEMM-K dimension s.t. that we have enough CTAs to fill the entire GPU and get maximum utilization.
Typically, we need a large split-k-slice value for wgrad since GEMM-M (K) and GEMM-N (RSC) are small. Thus, we split the GEMM-K (NPQ) to launch more CTAs. For a given TileShape and problem size you can try a split-k-slice number that launches at-least one wave, i.e., 108 CTAs for GA100 or 68 CTAs for GA102. I am attaching some notes on this topic which tries to analytically compute split-k-slice number (see page 1). In practice, I have run sweeps to find the best (1) TileShape and (4) split-k-slice (--split-k-slice=1:128:1). Fixing (2) Alignment to largest possible and (3) split-k-mode to parallel. |
Beta Was this translation helpful? Give feedback.
-
@manishucsd Questions on the reduction kernel block shape https://github.com/masahi/cutlass/blob/example-wgrad-splitk/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu#L107:
|
Beta Was this translation helpful? Give feedback.
-
This parameter can impact the performance of the reduction kernel. However, the reduction time is usually much shorter than the gemm.
What problem size and split_k_slice do you use? |
Beta Was this translation helpful? Give feedback.
-
Ok I found why cutlass/tools/library/src/handle.cu Lines 1073 to 1075 in 1e4703c This is because this line sets the element C to be the type of accumulator, while what I need there is the type of output tensor. cutlass/tools/library/src/handle.cu Line 1060 in 1e4703c If I replace that line with
|
Beta Was this translation helpful? Give feedback.
-
You need to build both
Here is what I got
What happens behind is that cutlass_profiler detects that the output type of Here is how cutlass_profiler uses cudeEvent to measure performance: https://github.com/NVIDIA/cutlass/blob/master/tools/profiler/src/conv2d_operation_profiler.cu#L1276-L1335 . You can compare TVM one with it if you want. |
Beta Was this translation helpful? Give feedback.
Below are some guidelines and information on finding the best tile shape, alignment, split-k-mode (serial, parallel), and split-k slice.
1. Tile Shape: You would want to go with the largest Tile Shape for the most reuse; however, the trade-off is that a large Tile Shape might not be able to reach full GPU utilization because of quantization effects. Thus, it is best that you sweep through all possible Tile Shape for each problem size.
2. Alignment: This one is straightforward. The largest possible alignment will always win. Thus, for F16 input go with align8 wgrad kernels.
3. Split-k-mode: Parallel split-k-mode always surpasses serial split-k-mode. Parallel split-k-mode runs a reduction k…