Skip to content

Commit

Permalink
llm decode shapes fp8 rowwise gemm tuning (#3565)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3565

X-link: facebookresearch/FBGEMM#650

Add more decode shape tuning for fp8 gemm

Reviewed By: xw285cornell, jwfromm

Differential Revision: D68104224

fbshipit-source-id: 71098f467161a3b3ac73fabd3c59a4ebbbbd2101
  • Loading branch information
mxz297 authored and facebook-github-bot committed Jan 13, 2025
1 parent a93eed4 commit d31c312
Showing 1 changed file with 18 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,24 @@ struct IntTupleHash {
// For certain high priority shapes, we directly map to the best kernel rather
// than use heuristics.
static const std::unordered_map<std::tuple<int, int, int>, RowwiseKernel, IntTupleHash> rowwise_lookup_dispatch = {
// Support for decode for [1024, 5120]
{{16, 1024, 5120},
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
{{32, 1024, 5120},
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
{{64, 1024, 5120},
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
{{128, 1024, 5120},
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
// Support for decode for [5120, 1024]
{{16, 5120, 1024},
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
{{32, 5120, 1024},
fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2},
{{64, 5120, 1024},
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
{{128, 5120, 1024},
fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2},
// LLama 70B Decode shapes.
// Support for decode across batch sizes for [1280, 8192]
{{16, 1280, 8192},
Expand Down

0 comments on commit d31c312

Please sign in to comment.