Skip to content

Commit 1b12e88

Browse files
committed
update code
Signed-off-by: jiahanc <[email protected]>
1 parent c829520 commit 1b12e88

File tree

5 files changed

+96
-60
lines changed

5 files changed

+96
-60
lines changed

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ void TrtllmGenBatchedGemmRunner::run(
169169
auto const configs = bmm.getBatchedGemmConfigs();
170170

171171
auto const& config = configs[configIndex];
172-
172+
std::cout << "config.mFunctionName: " << config.mFunctionName << std::endl;
173173
FLASHINFER_CHECK(numBatches > 0, "Batched GEMM requires numBatches > 0");
174174
if (!mOptions.staticBatch) {
175175
FLASHINFER_CHECK(totalNumPaddedTokens,

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 75 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
4141
Tensor routing_logits, Optional<Tensor> routing_bias, Tensor hidden_states,
4242
Tensor gemm1_weights, Tensor output1_scales_scalar, Tensor output1_scales_gate_scalar,
4343
Tensor gemm2_weights, Tensor output2_scales_scalar, int64_t const num_experts,
44-
int64_t const top_k, int64_t const n_group, int64_t const topk_group,
44+
int64_t const top_k, Optional<int64_t> const n_group, Optional<int64_t> const topk_group,
4545
int64_t const intermediate_size, int64_t const local_expert_offset,
46-
int64_t const local_num_experts, double const routed_scaling_factor,
46+
int64_t const local_num_experts, Optional<double> const routed_scaling_factor,
4747
bool const use_routing_scales_on_input, int64_t const tile_tokens_dim,
4848
int64_t const routing_method_type, bool enable_pdl) {
4949
static const std::tuple<int, int> device_props = [hidden_states] {
@@ -61,8 +61,11 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
6161

6262
if (use_routing_scales_on_input) {
6363
TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_bfloat16) << "routing_logits must be bfloat16.";
64-
} else {
64+
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
65+
RoutingMethodType::DeepSeekV3) {
6566
TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_float32) << "routing_logits must be float.";
67+
} else {
68+
// TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_bfloat16) << "routing_logits must be bfloat16.";
6669
}
6770
TVM_FFI_ICHECK_EQ(routing_logits->ndim, 2) << "routing_logits must be 2D.";
6871
TVM_FFI_ICHECK_EQ(routing_logits->shape[1], num_experts) << "routing_logits has incorrect shape.";
@@ -73,17 +76,32 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
7376
<< "routing_bias has incorrect shape.";
7477
}
7578

76-
if (n_group <= 0 || topk_group <= 0) {
77-
TVM_FFI_ICHECK_EQ(top_k, 1) << "Current routing kernel (no groups) only supports top_k=1.";
78-
} else {
79-
TVM_FFI_ICHECK_LE(top_k, 8) << "Current routing kernel (with groups) only supports top_k<=8.";
80-
TVM_FFI_ICHECK_LE(topk_group, 4)
81-
<< "Current routing kernel (with groups) only supports topk_group<=4.";
82-
TVM_FFI_ICHECK_LE(topk_group, n_group) << "n_group must not be smaller than topk_group.";
83-
TVM_FFI_ICHECK_EQ(num_experts % n_group, 0) << "num_experts must be divisible by n_group";
79+
if (n_group.has_value() && n_group.value() != 0) {
80+
TVM_FFI_ICHECK(static_cast<RoutingMethodType>(routing_method_type) ==
81+
RoutingMethodType::DeepSeekV3)
82+
<< "Routing kernel with groups implies DeepSeekV3 routing method.";
83+
TVM_FFI_ICHECK(topk_group.has_value()) << "if n_group is given, topk_group must be given";
84+
TVM_FFI_ICHECK_EQ(num_experts % n_group.value(), 0)
85+
<< "num_experts must be divisible by n_group";
86+
TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
87+
<< "Current routing kernel (with groups) only supports top_k<=8 && top_k>0.";
88+
TVM_FFI_ICHECK(topk_group.value() <= 4 && topk_group.value() > 0)
89+
<< "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0.";
90+
TVM_FFI_ICHECK_LE(topk_group.value(), n_group.value())
91+
<< "n_group must not be smaller than topk_group.";
8492
// This check ensures we have enough experts in the selected groups to handle the top_k routing
85-
TVM_FFI_ICHECK_LT(top_k, (topk_group * num_experts / n_group))
93+
TVM_FFI_ICHECK_LT(top_k, (topk_group.value() * num_experts / n_group.value()))
8694
<< "top_k must be less than total number of experts in selected groups";
95+
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
96+
RoutingMethodType::Renormalize ||
97+
static_cast<RoutingMethodType>(routing_method_type) ==
98+
RoutingMethodType::RenormalizeNaive) {
99+
TVM_FFI_LOG_AND_THROW(NotImplementedError)
100+
<< "Don't support routing method type Renormalize(Naive).";
101+
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
102+
RoutingMethodType::Llama4) {
103+
TVM_FFI_ICHECK_EQ(top_k, 1)
104+
<< "Current routing kernel (no groups, Llama4) only supports top_k=1.";
87105
}
88106
TVM_FFI_ICHECK_EQ(num_experts % 4, 0)
89107
<< "Routing kernel expects that num_experts must be divisible by 4";
@@ -121,11 +139,11 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
121139
args.hidden_size = hidden_states->shape[1];
122140
args.hidden_size_output = args.hidden_size;
123141
args.top_k = top_k;
124-
args.n_group = n_group;
125-
args.topk_group = topk_group;
142+
args.n_group = n_group.has_value() ? n_group.value() : 0;
143+
args.topk_group = topk_group.has_value() ? topk_group.value() : 0;
126144
args.local_expert_offset = local_expert_offset;
127145
args.local_num_experts = local_num_experts;
128-
args.routed_scaling_factor = routed_scaling_factor;
146+
args.routed_scaling_factor = routed_scaling_factor.has_value() ? routed_scaling_factor.value() : 1.0;
129147
args.intermediate_size = intermediate_size;
130148
args.mUseRoutingScalesOnInput = use_routing_scales_on_input;
131149

@@ -279,8 +297,8 @@ Tensor trtllm_fp8_per_tensor_scale_moe(
279297
Tensor routing_logits, Optional<Tensor> routing_bias, Tensor hidden_states,
280298
Tensor gemm1_weights, Tensor output1_scales_scalar, Tensor output1_scales_gate_scalar,
281299
Tensor gemm2_weights, Tensor output2_scales_scalar, int64_t num_experts, int64_t top_k,
282-
int64_t n_group, int64_t topk_group, int64_t intermediate_size, int64_t local_expert_offset,
283-
int64_t local_num_experts, double routed_scaling_factor, bool use_routing_scales_on_input,
300+
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size, int64_t local_expert_offset,
301+
int64_t local_num_experts, Optional<double> routed_scaling_factor, bool use_routing_scales_on_input,
284302
int64_t tile_tokens_dim, int64_t routing_method_type, bool enable_pdl) {
285303
auto dtype = hidden_states->dtype;
286304
if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) {
@@ -299,9 +317,9 @@ void trtllm_fp8_block_scale_moe_launcher(
299317
Tensor routing_logits, Optional<Tensor> routing_bias, Tensor hidden_states,
300318
Tensor hidden_states_scale, Tensor gemm1_weights, Tensor gemm1_weights_scale,
301319
Tensor gemm2_weights, Tensor gemm2_weights_scale, Tensor output, int64_t const num_experts,
302-
int64_t const top_k, int64_t const n_group, int64_t const topk_group,
320+
int64_t const top_k, Optional<int64_t> const n_group, Optional<int64_t> const topk_group,
303321
int64_t const intermediate_size, int64_t const local_expert_offset,
304-
int64_t const local_num_experts, double const routed_scaling_factor,
322+
int64_t const local_num_experts, Optional<double> const routed_scaling_factor,
305323
int64_t const tile_tokens_dim, int64_t const routing_method_type,
306324
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex,
307325
bool enable_pdl) {
@@ -318,7 +336,11 @@ void trtllm_fp8_block_scale_moe_launcher(
318336
<< "This kernel requires 10.x architecture. Current device has SM "
319337
<< std::get<0>(device_props) << std::get<1>(device_props);
320338

321-
TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_float32) << "routing_logits must be float.";
339+
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
340+
TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_float32) << "routing_logits must be float.";
341+
} else {
342+
TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_bfloat16) << "routing_logits must be bfloat16.";
343+
}
322344
TVM_FFI_ICHECK_EQ(routing_logits->ndim, 2) << "routing_logits must be 2D.";
323345
TVM_FFI_ICHECK_EQ(routing_logits->shape[0], hidden_states->shape[0])
324346
<< "routing_logits and hidden_states must have the same number of tokens.";
@@ -333,18 +355,33 @@ void trtllm_fp8_block_scale_moe_launcher(
333355
<< "routing_bias has incorrect shape.";
334356
}
335357

336-
// if (n_group <= 0 || topk_group <= 0) {
337-
// TVM_FFI_ICHECK_EQ(top_k, 1) << "Current routing kernel (no groups) only supports top_k=1.";
338-
// } else {
339-
// TVM_FFI_ICHECK_LE(top_k, 8) << "Current routing kernel (with groups) only supports
340-
// top_k<=8."; TVM_FFI_ICHECK_LE(topk_group, 4)
341-
// << "Current routing kernel (with groups) only supports topk_group<=4.";
342-
// TVM_FFI_ICHECK_LE(topk_group, n_group) << "n_group must not be smaller than topk_group.";
343-
// TVM_FFI_ICHECK_EQ(num_experts % n_group, 0) << "num_experts must be divisible by n_group";
344-
// // This check ensures we have enough experts in the selected groups to handle the top_k
345-
// routing TVM_FFI_ICHECK_LT(top_k, (topk_group * num_experts / n_group))
346-
// << "top_k must be less than total number of experts in selected groups";
347-
// }
358+
if (n_group.has_value() && n_group.value() != 0) {
359+
TVM_FFI_ICHECK(static_cast<RoutingMethodType>(routing_method_type) ==
360+
RoutingMethodType::DeepSeekV3)
361+
<< "Routing kernel with groups implies DeepSeekV3 routing method.";
362+
TVM_FFI_ICHECK(topk_group.has_value()) << "if n_group is given, topk_group must be given";
363+
TVM_FFI_ICHECK_EQ(num_experts % n_group.value(), 0)
364+
<< "num_experts must be divisible by n_group";
365+
TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
366+
<< "Current routing kernel (with groups) only supports top_k<=8 && top_k>0.";
367+
TVM_FFI_ICHECK(topk_group.value() <= 4 && topk_group.value() > 0)
368+
<< "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0.";
369+
TVM_FFI_ICHECK_LE(topk_group.value(), n_group.value())
370+
<< "n_group must not be smaller than topk_group.";
371+
// This check ensures we have enough experts in the selected groups to handle the top_k routing
372+
TVM_FFI_ICHECK_LT(top_k, (topk_group.value() * num_experts / n_group.value()))
373+
<< "top_k must be less than total number of experts in selected groups";
374+
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
375+
RoutingMethodType::Renormalize ||
376+
static_cast<RoutingMethodType>(routing_method_type) ==
377+
RoutingMethodType::RenormalizeNaive) {
378+
TVM_FFI_ICHECK(top_k <= 10 && top_k > 0)
379+
<< "Current routing kernel (no groups, renormalize) only supports top_k<=10 && top_k>0.";
380+
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
381+
RoutingMethodType::Llama4) {
382+
TVM_FFI_ICHECK_EQ(top_k, 1)
383+
<< "Current routing kernel (no groups, Llama4) only supports top_k=1.";
384+
}
348385
TVM_FFI_ICHECK_EQ(num_experts % 4, 0)
349386
<< "Routing kernel expects that num_experts must be divisible by 4";
350387
TVM_FFI_ICHECK_GT(num_experts, top_k) << "num_experts must be greater than top_k";
@@ -380,11 +417,11 @@ void trtllm_fp8_block_scale_moe_launcher(
380417
args.hidden_size = hidden_states->shape[1];
381418
args.hidden_size_output = args.hidden_size;
382419
args.top_k = top_k;
383-
args.n_group = n_group;
384-
args.topk_group = topk_group;
420+
args.n_group = n_group.has_value() ? n_group.value() : 0;
421+
args.topk_group = topk_group.has_value() ? topk_group.value() : 0;
385422
args.local_expert_offset = local_expert_offset;
386423
args.local_num_experts = local_num_experts;
387-
args.routed_scaling_factor = routed_scaling_factor;
424+
args.routed_scaling_factor = routed_scaling_factor.has_value() ? routed_scaling_factor.value() : 1.0;
388425
args.intermediate_size = intermediate_size;
389426
args.mUseDeepSeekFp8 = true;
390427

@@ -569,10 +606,10 @@ void trtllm_fp8_block_scale_moe(Tensor routing_logits, Optional<Tensor> routing_
569606
Tensor hidden_states, Tensor hidden_states_scale,
570607
Tensor gemm1_weights, Tensor gemm1_weights_scale,
571608
Tensor gemm2_weights, Tensor gemm2_weights_scale, Tensor output,
572-
int64_t num_experts, int64_t top_k, int64_t n_group,
573-
int64_t topk_group, int64_t intermediate_size,
609+
int64_t num_experts, int64_t top_k, Optional<int64_t> n_group,
610+
Optional<int64_t> topk_group, int64_t intermediate_size,
574611
int64_t local_expert_offset, int64_t local_num_experts,
575-
double routed_scaling_factor, int64_t tile_tokens_dim,
612+
Optional<double> routed_scaling_factor, int64_t tile_tokens_dim,
576613
int64_t routing_method_type, bool use_shuffled_weight,
577614
int64_t weight_layout, bool enable_pdl) {
578615
auto dtype = hidden_states->dtype;

csrc/trtllm_fused_moe_routing_renormalize.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,6 @@ void run(Data const& data, void* stream) {
464464
}
465465
cudaDeviceSynchronize();
466466
cudaError_t result = cudaGetLastError();
467-
std::cout << "cudaGetLastError: " << cudaGetErrorString(result) << std::endl;
468467
}
469468

470469
// void run(Data const& data, void* stream) {

flashinfer/fused_moe/core.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,12 +1069,12 @@ def trtllm_fp8_per_tensor_scale_moe_op(
10691069
output2_scales_scalar: torch.Tensor,
10701070
num_experts: int,
10711071
top_k: int,
1072-
n_group: int,
1073-
topk_group: int,
1072+
n_group: Optional[int],
1073+
topk_group: Optional[int],
10741074
intermediate_size: int,
10751075
local_expert_offset: int,
10761076
local_num_experts: int,
1077-
routed_scaling_factor: float,
1077+
routed_scaling_factor: Optional[float],
10781078
use_routing_scales_on_input: bool,
10791079
tile_tokens_dim: int = 8,
10801080
routing_method_type: int = 0,
@@ -1119,8 +1119,8 @@ def _fake_trtllm_fp8_per_tensor_scale_moe(
11191119
output2_scales_scalar: torch.Tensor,
11201120
num_experts: int,
11211121
top_k: int,
1122-
n_group: int,
1123-
topk_group: int,
1122+
n_group: Optional[int],
1123+
topk_group: Optional[int],
11241124
intermediate_size: int,
11251125
local_expert_offset: int,
11261126
local_num_experts: int,
@@ -1151,8 +1151,8 @@ def trtllm_fp8_block_scale_moe_op(
11511151
output: torch.Tensor,
11521152
num_experts: int,
11531153
top_k: int,
1154-
n_group: int,
1155-
topk_group: int,
1154+
n_group: Optional[int],
1155+
topk_group: Optional[int],
11561156
intermediate_size: int,
11571157
local_expert_offset: int,
11581158
local_num_experts: int,
@@ -1207,8 +1207,8 @@ def _fake_trtllm_fp8_block_scale_moe(
12071207
output: torch.Tensor,
12081208
num_experts: int,
12091209
top_k: int,
1210-
n_group: int,
1211-
topk_group: int,
1210+
n_group: Optional[int],
1211+
topk_group: Optional[int],
12121212
intermediate_size: int,
12131213
local_expert_offset: int,
12141214
local_num_experts: int,
@@ -1469,12 +1469,12 @@ def trtllm_fp8_per_tensor_scale_moe(
14691469
output2_scales_scalar: torch.Tensor,
14701470
num_experts: int,
14711471
top_k: int,
1472-
n_group: int,
1473-
topk_group: int,
1472+
n_group: Optional[int],
1473+
topk_group: Optional[int],
14741474
intermediate_size: int,
14751475
local_expert_offset: int,
14761476
local_num_experts: int,
1477-
routed_scaling_factor: float,
1477+
routed_scaling_factor: Optional[float],
14781478
use_routing_scales_on_input: bool,
14791479
tile_tokens_dim: int = 8,
14801480
routing_method_type: int = 0,
@@ -1542,8 +1542,8 @@ def trtllm_fp8_block_scale_moe(
15421542
gemm2_weights_scale: torch.Tensor,
15431543
num_experts: int,
15441544
top_k: int,
1545-
n_group: int,
1546-
topk_group: int,
1545+
n_group: Optional[int],
1546+
topk_group: Optional[int],
15471547
intermediate_size: int,
15481548
local_expert_offset: int,
15491549
local_num_experts: int,

tests/moe/test_trtllm_gen_fused_moe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def capture(self, hidden_states_sample, **runtime_args):
105105
self.input_tensor = hidden_states_sample.clone()
106106

107107
# Warmup
108-
with torch.cuda.stream(torch_stream), autotune(True):
108+
with torch.cuda.stream(torch_stream), autotune(False):
109109
for _ in range(1):
110110
self._run_moe_computation(runtime_args)
111111

@@ -1836,9 +1836,9 @@ def cache_permute_indices():
18361836
return _cache_permute_indices
18371837

18381838

1839-
@pytest.mark.parametrize("num_tokens", [1, 8, 1024])
1839+
@pytest.mark.parametrize("num_tokens", [1, 8, 1024,512])
18401840
@pytest.mark.parametrize("hidden_size", [1024, 2048, 8192])
1841-
@pytest.mark.parametrize("intermediate_size", [2048, 1024, 5120, 768, 384])
1841+
@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 384, 512])
18421842
@pytest.mark.parametrize(
18431843
"moe_impl",
18441844
[
@@ -1905,15 +1905,15 @@ def cache_permute_indices():
19051905
),
19061906
pytest.param(
19071907
{
1908-
"num_experts": 512,
1909-
"top_k": 10,
1908+
"num_experts": 256,
1909+
"top_k": 8,
19101910
"padding": 8,
19111911
"n_groups": None,
19121912
"top_k_groups": None,
19131913
"routed_scaling": None,
19141914
"has_routing_bias": False,
19151915
"routing_method_type": RoutingMethodType.Renormalize,
1916-
"compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe],
1916+
"compatible_moe_impls": [FP8PerTensorMoe, FP8BlockScaleMoe, FP4Moe],
19171917
},
19181918
id="Renorm",
19191919
# marks=pytest.mark.skip(

0 commit comments

Comments
 (0)