-
Notifications
You must be signed in to change notification settings - Fork 585
Update trtllm-gen fused moe routing kernel and add more kernels #1955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
553724f
eac52be
1906f4a
330a6f6
259820d
7f8c255
ee54764
ed54446
a438123
321d2fd
4a8d9a7
38496fb
af6860a
e0599d2
079d089
2e6d1b2
ecb5be0
e55b52b
0e88417
67a5ffb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,13 +8,109 @@ | |||||||||||||||||||||||
| fp4_quantize, | ||||||||||||||||||||||||
| mxfp8_quantize, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| from flashinfer.fused_moe import trtllm_fp4_block_scale_moe | ||||||||||||||||||||||||
| from flashinfer.fused_moe import ( | ||||||||||||||||||||||||
| trtllm_fp4_block_scale_moe, | ||||||||||||||||||||||||
| trtllm_fp8_per_tensor_scale_moe, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| from flashinfer.autotuner import autotune | ||||||||||||||||||||||||
| from flashinfer.testing.utils import bench_gpu_time | ||||||||||||||||||||||||
| from flashinfer.utils import device_support_pdl | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max | ||||||||||||||||||||||||
| FLOAT4_E2M1_MAX = 6.0 | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def fp8_quantize(x): | ||||||||||||||||||||||||
| max = x.float().abs().nan_to_num().max() | ||||||||||||||||||||||||
| scale = FLOAT8_E4M3_MAX / max | ||||||||||||||||||||||||
| x = (x * scale).to(torch.float8_e4m3fn) | ||||||||||||||||||||||||
| return x, 1.0 / scale | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def bench_trtllm_gen_fused_moe_autotuner( | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def bench_trtllm_gen_fused_moe_autotuner_fp8( | ||||||||||||||||||||||||
| tune_max_num_tokens: Optional[int], | ||||||||||||||||||||||||
| quant_mode: Literal["Fp8-Per-Tensor"], | ||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unused parameter. The Apply this diff if the parameter is not needed: def bench_trtllm_gen_fused_moe_autotuner_fp8(
tune_max_num_tokens: Optional[int],
- quant_mode: Literal["Fp8-Per-Tensor"],
num_tokens: int,
num_experts: int,
hidden_size: int,
intermediate_size: int,
top_k: int,
warmups: int,
iterations: int,
):📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.14.3)29-29: Unused function argument: (ARG001) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||
| num_tokens: int, | ||||||||||||||||||||||||
| num_experts: int, | ||||||||||||||||||||||||
| hidden_size: int, | ||||||||||||||||||||||||
| intermediate_size: int, | ||||||||||||||||||||||||
| top_k: int, | ||||||||||||||||||||||||
| warmups: int, | ||||||||||||||||||||||||
| iterations: int, | ||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||
| device = torch.device("cuda:0") | ||||||||||||||||||||||||
| enable_pdl = device_support_pdl(device) | ||||||||||||||||||||||||
| routing_logits = torch.rand(num_tokens, num_experts, device=device).to( | ||||||||||||||||||||||||
| torch.bfloat16 | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| hidden_states = torch.randn(num_tokens, hidden_size, device=device).to( | ||||||||||||||||||||||||
| torch.bfloat16 | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| w13 = torch.randn( | ||||||||||||||||||||||||
| num_experts, intermediate_size * 2, hidden_size, device=device | ||||||||||||||||||||||||
| ).to(torch.bfloat16) | ||||||||||||||||||||||||
| w2 = torch.randn(num_experts, hidden_size, intermediate_size, device=device).to( | ||||||||||||||||||||||||
| torch.bfloat16 | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| hidden_states, hidden_states_scale = fp8_quantize(hidden_states) | ||||||||||||||||||||||||
| w13, w13_scale = fp8_quantize(w13) | ||||||||||||||||||||||||
| w2, w2_scale = fp8_quantize(w2) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| output1_scale_scalar = torch.tensor( | ||||||||||||||||||||||||
| [hidden_states_scale * w13_scale] * num_experts, device=device | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| output1_scales_gate_scalar = torch.ones( | ||||||||||||||||||||||||
| num_experts, device=device, dtype=torch.float32 | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| output2_scale_scalar = torch.tensor( | ||||||||||||||||||||||||
| [hidden_states_scale * w2_scale] * num_experts, device=device | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
Comment on lines
+60
to
+68
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Construct FP8 scale vectors without CPU conversion errors. Lines 60-68 call Apply this diff to keep the scales on CUDA: - output1_scale_scalar = torch.tensor(
- [hidden_states_scale * w13_scale] * num_experts, device=device
- )
+ scale_prod_1 = (hidden_states_scale * w13_scale).item()
+ output1_scale_scalar = torch.full(
+ (num_experts,),
+ scale_prod_1,
+ device=device,
+ dtype=torch.float32,
+ )
@@
- output2_scale_scalar = torch.tensor(
- [hidden_states_scale * w2_scale] * num_experts, device=device
- )
+ scale_prod_2 = (hidden_states_scale * w2_scale).item()
+ output2_scale_scalar = torch.full(
+ (num_experts,),
+ scale_prod_2,
+ device=device,
+ dtype=torch.float32,
+ )🤖 Prompt for AI Agents |
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| fn = lambda: trtllm_fp8_per_tensor_scale_moe( | ||||||||||||||||||||||||
| routing_logits, | ||||||||||||||||||||||||
| None, # routing_bias | ||||||||||||||||||||||||
| hidden_states, | ||||||||||||||||||||||||
| w13, | ||||||||||||||||||||||||
| output1_scale_scalar, | ||||||||||||||||||||||||
| output1_scales_gate_scalar, | ||||||||||||||||||||||||
| w2, | ||||||||||||||||||||||||
| output2_scale_scalar, | ||||||||||||||||||||||||
| num_experts, | ||||||||||||||||||||||||
| top_k, | ||||||||||||||||||||||||
| None, # n_group | ||||||||||||||||||||||||
| None, # topk_group | ||||||||||||||||||||||||
| intermediate_size, | ||||||||||||||||||||||||
| 0, # local_expert_offset | ||||||||||||||||||||||||
| num_experts, | ||||||||||||||||||||||||
| 1.0, # routed_scaling_factor | ||||||||||||||||||||||||
| False, # use_routing_scales_on_input | ||||||||||||||||||||||||
| None, | ||||||||||||||||||||||||
| RoutingMethodType.TopK.value, | ||||||||||||||||||||||||
| enable_pdl, | ||||||||||||||||||||||||
| num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def bench(do_autotune): | ||||||||||||||||||||||||
| with autotune(do_autotune): | ||||||||||||||||||||||||
| fn() | ||||||||||||||||||||||||
| ms_list = bench_gpu_time( | ||||||||||||||||||||||||
| fn, | ||||||||||||||||||||||||
| dry_run_iters=warmups, | ||||||||||||||||||||||||
| repeat_iters=iterations, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| median_ms = np.median(ms_list) | ||||||||||||||||||||||||
| return median_ms | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| ms = bench(do_autotune=False) | ||||||||||||||||||||||||
| ms_tuned = bench(do_autotune=True) | ||||||||||||||||||||||||
| print( | ||||||||||||||||||||||||
| f"num tokens: {num_tokens}, num experts: {num_experts}, hidden size: {hidden_size}, intermediate size: {intermediate_size}, top k: {top_k}" | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| print(f"No autotune: {ms:.3f} ms; with autotune: {ms_tuned:.3f} ms") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def bench_trtllm_gen_fused_moe_autotuner_fp4( | ||||||||||||||||||||||||
| tune_max_num_tokens: Optional[int], | ||||||||||||||||||||||||
| quant_mode: Literal["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], | ||||||||||||||||||||||||
| num_tokens: int, | ||||||||||||||||||||||||
|
|
@@ -143,12 +239,11 @@ def bench_trtllm_gen_fused_moe_autotuner( | |||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def bench(do_autotune): | ||||||||||||||||||||||||
| # warmup | ||||||||||||||||||||||||
| with autotune(do_autotune): | ||||||||||||||||||||||||
| for _ in range(warmups): | ||||||||||||||||||||||||
| fn() | ||||||||||||||||||||||||
| fn() | ||||||||||||||||||||||||
| ms_list = bench_gpu_time( | ||||||||||||||||||||||||
| fn, | ||||||||||||||||||||||||
| dry_run_iters=warmups, | ||||||||||||||||||||||||
| repeat_iters=iterations, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| median_ms = np.median(ms_list) | ||||||||||||||||||||||||
|
|
@@ -168,7 +263,7 @@ def bench(do_autotune): | |||||||||||||||||||||||
| "--quant-mode", | ||||||||||||||||||||||||
| type=str, | ||||||||||||||||||||||||
| default="MxFP4xMxFP8", | ||||||||||||||||||||||||
| choices=["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], | ||||||||||||||||||||||||
| choices=["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16", "Fp8-Per-Tensor"], | ||||||||||||||||||||||||
| help="Quantization mode", | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| parser.add_argument("--num-tokens", type=int, default=512, help="Number of tokens") | ||||||||||||||||||||||||
|
|
@@ -193,14 +288,27 @@ def bench(do_autotune): | |||||||||||||||||||||||
| "--iterations", type=int, default=100, help="Number of benchmark iterations" | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| args = parser.parse_args() | ||||||||||||||||||||||||
| bench_trtllm_gen_fused_moe_autotuner( | ||||||||||||||||||||||||
| args.tune_max_num_tokens, | ||||||||||||||||||||||||
| args.quant_mode, | ||||||||||||||||||||||||
| args.num_tokens, | ||||||||||||||||||||||||
| args.num_experts, | ||||||||||||||||||||||||
| args.hidden_size, | ||||||||||||||||||||||||
| args.intermediate_size, | ||||||||||||||||||||||||
| args.top_k, | ||||||||||||||||||||||||
| args.warmups, | ||||||||||||||||||||||||
| args.iterations, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| if args.quant_mode == "Fp8-Per-Tensor": | ||||||||||||||||||||||||
| bench_trtllm_gen_fused_moe_autotuner_fp8( | ||||||||||||||||||||||||
| args.tune_max_num_tokens, | ||||||||||||||||||||||||
| args.quant_mode, | ||||||||||||||||||||||||
| args.num_tokens, | ||||||||||||||||||||||||
| args.num_experts, | ||||||||||||||||||||||||
| args.hidden_size, | ||||||||||||||||||||||||
| args.intermediate_size, | ||||||||||||||||||||||||
| args.top_k, | ||||||||||||||||||||||||
| args.warmups, | ||||||||||||||||||||||||
| args.iterations, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| bench_trtllm_gen_fused_moe_autotuner_fp4( | ||||||||||||||||||||||||
| args.tune_max_num_tokens, | ||||||||||||||||||||||||
| args.quant_mode, | ||||||||||||||||||||||||
| args.num_tokens, | ||||||||||||||||||||||||
| args.num_experts, | ||||||||||||||||||||||||
| args.hidden_size, | ||||||||||||||||||||||||
| args.intermediate_size, | ||||||||||||||||||||||||
| args.top_k, | ||||||||||||||||||||||||
| args.warmups, | ||||||||||||||||||||||||
| args.iterations, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -63,13 +63,22 @@ std::set<int32_t> computeSelectedTileN(std::vector<int32_t> const& supported_til | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int64_t const num_tokens, int64_t const top_k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int64_t const num_local_experts) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / num_local_experts; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // assume supported_tile_nums is sorted | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| supported_tile_nums.front(), supported_tile_nums.back()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::set<int32_t> selected_tile_nums = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::max(supported_tile_nums.front(), tile_tokens_dim / 2), tile_tokens_dim, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::min(supported_tile_nums.back(), tile_tokens_dim * 2), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::min(supported_tile_nums.back(), tile_tokens_dim * 4)}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::set<int32_t> selected_tile_nums; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| selected_tile_nums.insert(tile_tokens_dim); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (std::next(it) != supported_tile_nums.end()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| selected_tile_nums.insert(*std::next(it)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (std::next(std::next(it)) != supported_tile_nums.end()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| selected_tile_nums.insert(*std::next(std::next(it))); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (it != supported_tile_nums.begin()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| selected_tile_nums.insert(*std::prev(it)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
67
to
82
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard against missing tile entry before iterating neighbors
- auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
-
- std::set<int32_t> selected_tile_nums;
- selected_tile_nums.insert(tile_tokens_dim);
+ auto it =
+ std::lower_bound(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
+ if (it == supported_tile_nums.end()) {
+ it = std::prev(supported_tile_nums.end());
+ }
+ tile_tokens_dim = *it;
+
+ std::set<int32_t> selected_tile_nums;
+ selected_tile_nums.insert(tile_tokens_dim);
if (std::next(it) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(it));
if (std::next(std::next(it)) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(std::next(it)));
}
}
if (it != supported_tile_nums.begin()) {
selected_tile_nums.insert(*std::prev(it));
}📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. We should make sure
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @IwakuraRein, I appreciate you acknowledging the issue! However, I notice that the current implementation already includes non-power-of-2 values in More importantly, even if we constrain
Recommendation: I suggest implementing CodeRabbit's defensive fix using auto it = std::lower_bound(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
if (it == supported_tile_nums.end()) {
it = std::prev(it); // Use largest supported value
}
tile_tokens_dim = *it; // Snap to actual supported valueThis ensures robustness regardless of the supported tile set. If you want to enforce a design constraint (consecutive powers of 2), consider adding a debug assertion in the callers that construct Would you like me to open an issue to track adding this fix? |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return selected_tile_nums; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -369,7 +378,7 @@ void trtllm_fp8_per_tensor_scale_moe( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto const hidden_size = hidden_states.size(1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool mUseDeepSeekFp8{false}; // FP8 per-tensor doesn't use DeepSeek FP8 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128, 192, 256}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::set<int32_t> selected_tile_nums = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -718,7 +727,7 @@ void trtllm_fp8_block_scale_moe( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto const num_tokens = hidden_states.size(0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto const hidden_size = hidden_states.size(1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::set<int32_t> selected_tile_nums = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1228,6 +1237,11 @@ Array<Tensor> trtllm_fp4_block_scale_moe( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (mDtypeAct != btg::Dtype::Bfloat16) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mSupportedTileN.push_back(128); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ((mDtypeAct == btg::Dtype::MxE4m3 && mDtypeWeights == btg::Dtype::MxE2m1) || | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (mDtypeAct == btg::Dtype::E2m1 && mDtypeWeights == btg::Dtype::E2m1)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // MxFP4 x MxFP4 or NvFP4 x NvFP4 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mSupportedTileN.push_back(256); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::set<int32_t> selected_tile_nums = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Build runners for all supported tile sizes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1305,8 +1319,20 @@ Array<Array<int64_t>> trtllm_get_valid_moe_configs( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool is_fp8_per_tensor = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (is_fp4_without_bf16_act || is_fp8_per_tensor) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (useDeepSeekFp8) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| supported_tile_nums.push_back(128); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else if (is_fp8_per_tensor) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| supported_tile_nums.push_back(128); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| supported_tile_nums.push_back(192); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| supported_tile_nums.push_back(256); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else if (is_fp4_without_bf16_act) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| supported_tile_nums.push_back(128); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ((dtype_act == btg::Dtype::MxE4m3 && dtype_weights == btg::Dtype::MxE2m1) || | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (dtype_act == btg::Dtype::E2m1 && dtype_weights == btg::Dtype::E2m1)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // MxFP4 x MxFP4 or NvFP4 x NvFP4 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| supported_tile_nums.push_back(256); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::set<int32_t> selected_tile_nums = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard FP8 quantization against all-zero inputs
All-zero inputs make
maxzero,scaleinfinite, and the quantized tensor NaN (0 × ∞), which breaks the benchmark when buffers start cleared. Please handle the zero case before inverting the scale.🤖 Prompt for AI Agents