forked from codeplaysoftware/cutlass-fork
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathxe_gemm_cooperative.hpp
322 lines (269 loc) · 13.3 KB
/
xe_gemm_cooperative.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
/***************************************************************************************************
* Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/workspace.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cute/tensor.hpp"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
///////////////////////////////////////////////////////////////////////////////
template <
class ProblemShape_,
class CollectiveMainloop_,
class CollectiveEpilogue_,
class TileScheduler_
>
class GemmUniversal<
ProblemShape_,
CollectiveMainloop_,
CollectiveEpilogue_,
TileScheduler_,
cute::enable_if_t<cute::is_base_of_v<KernelPVC, typename CollectiveMainloop_::DispatchPolicy::Schedule>
&& cute::is_same_v<TileScheduler_, cutlass::gemm::StreamKScheduler>>>
{
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::WorkgroupTileShape;
using WorkgroupTileShape = TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using ClusterShape = typename DispatchPolicy::ClusterShape;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
using TileSchedulerTag = TileScheduler_;
using TileScheduler = typename detail::TileSchedulerSelector<
TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
using TileSchedulerParams = typename TileScheduler::Params;
static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size
static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock;
using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape;
using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape;
using PrefetchATileSize = typename CollectiveMainloop::PrefetchATileSize;
static constexpr int PrefetchStrideA = static_cast<int>(get<1>(PrefetchATileSize{}));
static constexpr int PrefetchStrideB = static_cast<int>(CollectiveMainloop::SG_K);
// Kernel level shared memory storage
struct SharedStorage {
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
EpilogueTensorStorage epilogue;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
// Device side arguments
struct Arguments {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerArguments scheduler{};
};
// Kernel entry point API
struct Params {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopParams mainloop{};
EpilogueParams epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerParams scheduler{};
void* workspace{nullptr};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
static
Params
to_underlying_arguments(Arguments const& args, void* workspace) {
CUTLASS_TRACE_HOST("to_underlying_arguments():");
auto problem_shape = args.problem_shape;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = args.hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
// Calculate workspace pointers
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(
problem_shape_MNKL, TileShape{}, hw_info, args.scheduler, workspace_ptr);
return {
args.mode,
problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace_ptr),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace_ptr),
hw_info,
scheduler,
workspace
};
}
static bool
can_implement(Arguments const& args) {
bool mode_implementable = args.mode == GemmUniversalMode::kGemm or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
return mode_implementable && TileScheduler::can_implement(args.scheduler);
}
static size_t
get_workspace_size(Arguments const& args) {
size_t workspace_size = 0;
workspace_size += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info);
return workspace_size;
}
static cutlass::Status
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
CudaHostAdapter* cuda_adapter = nullptr) {
Status status = Status::kSuccess;
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
status = TileScheduler::template initialize_workspace<ProblemShape, ElementAccumulator>(
args.scheduler, workspace_ptr, args.problem_shape, args.hw_info);
return status;
}
// Computes the kernel launch grid shape based on runtime parameters
static dim3
get_grid_shape(Params const& params) {
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, params.hw_info);
}
static dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
CUTLASS_DEVICE
void
operator()(Params const& params, char* smem_buf) {
// Preconditions
CUTE_STATIC_ASSERT(is_static<WorkgroupTileShape>::value);
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Kernel level shared memory storage
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
auto M = get<0>(problem_shape_MNKL);
auto N = get<1>(problem_shape_MNKL);
auto K = get<2>(problem_shape_MNKL);
auto L = get<3>(problem_shape_MNKL);
TileScheduler scheduler{params.scheduler};
auto work_tile_info = scheduler.initial_work_tile_info();
int thread_idx = int(ThreadIdxX());
constexpr auto workgroup_shape = WorkgroupTileShape{}; // (BLK_M,BLK_N,BLK_K)
constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K)
while (work_tile_info.is_valid()) {
const int m_coord = work_tile_info.M_idx;
const int n_coord = work_tile_info.N_idx;
const int l_coord = work_tile_info.L_idx;
const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord);
Tensor mA_mkl = make_tensor(make_gmem_ptr(static_cast<ElementA const*>(nullptr)), make_shape(M,K,L), StrideA{}); //(m,k,l)
Tensor mB_nkl = make_tensor(make_gmem_ptr(static_cast<ElementB const*>(nullptr)), make_shape(N,K,L), StrideB{}); //(n,k,l)
Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k)
Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k)
auto gA = local_tile(mA_mk, workgroup_shape, take<0, 3>(tile_coord), Step<_1, X, _1>{});
auto gB = local_tile(mB_nk, workgroup_shape, take<0, 3>(tile_coord), Step< X, _1, _1>{});
// Get the number of K tiles to compute for this work as well as the starting K tile offset of the work.
const int work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, workgroup_shape);
const int work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info);
auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, make_shape(K)), make_shape(K));
auto k_residue = K - get<2>(subgroup_shape) * (K / get<2>(subgroup_shape)); // K - SUB_K * k_coord_max
// Compute tile residues for predication
auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord
auto n_max_coord = N - get<1>(subgroup_shape) * n_coord; // N - SUB_N * n_coord
auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue);
TiledMma tiled_mma;
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(workgroup_shape));
CollectiveMainloop collective_mma;
// Perform the collective scoped MMA
collective_mma.template operator()<PrefetchStrideA, PrefetchStrideB>(
accumulators,
gA,
gB,
accumulators,
k_tile_iter, work_k_tile_count,
residue_mnk,
tile_coord,
K,
thread_idx,
smem_buf,
params.mainloop
);
// Perform reduction across splits, if needed
TileScheduler::template fixup<MaxThreadsPerBlock>(
params.scheduler, work_tile_info, accumulators);
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue};
epilogue(
problem_shape_MNKL,
subgroup_shape,
tile_coord,
accumulators,
tiled_mma,
residue_mnk,
thread_idx,
smem_buf
);
}
// Get next work tile
work_tile_info = scheduler.fetch_next_work(work_tile_info);
}
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel