Skip to content

Commit 719fb18

Browse files
committed
address feedbacks
Signed-off-by: Anthony Chang <[email protected]>
1 parent 757579a commit 719fb18

File tree

4 files changed

+43
-4
lines changed

4 files changed

+43
-4
lines changed

cpp/include/tensorrt_llm/common/utils.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
#pragma once
1818

1919
#include <algorithm>
20-
#include <csignal>
2120
#include <initializer_list>
2221
#include <string>
23-
#include <unistd.h>
2422

2523
#ifndef _WIN32
2624
#include <pthread.h>

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,49 +153,69 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
153153
if (!acceptIf(options.mDtypeA == mOptions.dtypeA,
154154
fmtstr("dtypeA mismatch (kernel: %s, expected: %s)", tg::dtypeToString(options.mDtypeA).c_str(),
155155
tg::dtypeToString(mOptions.dtypeA).c_str())))
156+
{
156157
continue;
158+
}
157159

158160
if (!acceptIf(options.mDtypeB == mOptions.dtypeB,
159161
fmtstr("dtypeB mismatch (kernel: %s, expected: %s)", tg::dtypeToString(options.mDtypeB).c_str(),
160162
tg::dtypeToString(mOptions.dtypeB).c_str())))
163+
{
161164
continue;
165+
}
162166

163167
if (!acceptIf(options.mDtypeC == mOptions.dtypeC,
164168
fmtstr("dtypeC mismatch (kernel: %s, expected: %s)", tg::dtypeToString(options.mDtypeC).c_str(),
165169
tg::dtypeToString(mOptions.dtypeC).c_str())))
170+
{
166171
continue;
172+
}
167173

168174
if (!acceptIf(options.mUseDeepSeekFp8 == mOptions.deepSeekFp8,
169175
fmtstr(
170176
"deepSeekFp8 mismatch (kernel: %d, expected: %d)", options.mUseDeepSeekFp8, mOptions.deepSeekFp8)))
177+
{
171178
continue;
179+
}
172180

173181
if (!acceptIf(options.mTransposeMmaOutput == mOptions.transposeMmaOutput,
174182
fmtstr("transposeMmaOutput mismatch (kernel: %d, expected: %d)", options.mTransposeMmaOutput,
175183
mOptions.transposeMmaOutput)))
184+
{
176185
continue;
186+
}
177187

178188
if (!acceptIf((!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct,
179189
fmtstr("routeAct mismatch (kernel: %d, expected: %d)", !doesRouteImplUseNoRoute(options.mRouteImpl),
180190
mOptions.routeAct)))
191+
{
181192
continue;
193+
}
182194

183195
if (!acceptIf(options.mFusedAct == mOptions.fusedAct,
184196
fmtstr("fusedAct mismatch (kernel: %d, expected: %d)", options.mFusedAct, mOptions.fusedAct)))
197+
{
185198
continue;
199+
}
186200

187201
if (!acceptIf(options.mIsStaticBatch == mOptions.staticBatch,
188202
fmtstr(
189203
"staticBatch mismatch (kernel: %d, expected: %d)", options.mIsStaticBatch, mOptions.staticBatch)))
204+
{
190205
continue;
206+
}
191207

192208
if (!acceptIf(tileSize == mOptions.tileSize,
193209
fmtstr("tileSize mismatch (kernel: %d, expected: %d)", tileSize, mOptions.tileSize)))
210+
{
194211
continue;
212+
}
195213

196214
if (!acceptIf(isSMCompatible(gpuSM, configs[i].mSm),
197215
fmtstr("SM not compatible (gpuSM: %d, kernelSM: %d)", gpuSM, static_cast<int>(configs[i].mSm))))
216+
{
198217
continue;
218+
}
199219

200220
auto sm = configs[i].mSm;
201221
if (sm != SmVersion::Sm100f)
@@ -205,28 +225,36 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
205225
{
206226
if (!acceptIf(sm == SmVersion::Sm100a,
207227
fmtstr("SM version 100 requires Sm100a (kernel has: %d)", static_cast<int>(sm))))
228+
{
208229
continue;
230+
}
209231
}
210232
else if (smVersion == 103)
211233
{
212234
if (!acceptIf(sm == SmVersion::Sm103a,
213235
fmtstr("SM version 103 requires Sm103a (kernel has: %d)", static_cast<int>(sm))))
236+
{
214237
continue;
238+
}
215239
}
216240
}
217241

218242
if (options.mUseDeepSeekFp8)
219243
{
220244
if (!acceptIf(options.mUseShuffledMatrixA == false, "useShuffledMatrixA should be false for DeepSeek Fp8"))
245+
{
221246
continue;
247+
}
222248
}
223249

224250
if (options.mFusedAct)
225251
{
226252
if (!acceptIf(options.mActType == static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType),
227253
fmtstr("actType mismatch (kernel: %d, expected: %d)", static_cast<int>(options.mActType),
228254
static_cast<int>(mOptions.actType))))
255+
{
229256
continue;
257+
}
230258
}
231259

232260
// FIXME: Disables a few static scheduler kernels (schedS) that appears to have issues;
@@ -236,14 +264,18 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
236264
if (!acceptIf(!(options.mTileScheduler == TileScheduler::Static && options.mUseTmaOobOpt == true
237265
&& options.mTileN == 64),
238266
"Static scheduler with TmaOobOpt and TileN=64 (known issue)"))
267+
{
239268
continue;
269+
}
240270

241271
if (mOptions.transposeMmaOutput)
242272
{
243273
if (!acceptIf(options.mEpilogueTileM == mOptions.epilogueTileM,
244274
fmtstr("epilogueTileM mismatch (kernel: %d, expected: %d)", options.mEpilogueTileM,
245275
mOptions.epilogueTileM)))
276+
{
246277
continue;
278+
}
247279
}
248280

249281
// Kernel passed all filters

cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::optional<torch::Tensor> con
3737
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
3838
{
3939
TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by FP8 block scale MOE");
40+
TORCH_CHECK(tile_tokens_dim == 8 || tile_tokens_dim == 16 || tile_tokens_dim == 32 || tile_tokens_dim == 64
41+
|| tile_tokens_dim == 128 || tile_tokens_dim == 192 || tile_tokens_dim == 256,
42+
"tile_tokens_dim must be 8, 16, 32, 64, 128, 256");
4043
if (topk_ids.has_value() && topk_weights.has_value())
4144
{
4245
TORCH_CHECK(topk_ids.value().scalar_type() == at::ScalarType::Int, "topk_ids must be int");

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,20 @@
5555

5656
@contextmanager
5757
def moe_trtllm_debug_msg(enable=False):
58+
TLLM_BATCHED_GEMM_PRINT_NAME = os.environ.get(
59+
"TLLM_BATCHED_GEMM_PRINT_NAME", "0")
60+
TLLM_BATCHED_GEMM_PRINT_CONFIGS = os.environ.get(
61+
"TLLM_BATCHED_GEMM_PRINT_CONFIGS", "0")
5862
if enable:
5963
os.environ["TLLM_BATCHED_GEMM_PRINT_NAME"] = "1"
6064
os.environ["TLLM_BATCHED_GEMM_PRINT_CONFIGS"] = "1"
6165
try:
6266
yield
6367
finally:
64-
os.environ["TLLM_BATCHED_GEMM_PRINT_NAME"] = "0"
65-
os.environ["TLLM_BATCHED_GEMM_PRINT_CONFIGS"] = "0"
68+
os.environ[
69+
"TLLM_BATCHED_GEMM_PRINT_NAME"] = TLLM_BATCHED_GEMM_PRINT_NAME
70+
os.environ[
71+
"TLLM_BATCHED_GEMM_PRINT_CONFIGS"] = TLLM_BATCHED_GEMM_PRINT_CONFIGS
6672

6773

6874
def round_up(x, alignment):

0 commit comments

Comments
 (0)