forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 21
/
sm90_epilogue_array_tma_warpspecialized.hpp
1191 lines (1015 loc) · 49.2 KB
/
sm90_epilogue_array_tma_warpspecialized.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing elementwise operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/barrier.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/epilogue/fusion/callbacks.hpp"
#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp"
#include "cutlass/detail/collective.hpp"
#include "cutlass/detail/layout.hpp"
#include "cutlass/trace.h"
#include "cutlass/cuda_host_adapter.hpp"
#include "cute/tensor.hpp"
#include "cute/atom/copy_traits_sm90_tma.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
int StagesC_,
int StagesD_,
int FragmentSize_,
bool ReuseSmemC_,
bool DelayTmaStore_,
int NumEpilogueWarpGroups_,
class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K)
class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N)
class ElementC_,
class StrideC_,
class ElementD_,
class StrideD_,
class FusionCallbacks_,
class CopyOpG2S_,
class SmemLayoutAtomC_,
class CopyOpS2R_,
class CopyOpS2G_,
class SmemLayoutAtomD_,
class CopyOpR2S_,
class CopyAtomC_,
class CopyOpR2R_
>
class CollectiveEpilogue<
Sm90PtrArrayTmaWarpSpecialized<StagesC_,
StagesD_,
FragmentSize_,
ReuseSmemC_,
DelayTmaStore_,
NumEpilogueWarpGroups_
>,
CtaTileMNK_,
EpilogueTile_,
ElementC_,
StrideC_,
ElementD_,
StrideD_,
FusionCallbacks_,
CopyOpG2S_,
SmemLayoutAtomC_,
CopyOpS2R_,
CopyOpS2G_,
SmemLayoutAtomD_,
CopyOpR2S_,
CopyAtomC_,
CopyOpR2R_
> {
public:
//
// Type Aliases
//
using DispatchPolicy = Sm90PtrArrayTmaWarpSpecialized<StagesC_,
StagesD_,
FragmentSize_,
ReuseSmemC_,
DelayTmaStore_,
NumEpilogueWarpGroups_
>;
using CtaTileMNK = CtaTileMNK_;
using EpilogueTile = EpilogueTile_;
using FusionCallbacks = FusionCallbacks_;
using ElementC = ElementC_;
using StrideC = StrideC_;
using InternalStrideC = cute::remove_pointer_t<StrideC>;
using ElementD = ElementD_;
using StrideD = StrideD_;
using InternalStrideD = cute::remove_pointer_t<StrideD>;
using CopyOpG2S = CopyOpG2S_;
using SmemLayoutAtomC = SmemLayoutAtomC_;
using CopyOpS2R = CopyOpS2R_;
using CopyOpS2G = CopyOpS2G_;
using SmemLayoutAtomD = SmemLayoutAtomD_;
using CopyOpR2S = CopyOpR2S_;
using CopyAtomC = CopyAtomC_;
using CopyOpR2R = CopyOpR2R_;
using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits<FusionCallbacks>::Operation;
using GmemTiledCopyC = CopyOpG2S;
using GmemTiledCopyD = CopyOpS2G;
static_assert(!is_layout<EpilogueTile>::value && is_tuple<EpilogueTile>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]");
static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M");
static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N");
static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]");
static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]");
private:
constexpr static bool is_source_supported = not cute::is_void_v<ElementC>;
constexpr static bool is_destination_supported = not cute::is_void_v<ElementD>;
using NonVoidElementD = cute::conditional_t<not is_destination_supported,fusion::get_element_aux_t<FusionCallbacks>, ElementD>;
static_assert(not cute::is_void_v<NonVoidElementD>, "SmemElementD is void");
using NonVoidElementC = cute::conditional_t<not is_source_supported,NonVoidElementD,ElementC>; // prevents void ref breakages
using SmemElementC = typename cutlass::detail::get_unpacked_element_type<NonVoidElementC>::type;
using SmemElementD = typename cutlass::detail::get_unpacked_element_type<NonVoidElementD>::type;
constexpr static int StagesC = StagesC_;
constexpr static int StagesD = StagesD_;
constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported;
constexpr static bool DelayTmaStore = DelayTmaStore_;
constexpr static bool is_m_major_C = detail::is_m_major<InternalStrideC>();
constexpr static bool is_m_major_D = detail::is_m_major<InternalStrideD>();
constexpr static bool is_im2col_C = cute::is_same_v<CopyOpG2S, SM90_TMA_LOAD_IM2COL>;
constexpr static bool is_im2col_D = cute::is_same_v<CopyOpS2G, SM90_TMA_STORE_IM2COL>;
// Check if register transformation is needed before copying register to shared memory.
constexpr static bool IsUseR2R = !cute::is_void_v<CopyOpR2R>;
using SmemLayoutC = decltype(tile_to_shape(
SmemLayoutAtomC{},
make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int<StagesC>{}),
cute::conditional_t<is_m_major_C, Step<_2,_1,_3>, Step<_1,_2,_3>>{} ));
using SmemLayoutD = decltype(tile_to_shape(
SmemLayoutAtomD{},
make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int<ReuseSmemC ? StagesC : StagesD>{}),
cute::conditional_t<is_m_major_D, Step<_2,_1,_3>, Step<_1,_2,_3>>{} ));
constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC
&& cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{}));
static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met");
constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{});
constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD);
using SmemArrayTypeC = cute::ArrayEngine<SmemElementC, cosize_v<SmemLayoutC>>;
using SmemArrayTypeD = cute::ArrayEngine<SmemElementD, cosize_v<SmemLayoutD>>;
using EmptyType = cute::tuple<>;
using SmemCStorage = cute::conditional_t<is_source_supported and (not ReuseSmemC),
SmemArrayTypeC,
EmptyType>;
using SmemDStorage = cute::conditional_t<is_destination_supported,
SmemArrayTypeD,
EmptyType>;
struct CollectiveStorageWithC {
alignas(SmemAlignmentC) ArrayEngine<SmemElementC, cosize_v<SmemLayoutC>> smem_C;
alignas(SmemAlignmentD) ArrayEngine<SmemElementD, cosize_v<SmemLayoutD>> smem_D;
};
union CollectiveStorageWithoutC {
cute::array<SmemElementC, 0> smem_C;
alignas(SmemAlignmentD) ArrayEngine<SmemElementD, cosize_v<SmemLayoutD>> smem_D;
};
union CollectiveStorageReuseC {
alignas(MaxSmemAlignment) ArrayEngine<SmemElementC, cosize_v<SmemLayoutC>> smem_C;
alignas(MaxSmemAlignment) ArrayEngine<SmemElementD, cosize_v<SmemLayoutD>> smem_D;
};
public:
// TMA pipeline for loading C
using LoadPipeline = cutlass::PipelineTransactionAsync<StagesC>;
using LoadPipelineState = cutlass::PipelineState<StagesC>;
constexpr static uint32_t TmaTransactionBytes =
(size(take<0,2>(SmemLayoutC{})) * static_cast<uint32_t>(sizeof_bits<SmemElementC>::value)) / 8;
constexpr static bool RequiresTransactionBytes = true;
constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_;
// TMA pipeline for storing D
using StorePipeline = cute::conditional_t<ReuseSmemC,
cutlass::PipelineTmaStore<StagesC, StagesD-1>,
cutlass::PipelineTmaStore<StagesD>>;
using StorePipelineState = cutlass::PipelineState<ReuseSmemC ? StagesC : StagesD>;
struct SharedStorage {
struct TensorStorage {
using CollectiveStorage = cute::conditional_t<not is_source_supported, CollectiveStorageWithoutC,
cute::conditional_t<ReuseSmemC, CollectiveStorageReuseC, CollectiveStorageWithC>>;
CollectiveStorage collective;
using FusionStorage = typename FusionCallbacks::SharedStorage;
FusionStorage thread;
} tensors;
struct TensorMapStorage : cute::aligned_struct<128, _0> {
cute::TmaDescriptor smem_tensormap_C;
cute::array<cute::TmaDescriptor, NumEpilogueWarpGroups> smem_tensormap_D;
} tensormaps;
using PipelineStorage = typename LoadPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<InternalStrideC, StrideC>;
// Host side epilogue arguments
struct Arguments {
typename FusionCallbacks::Arguments thread{};
ElementC const** ptr_C = nullptr;
StrideC dC;
ElementD ** ptr_D = nullptr;
StrideD dD;
};
// Device side epilogue params
struct Params {
using TMA_C = decltype(make_tma_copy(
CopyOpG2S{},
make_tensor(make_gmem_ptr(static_cast<NonVoidElementC const*>(nullptr)),
repeat_like(InternalStrideC{}, int32_t(0)), InternalStrideC{}),
take<0,2>(SmemLayoutC{}),
EpilogueTile{},
_1{}));
using TMA_D = decltype(make_tma_copy(
CopyOpS2G{},
make_tensor(make_gmem_ptr(static_cast<NonVoidElementD const*>(nullptr)),
repeat_like(InternalStrideD{}, int32_t(0)), InternalStrideD{}),
take<0,2>(SmemLayoutD{}),
EpilogueTile{},
_1{}));
typename FusionCallbacks::Params thread{};
TMA_C tma_load_c;
TMA_D tma_store_d;
cute::TmaDescriptor* tensormaps;
ElementC const** ptr_C;
StrideC dC;
ElementD** ptr_D;
StrideD dD;
uint32_t tma_transaction_bytes = TmaTransactionBytes;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
[[maybe_unused]] void* workspace) {
// These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc.
// These will be replaced with correct values before the initial tma load.
auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1));
auto init_M = get<0>(init_shape);
auto init_N = get<1>(init_shape);
auto init_L = get<3>(init_shape);
static_assert(!is_im2col_C and !is_im2col_D, "Im2Col not supported on C or D");
InternalStrideC stride_c;
InternalStrideD stride_d;
if constexpr (IsGroupedGemmKernel) {
// Strides for Grouped Gemm will be replaced prior to the first access regardless.
stride_c = InternalStrideC{};
stride_d = InternalStrideD{};
}
else {
// Tensor shapes for Ptr-Array are initialized correctly only here.
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(0), 1);
init_M = get<0>(problem_shape_MNKL);
init_N = get<1>(problem_shape_MNKL);
init_L = get<3>(problem_shape_MNKL);
stride_c = args.dC;
stride_d = args.dD;
}
uint32_t transaction_bytes = TmaTransactionBytes;
typename Params::TMA_C tma_load_c = {};
if constexpr (is_source_supported) {
ElementC const* ptr_C_first_batch = reinterpret_cast<ElementC const*>(args.ptr_C);
Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_c, _0{})));
tma_load_c = make_tma_copy(
CopyOpG2S{},
tensor_c,
take<0,2>(SmemLayoutC{}),
EpilogueTile{},
_1{});
}
typename Params::TMA_D tma_store_d;
if constexpr (is_destination_supported) {
ElementD const* ptr_D_first_batch = reinterpret_cast<ElementD const*>(args.ptr_D);
Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{})));
tma_store_d = make_tma_copy(
CopyOpS2G{},
tensor_d,
take<0,2>(SmemLayoutD{}),
EpilogueTile{},
_1{});
}
auto fusion_workspace = static_cast<char*>(workspace);
auto fusion_workspace_size = FusionCallbacks::get_workspace_size(problem_shape, args.thread);
auto tma_descriptor_workspace = reinterpret_cast<cute::TmaDescriptor*>(
static_cast<char*>(workspace) + fusion_workspace_size);
return {
FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, fusion_workspace),
tma_load_c,
tma_store_d,
tma_descriptor_workspace,
args.ptr_C,
args.dC,
args.ptr_D,
args.dD,
transaction_bytes,
};
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) {
constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v<ElementC> ? 0 : 1);
auto descriptors_shape = cute::make_shape(sm_count, Int<NumInputTensors>{});
constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor);
// Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies
return (size(descriptors_shape) * SizeOfCuTensorMap) + FusionCallbacks::get_workspace_size(problem_shape, args.thread);
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter);
}
template <class ProblemShape>
static bool
can_implement(
ProblemShape problem_shape,
[[maybe_unused]] Arguments const& args) {
bool implementable = true;
bool fusion_implementable = true;
if (problem_shape.is_host_problem_shape_available()) {
for (int i = 0; i < problem_shape.groups(); ++i) {
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1);
auto [M,N,K,L] = problem_shape_MNKL;
if constexpr (is_destination_supported) {
constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits<ElementD>();
constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits<ElementD>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_D>(cute::make_shape(M,N,L), InternalStrideD{});
}
if constexpr (not cute::is_void_v<ElementC>) {
constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits<ElementC>();
constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits<ElementC>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(cute::make_shape(M,N,L), InternalStrideC{});
}
fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread);
}
}
else {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n");
}
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
if (!fusion_implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
}
bool beta_implementable = true;
if constexpr (cute::is_void_v<ElementC>) {
if constexpr (detail::has_beta<Arguments>::value) {
beta_implementable = args.thread.beta == 0.0;
}
if constexpr (detail::has_beta_ptr<Arguments>::value) {
beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr;
}
}
if (!beta_implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n");
}
return implementable && fusion_implementable && beta_implementable;
}
template<class TileShapeMNK>
CUTLASS_HOST_DEVICE
static constexpr int
get_load_pipe_increment(TileShapeMNK tile_shape_MNK) {
// Compute number of epilogue subtiles
return size<1>(zipped_divide(make_layout(take<0,2>(tile_shape_MNK)), EpilogueTile{}));
}
template<class TileShapeMNK>
CUTLASS_HOST_DEVICE
static constexpr int
get_store_pipe_increment(TileShapeMNK tile_shape_MNK) {
return get_load_pipe_increment(tile_shape_MNK);
}
CUTLASS_HOST_DEVICE
CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors)
: params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {}
CUTLASS_DEVICE
bool
is_producer_load_needed() const {
return fusion_callbacks.is_producer_load_needed();
}
CUTLASS_DEVICE auto
load_init(
Params const& params,
TensorMapStorage& shared_tensormaps,
int32_t sm_count,
int32_t sm_idx) {
// Initialize tma for loading
constexpr bool IsLoad = true;
auto load_tensormaps = tensormaps_init<IsLoad>(params, shared_tensormaps, sm_count, sm_idx, 0);
return load_tensormaps;
}
template<
class ProblemShapeMNKL,
class TileShapeMNK,
class TileCoordMNKL,
class TiledMma,
class TensorMapC,
__CUTE_REQUIRES(std::is_pointer_v<TensorMapC>)
>
CUTLASS_DEVICE auto
load(
LoadPipeline load_pipeline,
LoadPipelineState load_pipe_producer_state,
ProblemShapeMNKL problem_shape_mnkl,
TileShapeMNK tile_shape_MNK,
TileCoordMNKL tile_coord_mnkl,
TiledMma tiled_mma,
int thread_idx,
TensorStorage& shared_tensors,
TensorMapC const& load_tensormap,
int subtile_idx=-1,
bool wait_until_load_finishes = false) {
using namespace cute;
// Indexing variables
auto [M, N, K, L] = problem_shape_mnkl;
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
static_assert(!is_im2col_D, "Do not support im2col");
auto coord_shape = append<3>(make_shape(m_coord, n_coord), Int<0>{});
// Represent the full source tensor, slice to get the tile this CTA is currently responsible for
Tensor mC_mn = params.tma_load_c.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L)
Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{}));
Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N)
// Apply epilogue subtile, get matching smem tensor
auto ptr_sC = shared_tensors.collective.smem_C.begin();
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C)
// Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_)
ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{});
Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (G2S,G2S_M,G2S_N,EPI_M,EPI_N)
Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C)
// Get the fusion callbacks for the producer load warp
auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{
problem_shape_mnkl,
CtaTileMNK{},
tile_coord_mnkl,
tiled_mma,
EpilogueTile{},
thread_idx
};
auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args);
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();
LoadPipelineState last_load_producer_state = load_pipe_producer_state;
// Predication for TMA load (one thread issues TMA load)
bool issue_tma_load = cute::elect_one_sync();
// Pre-loop fusion callback entry point
pld_callbacks.begin();
LoadPipelineState prior_state = load_pipe_producer_state;
bool did_load = false;
CUTLASS_PRAGMA_UNROLL
for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) {
CUTLASS_PRAGMA_UNROLL
for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) {
if (subtile_idx != -1 && (epi_n * static_cast<int>(size<2>(gC_epi)) + epi_m) != subtile_idx) {
continue;
}
// Acquire the lock for this stage
constexpr uint16_t mcast_mask = 0;
uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state);
load_pipeline.producer_acquire(load_pipe_producer_state);
// Loop fusion callback entry point
pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load);
// Execute the TMA load for C if needed
if (is_C_load_needed) {
if (issue_tma_load) {
copy(params.tma_load_c.with(load_tensormap, *tma_barrier, mcast_mask),
bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index()));
load_pipeline.producer_expect_transaction(load_pipe_producer_state);
}
last_load_producer_state = load_pipe_producer_state;
did_load = true;
}
// Commit TMA loads for this stage and release the lock
load_pipeline.producer_commit(load_pipe_producer_state);
++load_pipe_producer_state;
}
}
// Post-loop fusion callback entry point
pld_callbacks.end();
if (wait_until_load_finishes && did_load) {
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state =
{last_load_producer_state.index(), !last_load_producer_state.phase(), last_load_producer_state.count()};
load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state);
}
return load_pipe_producer_state;
}
CUTLASS_DEVICE auto
load_tail(
LoadPipeline load_pipeline,
LoadPipelineState load_pipe_producer_state) {
if (!fusion_callbacks.is_producer_load_needed()) {
return load_pipe_producer_state;
}
bool issue_tma_load = cute::elect_one_sync();
if (issue_tma_load) {
load_pipeline.producer_tail(load_pipe_producer_state);
}
return load_pipe_producer_state;
}
template<
class ProblemShapeMNKL,
class TileShapeMNK,
class TileCoordMNKL,
class AccEngine, class AccLayout,
class TiledMma,
class TensorMapD
>
CUTLASS_DEVICE auto
store(
LoadPipeline load_pipeline,
LoadPipelineState load_pipe_consumer_state,
StorePipeline store_pipeline,
StorePipelineState store_pipe_producer_state,
ProblemShapeMNKL problem_shape_mnkl,
TileShapeMNK tile_shape_MNK,
TileCoordMNKL tile_coord_mnkl,
cute::Tensor<AccEngine,AccLayout> accumulators,
TiledMma tiled_mma,
int thread_idx,
TensorStorage& shared_tensors,
TensorMapD const& store_tensormap,
int subtile_idx=-1) {
using namespace cute;
using ElementAccumulator = typename AccEngine::value_type;
using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits<FusionCallbacks>::ElementCompute;
using ElementCompute = cute::conditional_t<cute::is_void_v<ElementCompute_>,ElementAccumulator,ElementCompute_>;
static_assert(is_rmem<AccEngine>::value, "Accumulator must be RF resident.");
static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)");
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<TileShapeMNK>::value, "TileShapeMNK must be static");
static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3");
static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4");
// Indexing variables
auto [M, N, K, L] = problem_shape_mnkl;
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
static_assert(!is_im2col_D, "Do not support im2col");
auto coord_shape = append<3>(make_shape(m_coord, n_coord), Int<0>{});
// Represent the full output tensor, slice to get the tile this CTA is responsible for
Tensor mD_mn = params.tma_store_d.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L)
Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{}));
Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N)
// Apply epilogue subtiling
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
// Construct the corresponding pipelined smem tensors
auto ptr_sC = shared_tensors.collective.smem_C.begin();
auto ptr_sD = shared_tensors.collective.smem_D.begin();
Tensor sC_epi = cute::as_position_independent_swizzle_tensor(
make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C)
Tensor sD_epi = cute::as_position_independent_swizzle_tensor(
make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D)
TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma);
// (t)hread-partition for (r)egister to (r)egister copy (tRR_)
TiledCopy tiled_r2r = [&]() {
if constexpr (IsUseR2R) {
return make_tiled_copy_S(Copy_Atom<CopyOpR2R, ElementCompute>{}, tiled_copy_C_atom);
}
else {
return make_tiled_copy_S(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>,
ElementCompute>{}, tiled_copy_C_atom);
}
}();
ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx);
// (t)hread-partition for (r)egister to (s)mem copy (tRS_)
TiledCopy tiled_r2s = [&]() {
if constexpr (IsUseR2R) {
return make_tiled_copy_D(Copy_Atom<CopyOpR2S,SmemElementD>{}, tiled_r2r);
}
else {
return make_tiled_copy_S(Copy_Atom<CopyOpR2S,SmemElementD>{}, tiled_copy_C_atom);
}
}();
ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx);
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D)
auto mma_tile_m = size<0>(TileShapeMNK{}) / size<1>(tRS_rAcc);
auto mma_tile_n = size<1>(TileShapeMNK{}) / size<2>(tRS_rAcc);
auto epi_tile_m = size<0>(EpilogueTile{});
auto epi_tile_n = size<1>(EpilogueTile{});
// Allocate D registers
Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi))));
Tensor tRS_rD = make_tensor<SmemElementD>(tRS_rD_layout); // (R2S,R2S_M,R2S_N)
// Vectorized fragment view
constexpr int FragmentSize = DispatchPolicy::FragmentSize;
Tensor tRS_rAcc_frg = recast<Array<ElementAccumulator, FragmentSize>>(tRS_rAcc);
Tensor tRS_rD_frg = recast<Array<SmemElementD , FragmentSize>>(tRS_rD);
CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly");
// (t)hread-partition for (s)mem to (r)egister copy (tSR_)
TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom<CopyOpS2R, SmemElementC>{}, tiled_copy_C_atom);
ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx);
Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C)
Layout tSR_rC_layout = thread_s2r.retile_D(tRS_rD).layout(); // (S2R,S2R_M,S2R_N)
// Allocate C registers
// If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type
// to eliminate some redundant pack+unpack instruction sequences for sub-word types
constexpr bool IsDirectS2R = cute::is_same_v<CopyOpS2R, AutoVectorizingCopyWithAssumedAlignment<128>>
&& decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1;
using RegisterElementC = cute::conditional_t<IsDirectS2R, ElementCompute, SmemElementC>;
Tensor tRS_rC = make_tensor<RegisterElementC>(tRS_rD_layout); // (R2S,R2S_M,R2S_N)
Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N)
// thread(b)lock-partition for (s)mem to (g)mem copy (bSG_)
ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{});
Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D)
Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
// OOB predication for tile quantization "residue"
// Absolute coordinate tensors (dynamic)
Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N)
Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N)
Tensor tRS_cD_mn = thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N)
// Relative coordinate tensors (static)
Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N)
Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N)
// Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate
auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n)
auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); // (m,n)
CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M");
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
// Get TiledCopy for partition reference when consumer store.
TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom<CopyOpR2S,SmemElementD>{}, tiled_copy_C_atom);
// Get the fusion callbacks for the consumer store warps
constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout
auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{
problem_shape_mnkl,
CtaTileMNK{},
tile_coord_mnkl,
tiled_mma,
EpilogueTile{},
tiled_copy_partition_ref,
cD,
residue_cD,
tRS_cD,
residue_tRS_cD,
tRS_rC,
thread_idx
};
auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks<RefSrc>(cst_args);
bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed();
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();
// Thread synchronizer for previously issued waits or fences
// to ensure visibility of smem reads/writes to threads or TMA unit
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
// Predication for TMA store (one warp issues TMA store)
bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0;
// In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight.
// The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can
// only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks.
// store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion.
LoadPipelineState load_wait_state = load_pipe_consumer_state;
if constexpr (ReuseSmemC) {
load_wait_state = store_pipe_producer_state;
load_wait_state.phase_ ^= 1;
}
// We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions
// Sync requirements of smem reuse may preclude this optimization
// Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD
int epi_m_prev = 0, epi_n_prev = 0;
static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock");
// The TMA store sequence for one subtile iteration
auto tma_store_fn = [&] (int epi_m, int epi_n) {
// Write the tile from smem to gmem with TMA
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
synchronize(); // ensure all threads have issued their async fence
if constexpr (is_destination_supported) {
if (issue_tma_store) {
copy(params.tma_store_d.with(store_tensormap), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n));
}
}
// Post async fence, pre TMA commit callback entry point
cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store);
// Commit the TMA stores for this stage
if (issue_tma_store) {
store_pipeline.producer_commit(store_pipe_producer_state);
}
++store_pipe_producer_state;
++issued_stores;
// Wait for the next smem buffer to be available
if (issue_tma_store) {
store_pipeline.producer_acquire(store_pipe_producer_state);
}
synchronize();
if constexpr (ReuseSmemC) {
// producer_acquire returns when at most StagesD-1 committed stores are pending
bool store_finished = issued_stores > StorePipeline::UnacquiredStages;
// Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits
if (store_finished) {
if (is_producer_load_needed) {
load_pipeline.consumer_release(load_pipe_consumer_state);
}
++load_pipe_consumer_state;
}
}
};
//
// BEGIN EPILOGUE
//
// Pre-loop fusion callback entry point
cst_callbacks.begin();
// For each output tile
CUTLASS_PRAGMA_UNROLL
for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) {
CUTLASS_PRAGMA_UNROLL
for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) {
bool is_first_iteration = epi_m == 0 && epi_n == 0;
bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1;
if (subtile_idx != -1 && (epi_n * static_cast<int>(size<2>(gD_epi)) + epi_m) != subtile_idx) {
continue;
}
cst_callbacks.begin_loop(epi_m, epi_n);
if (is_producer_load_needed) {
// Wait for the producer load to fill smem
load_pipeline.consumer_wait(load_wait_state);
if (is_C_load_needed) {
// Copy source tile from smem to register
copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC);
}
}
// First loop fusion callback entry point
cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed);
if (is_producer_load_needed) {
if constexpr (not ReuseSmemC) {
// Let producer load warp know smem buffers are consumed and empty
cutlass::arch::fence_view_async_shared();
load_pipeline.consumer_release(load_pipe_consumer_state);
++load_pipe_consumer_state;
}
++load_wait_state;
}
int mma_m = epi_m;
int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n;
Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n);
// Vectorized fragment loop with visitor callback entry point
int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n);
int r2s_v = epi_n_in_mma * size(tRS_rD_frg);
CUTLASS_PRAGMA_UNROLL
for (int epi_v = 0; epi_v < size(tRS_rD_frg); ++epi_v) {
tRS_rD_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n);
}
// The latest we can delay the TMA store is right before the smem store of the next iteration
// since the current TMA store needs to be committed before we can acquire the next smem buffer
if constexpr (DelayTmaStore) {
// Issue TMA stores for the previous subtile
if (not is_first_iteration and subtile_idx == -1) {
tma_store_fn(epi_m_prev, epi_n_prev);
}
epi_m_prev = epi_m;
epi_n_prev = epi_n;
}
// Smem reduction callback entry point using current store buffer for workspace
cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()),
synchronize, epi_m, epi_n, is_last_iteration, tRS_rD_frg);
// Copy tile from register to regiser if needed
if constexpr (IsUseR2R) {
// retile source and destination for tiled_r2r
Tensor tRR_rD_src = thread_r2r.retile_S(tRS_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N)
Tensor tRR_rD_dst = thread_r2r.retile_D(tRS_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N)
// Output needs register shuffling before copying to shared memory.
copy(tiled_r2r, tRR_rD_src, tRR_rD_dst);
}
// Copy tile from register to smem
if constexpr (is_destination_supported) {
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
}
// Post reduction, pre TMA store callback entry point
constexpr bool issue_smem_store = true; // No smem store predication
cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store);
if constexpr (not DelayTmaStore) {
// Issue TMA stores for this subtile
tma_store_fn(epi_m, epi_n);
}
cst_callbacks.end_loop(epi_m, epi_n);
} // for epi_m
} // for epi_n
if constexpr (DelayTmaStore) {
// Issue TMA stores for the last subtile
tma_store_fn(epi_m_prev, epi_n_prev);
}
// Post-loop fusion callback entry point
cst_callbacks.end();
return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state);
}
CUTLASS_DEVICE auto
store_tail(
LoadPipeline load_pipeline,
LoadPipelineState load_pipe_consumer_state,
StorePipeline store_pipeline,
StorePipelineState store_pipe_producer_state) {
// wait for all TMA stores to complete
store_pipeline.producer_tail(store_pipe_producer_state);
// reset store counter
issued_stores = 0;
if constexpr (ReuseSmemC) {
if (fusion_callbacks.is_producer_load_needed()) {
// Issue releases on up to StagesD-1 previously issued TMA stores
constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(CtaTileMNK{}));
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < release_stages; ++stage) {
load_pipeline.consumer_release(load_pipe_consumer_state);
++load_pipe_consumer_state;
}
}
}
return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state);
}
CUTLASS_DEVICE auto
store_init(
Params const& params,
TensorMapStorage& shared_tensormaps,
int32_t sm_count,
int32_t sm_idx,
int32_t warp_group_idx) {
int warp_idx_in_warp_group = canonical_warp_idx_sync() % NumWarpsPerWarpGroup;
// Since only one warp issues TMA store, we only need that one warp to initialize tensormaps
if (warp_idx_in_warp_group == 0) {
// Initialize tma
constexpr bool IsLoad = false;
auto store_tensormaps = tensormaps_init<IsLoad>(params, shared_tensormaps, sm_count, sm_idx, warp_group_idx);
return store_tensormaps;
}
TmaDescriptor* null_tma_desc = nullptr;
return cute::make_tuple(null_tma_desc);
}