-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathgemmPluginProfiler.cpp
340 lines (288 loc) · 12.5 KB
/
gemmPluginProfiler.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/plugins/common/gemmPluginProfiler.h"
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fused_gated_gemm/fused_gated_gemm.h"
#include "tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm.h"
#include "tensorrt_llm/plugins/lowLatencyGemmPlugin/lowLatencyGemmPlugin.h"
#include "tensorrt_llm/plugins/lowLatencyGemmSwigluPlugin/lowLatencyGemmSwigluPlugin.h"
#include "tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h"
namespace tensorrt_llm::plugins
{
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::GemmPluginProfiler()
{
mMNKProfileMap = std::make_shared<MNKProfileMap>();
// set SKIP_GEMM_PLUGIN_PROFILINGS=1 to avoid tactics profilings
auto const skipEnv = std::getenv("SKIP_GEMM_PLUGIN_PROFILINGS");
mSkip = (skipEnv != NULL && std::stoi(skipEnv));
if (mSkip)
{
TLLM_LOG_DEBUG(
"SKIP_GEMM_PLUGIN_PROFILINGS is set. Skipping GEMM plugin profilings. It could result in runtime error "
"if default tactic is not defined.");
}
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::serialize(
char*& buffer, GemmIdType const& gemmId) const
{
auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId);
// Save number of profiles for given GEMM ID
write(buffer, static_cast<int>(mProfileMap->size()));
for (auto const& pair : *mProfileMap)
{
// Save pair of M to the best GEMM config
write(buffer, pair);
}
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::deserialize(
char const*& data, GemmDims& dims, GemmIdType const& gemmId)
{
// NOTE: this mutex is not needed since each thread owns its private map, but will put here for
// consistency
writer_lock lock(mMNKProfileMap->mutex);
mDims = dims;
// GemmId gemmId(dims.n, dims.k);
if (!mMNKProfileMap->existsMProfileMap(gemmId))
{
// Create GEMM with GEMM ID if it does not exist
mMNKProfileMap->createMProfileMap(gemmId);
}
// Populate map with profiles of GEMM ID
auto profileMap = mMNKProfileMap->getMProfileMap(gemmId);
int selectedMapSize;
read(data, selectedMapSize);
for (int ii = 0; ii < selectedMapSize; ++ii)
{
std::pair<int, std::optional<Config>> config;
read(data, config);
profileMap->insert(config);
}
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
size_t GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::getSerializationSize(
GemmIdType const& gemmId) const
{
reader_lock lock(mMNKProfileMap->mutex);
return sizeof(int) + // size of the tactics map
mMNKProfileMap->getMProfileMap(gemmId)->size()
* sizeof(std::pair<int, std::optional<Config>>); // size of the tactics map
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
int GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::getMaxProfileM() const
{
return 8192;
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::initTmpData(
int m, int n, int k, char* workspace, size_t size, cudaStream_t stream)
{
/* Do nothing */
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::profileTactics(
RunnerPtr const& runner, nvinfer1::DataType const& type, GemmDims const& dims, GemmIdType const& gemmId)
{
writer_lock lock(mMNKProfileMap->mutex);
if (!dims.isInitialized())
{
return;
}
mRunner = runner;
mType = type;
int const maxM = std::min(nextPowerOfTwo(dims.maxM), getMaxProfileM());
computeTmpSize(maxM, dims.n, dims.k);
if (!mMNKProfileMap->existsMProfileMap(gemmId))
{
// Create map for GEMM ID
mMNKProfileMap->createMProfileMap(gemmId);
}
if (mSkip)
{
return;
}
auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId);
bool isAllocated{false};
auto profileTactics = [&mProfileMap, &isAllocated, this](int m, int n, int k)
{
if (mProfileMap->count(m) == 0)
{
if (!isAllocated)
{
// Allocate tmp data to run GEMMs
allocateTmpData();
isAllocated = true;
}
initTmpData(m, n, k, mWorkspaceTmp, mTmpWorkspaceSizeInBytes, mStream);
auto const tactics = this->getTactics(m, n, k);
// Profile different tactics for particular m and insert best config to the map
mProfileMap->insert({m, this->profileTacticsForProblem(m, n, k, tactics)});
}
};
common::check_cuda_error(cudaStreamCreate(&mStream));
int const startMinMRounded = nextPowerOfTwo(dims.minM);
for (int m = std::max(1, startMinMRounded); m < maxM; m *= 2)
{
profileTactics(m, dims.n, dims.k);
}
profileTactics(maxM, dims.n, dims.k);
if (isAllocated)
{
// Free tmp data
freeTmpData();
}
common::check_cuda_error(cudaStreamDestroy(mStream));
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
std::optional<Config> GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::getBestConfig(
int m, GemmIdType const& gemmId) const
{
reader_lock lock(mMNKProfileMap->mutex);
if (mSkip)
{
TLLM_LOG_TRACE("Skip is set, no best config is set for this instance");
return std::nullopt;
}
int const mRounded = std::min(std::max(1, nextPowerOfTwo(m)), getMaxProfileM());
fflush(stdout);
return mMNKProfileMap->getMProfileMap(gemmId)->at(mRounded);
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::allocateTmpData()
{
TLLM_CHECK_WITH_INFO(mTmpWorkspaceSizeInBytes > 0, "tmpWorkspaceSizeInBytes must be larger than 0");
auto const status = cudaMalloc(&mWorkspaceTmp, mTmpWorkspaceSizeInBytes);
TLLM_CHECK_WITH_INFO(status == cudaSuccess, "Can't allocate tmp workspace for GEMM tactics profiling.");
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
void GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::freeTmpData()
{
auto const status = cudaFree(mWorkspaceTmp);
TLLM_CHECK_WITH_INFO(status == cudaSuccess, "Can't free tmp workspace for GEMM tactics profiling.");
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
std::optional<Config> GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::profileTacticsForProblem(
int m, int n, int k, std::vector<Config> const& tactics)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
float bestTime = std::numeric_limits<float>::max();
Config bestConfig;
bool foundOne = false;
// Iterate over all tactics for given M, N and K
for (int ii = 0; ii < tactics.size(); ++ii)
{
Config const& candidateConfig = tactics[ii];
float time = std::numeric_limits<float>::max();
try
{
if (!checkTactic(m, n, k, candidateConfig))
{
continue;
}
// Profile particualar tactic for given M, N and K
time = profileTacticForProblem(m, n, k, candidateConfig);
foundOne = true;
}
catch (std::exception const& e)
{
std::ostringstream msg;
msg << "Cannot profile configuration " << ii;
if constexpr (std::is_same_v<Config, tensorrt_llm::cutlass_extensions::CutlassGemmConfig>)
{
msg << ": " << candidateConfig.toString();
}
msg << "\n (for"
<< " m=" << m << ", n=" << n << ", k=" << k << ")"
<< ", reason: \"" << e.what() << "\". Skipped";
TLLM_LOG_TRACE(msg.str());
cudaGetLastError(); // Reset the last cudaError to cudaSuccess.
continue;
}
// Choose the fastest tactic
if (time < bestTime)
{
bestConfig = candidateConfig;
bestTime = time;
}
}
if (!foundOne)
{
std::ostringstream msg;
msg << "Have not found any valid GEMM config for shape ("
<< "m=" << m << ", n=" << n << ", k=" << k << "). Will try to use default or fail at runtime";
TLLM_LOG_WARNING(msg.str());
return std::nullopt;
}
return {bestConfig};
}
template <typename Config, typename RunnerPtr, typename GemmIdType, typename GemmIdHashType>
float GemmPluginProfiler<Config, RunnerPtr, GemmIdType, GemmIdHashType>::profileTacticForProblem(
int m, int n, int k, Config const& tactic)
{
constexpr int warmup = 5;
constexpr int runs = 10;
cudaStream_t stream = mStream;
// Warmup the execution
for (int i = 0; i < warmup; ++i)
{
runTactic(m, n, k, tactic, mWorkspaceTmp, stream);
}
cudaEvent_t start;
cudaEvent_t stop;
common::check_cuda_error(cudaEventCreate(&start));
common::check_cuda_error(cudaEventCreate(&stop));
common::check_cuda_error(cudaStreamSynchronize(stream));
common::check_cuda_error(cudaEventRecord(start, stream));
// Profile GEMM
for (int i = 0; i < runs; ++i)
{
runTactic(m, n, k, tactic, mWorkspaceTmp, stream);
}
common::check_cuda_error(cudaEventRecord(stop, stream));
common::check_cuda_error(cudaEventSynchronize(stop));
float elapsed;
common::check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop));
common::check_cuda_error(cudaEventDestroy(start));
common::check_cuda_error(cudaEventDestroy(stop));
return elapsed / runs;
}
template class GemmPluginProfiler<tensorrt_llm::cutlass_extensions::CutlassGemmConfig,
std::shared_ptr<tensorrt_llm::kernels::cutlass_kernels::CutlassInt8GemmRunnerInterface>, GemmIdCore,
GemmIdCoreHash>;
template class GemmPluginProfiler<tensorrt_llm::cutlass_extensions::CutlassGemmConfig,
std::shared_ptr<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunnerInterface>, GemmIdCore,
GemmIdCoreHash>;
template class GemmPluginProfiler<cublasLtMatmulHeuristicResult_t,
std::shared_ptr<tensorrt_llm::common::CublasMMWrapper>, GemmIdCublas, GemmIdCublasHash>;
// TODO I dont like the dependency on the MOE plugin here, but MOE needs the full context to run profiles
template class GemmPluginProfiler<tensorrt_llm::cutlass_extensions::CutlassGemmConfig, MixtureOfExpertsPlugin*,
GemmIDMoe, GemmIDMoeHash>;
template class GemmPluginProfiler<tensorrt_llm::cutlass_extensions::CutlassGemmConfig,
std::shared_ptr<tensorrt_llm::kernels::cutlass_kernels::CutlassFusedGatedGemmRunnerInterface>, GemmIdCore,
GemmIdCoreHash>;
template class GemmPluginProfiler<tensorrt_llm::cutlass_extensions::CutlassGemmConfig,
std::shared_ptr<tensorrt_llm::kernels::cutlass_kernels::CutlassFp8RowwiseGemmRunnerInterface>, GemmIdCore,
GemmIdCoreHash>;
template class GemmPluginProfiler<LowLatencyGemmPluginProfiler::Config, LowLatencyGemmRunnerPtr, GemmIdCore,
GemmIdCoreHash>;
template class GemmPluginProfiler<LowLatencyGemmSwigluPluginProfiler::Config, LowLatencyGemmSwigluRunnerPtr, GemmIdCore,
GemmIdCoreHash>;
} // namespace tensorrt_llm::plugins