forked from codeplaysoftware/cutlass-fork
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsm70_gemm.hpp
272 lines (234 loc) · 10.8 KB
/
sm70_gemm.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/tensor.hpp"
namespace cutlass::gemm::kernel {
///////////////////////////////////////////////////////////////////////////////
template <
class ProblemShape_,
class CollectiveMainloop_,
class CollectiveEpilogue_,
class TileScheduler_
>
class GemmUniversal<
ProblemShape_,
CollectiveMainloop_,
CollectiveEpilogue_,
TileScheduler_,
cute::enable_if_t<cute::is_base_of_v<KernelMultistage, typename CollectiveMainloop_::DispatchPolicy::Schedule>>>
{
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
using TileSchedulerTag = TileScheduler_;
using TileScheduler = typename detail::TileSchedulerSelector<
TileScheduler_, ArchTag, TileShape,
cute::Shape<cute::Int<1>, cute::Int<1>, cute::Int<1>>>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
static constexpr bool IsGdcEnabled = false;
static constexpr bool is_valid_tile_scheduler =
cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>;
static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializing the tile scheduler.");
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(cute::is_same_v<ElementAccumulator, typename CollectiveEpilogue::ElementAccumulator>,
"Mainloop and epilogue do not agree on accumulator value type.");
// MSVC requires the cast to fix a warning-as-error.
static constexpr int SharedStorageSize = static_cast<int>(cute::max(
sizeof(typename CollectiveMainloop::SharedStorage),
sizeof(typename CollectiveEpilogue::SharedStorage)));
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(cute::size(TiledMma{}));
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
// Device side arguments
struct Arguments {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerArguments scheduler{};
};
// Kernel entry point API
struct Params {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopParams mainloop{};
EpilogueParams epilogue{};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
static
Params
to_underlying_arguments(Arguments const& args, void* workspace) {
(void) workspace;
KernelHardwareInfo hw_info{args.hw_info.device_id, args.hw_info.sm_count};
auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{});
return {
args.mode,
args.problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace)
};
}
static bool
can_implement(Arguments const& args) {
bool mode_implementable = args.mode == GemmUniversalMode::kGemm or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
return mode_implementable && TileScheduler::can_implement(args.scheduler);
}
static size_t
get_workspace_size(Arguments const& args) {
size_t workspace_size = 0;
return workspace_size;
}
static
cutlass::Status
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
CudaHostAdapter* cuda_adapter = nullptr) {
cutlass::Status status = Status::kSuccess;
return status;
}
static dim3
get_grid_shape(Params const& params) {
int batch_count = 1;
if constexpr (cute::rank(ProblemShape{}) == 4) {
batch_count = cute::size<3>(params.problem_shape);
}
return dim3(
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))),
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))),
batch_count
);
}
static dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
CUTLASS_DEVICE
void
operator()(Params const& params, char* smem_buf) {
using namespace cute;
using X = Underscore;
// Preconditions
CUTE_STATIC_ASSERT(is_static<TileShape>::value);
// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
auto [M,N,K,L] = problem_shape_MNKL;
// Preconditions
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Get the appropriate blocks for this thread block -- potential for thread block locality
int thread_idx = int(ThreadIdxX());
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
auto m_coord = BlockIdxX();
auto n_coord = BlockIdxY();
auto l_coord = BlockIdxZ();
auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); // (m,n,k,l)
// Represent the full tensors
Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA); //(m,k,l)
Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB); //(n,k,l)
// Get batch slice
Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k)
Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k)
// Slice to get the tiles this thread block is responsible for
Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
// Compute tile residues for predication
auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord
auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord
auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max
auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue);
// Allocate the tiled_mma and the accumulators for the (M,N) blk_shape
TiledMma tiled_mma;
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
clear(accumulators);
auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA));
int k_tile_count = size<2>(gA);
// Perform the collective scoped MMA
CollectiveMainloop collective_mma;
collective_mma(
accumulators,
gA,
gB,
accumulators,
k_tile_iter, k_tile_count,
residue_mnk,
thread_idx,
smem_buf
);
// Epilogue and write to gD
CollectiveEpilogue epilogue{params.epilogue};
epilogue(
problem_shape_MNKL,
blk_shape,
blk_coord_mnkl,
accumulators,
tiled_mma,
residue_mnk,
thread_idx,
smem_buf
);
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel