fix: support fp32 logits for fp8_per_tensor and fp8_block#2534
fix: support fp32 logits for fp8_per_tensor and fp8_block#2534yweng0828 wants to merge 6 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @yweng0828, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the fused MoE kernels to support FP32 logits, which is necessary for compatibility with certain models like DeepSeekV3. The changes involve modifications to the kernel launcher, runner, and test suite to accommodate the new data type. This ensures that the MoE kernels can handle a wider range of models and configurations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces support for fp32 logits in the fused MoE kernels, specifically for fp8_per_tensor and fp8_block quantization modes. This is achieved by adding a mDtypeScore member to FusedMoeLauncher and routingRenormalize::Data structs, and updating the routing_runner.run calls and kernel dispatch macros to utilize this new dtype. The routing_logits dtype validation logic in trtllm_fp8_per_tensor_scale_moe and trtllm_fp8_block_scale_moe functions is relaxed to allow float32 where appropriate, while still enforcing float32 for DeepSeekV3 routing. Corresponding test cases are updated to parameterize logits_dtype and include new skip conditions to ensure compatibility. The changes are consistent across the codebase and align with the stated goal of the pull request.
| workspace.token_scales = expert_weights.data_ptr(); // Consumed by permuteGemm1 kernel | ||
| } | ||
| if (routing_logits.has_value()) { | ||
| mDtypeScore = |
There was a problem hiding this comment.
Should this piece of code be part of the FusedMoeLauncher class so that all child classes can share it? It seems that this logic is currently in the Fp8PerTensorLauncher class. Also, we might want to add an assertion to check the data type of routing_logits.
TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16)
<< "BF16 MoE: routing_logits must be bfloat16 or float.";
There was a problem hiding this comment.
Thanks for pointing it out. I have refactored this part of the logic and moved it to the base class.
| kernel, numBlocks, numThreads, smemSize, stream); \ | ||
| } else { \ | ||
| FLASHINFER_WARN("Unsupported dtypeExpW"); \ | ||
| FLASHINFER_WARN("Unsupported mDtypeScore/mDtypeExpW combination"); \ |
There was a problem hiding this comment.
How about use this infor: Unsupported combination of mDtypeScore and mDtypeExpW
a62decc to
0c876d4
Compare
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds an explicit score dtype (mDtypeScore / dtypeScore), threads it Launcher → Runner → Kernel, updates routing validations (DeepSeekV3 requires FP32), changes Runner::run signature to accept dtypeScore, extends kernel launch branching on score dtype, and parameterizes tests to pass logits dtype. Changes
Sequence Diagram(s)sequenceDiagram
participant Launcher
participant Runner
participant Kernel
participant GPU
Launcher->>Runner: run(..., dtypeScore, dtypeElt, dtypeBias, ..., stream)
Runner->>Kernel: prepare routingData (mDtypeScore := dtypeScore, mDtypeElt := dtypeElt, ...)
Kernel->>GPU: launch routing kernels using mDtypeScore and mDtypeElt
GPU-->>Kernel: routing results
Kernel-->>Runner: routing outputs (indices, counts)
Runner-->>Launcher: return routing results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can use OpenGrep to find security vulnerabilities and bugs across 17+ programming languages.OpenGrep is compatible with Semgrep configurations. Add an |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)
219-221:⚠️ Potential issue | 🟡 MinorStale error message in
LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG.The else-branch error message still says
"Unsupported dtypeExpW", but the macro now gates onmDtypeScore,mDtypeBias, andmDtypeExpW. Update it similar to line 269 for consistency and easier debugging.Proposed fix
} else { \ - FLASHINFER_WARN("Unsupported dtypeExpW"); \ + FLASHINFER_WARN("Unsupported combination of mDtypeScore, mDtypeBias, and mDtypeExpW"); \ }tests/moe/test_dpsk_fused_moe_fp8.py (1)
615-624:⚠️ Potential issue | 🔴 CriticalMissing
routing_method_typekey inrouting_configwill causeKeyErrorin the updatedskip_checks.The
routing_configdicts defined at lines 510–548 don't contain a"routing_method_type"key, but the new check inskip_checks(line 148 ofutils.py) accessesrouting_config["routing_method_type"]unconditionally. This will crash every test case in this file.Either add
"routing_method_type"to each routing config dict, or use.get()with a default inskip_checks:Option 1: Fix in utils.py (safer — handles callers that don't set the key)
if ( - routing_config["routing_method_type"] == RoutingMethodType.DeepSeekV3 + routing_config.get("routing_method_type") == RoutingMethodType.DeepSeekV3 and logits_dtype != torch.float32 ):Option 2: Fix in this test file (add routing_method_type to each config)
For the DSv3 config:
{ "num_experts": 256, "top_k": 8, + "routing_method_type": RoutingMethodType.DeepSeekV3, ... },And similarly for other configs with the appropriate
RoutingMethodType.tests/moe/test_trtllm_gen_fused_moe.py (1)
2883-2893:⚠️ Potential issue | 🔴 CriticalBug:
logits_dtypeandcache_permute_indicesarguments are swapped.The
run_moe_testsignature (line 2337) expectscache_permute_indicesas the 8th positional arg andlogits_dtypeas the 9th. Here, they are passed in the opposite order. This will causemoe_impl._cache_permute_indicesto be set to atorch.dtypeandexpert_logits.to(logits_dtype)to receive a dict, resulting in a runtime crash.Compare with
test_renormalize_routing(line 2695–2696),test_topk_routing(line 2975–2976), andtest_llama4_routing(line 3056–3057), which all pass the arguments in the correct order.🐛 Proposed fix
run_moe_test( num_tokens, hidden_size, intermediate_size, moe_impl, routing_config, weight_processing, activation_type, - logits_dtype, cache_permute_indices, + logits_dtype, )
🤖 Fix all issues with AI agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 288-298: The code currently sets mDtypeScore based solely on
routing_method_type which forces BFloat16 for non-DeepSeekV3 even when
routing_logits are float32; update the block that runs when
routing_logits.has_value() so mDtypeScore is derived from
routing_logits.value().dtype(): if routing_method_type ==
RoutingMethodType::DeepSeekV3 keep the TVM_FFI_ICHECK_EQ asserting dtype is
dl_float32 and set mDtypeScore = btg::Dtype::Fp32; otherwise inspect
routing_logits.value().dtype() and set mDtypeScore = btg::Dtype::Fp32 for
dl_float32, btg::Dtype::Bfloat16 for dl_bfloat16 (and error/ICHECK for
unsupported dtypes). Reference symbols: mDtypeScore, routing_logits,
RoutingMethodType::DeepSeekV3.
In `@tests/moe/utils.py`:
- Around line 155-162: The condition incorrectly compares type(moe_impl) to
QuantMode enum values causing all FP32-logits tests to skip; change the check to
inspect moe_impl.quant_mode instead. Update the if-statement that currently
reads "if logits_dtype == torch.float32 and type(moe_impl) not in
[QuantMode...]" to use "moe_impl.quant_mode not in [QuantMode.FP8_PER_TENSOR,
QuantMode.FP8_BLOCK_SCALE, QuantMode.BF16]" so the pytest.skip call only
triggers for incompatible quant modes; keep the existing pytest.skip message and
variables (logits_dtype, moe_impl, QuantMode) unchanged.
🧹 Nitpick comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)
2865-2870: DeepSeekV3 routing is only parametrized with FP32 logits — intentional?Unlike
test_renormalize_routingandtest_topk_routingwhich test both FP32 and BF16, this test only exercises FP32 logits. If BF16 logits are also a valid input for DeepSeekV3 routing in production, consider adding BF16 coverage here too.
|
/bot run |
|
@yweng0828 is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
|
hi @yweng0828 thx for the contrib
|
|
[SUCCESS] Pipeline #44270124: 16/20 passed |
|
Hi @aleozlx , thank you for your review. The PR is ready. Local testing has also passed. |
|
@yweng0828 Does the change also apply to |
Hi @wenscarl, No, this change does not apply to |
|
@aleozlx Any update on merging this? |
|
Hi @yweng0828, can you rebase to main so we can restart the CI? |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tests/moe/test_trtllm_gen_fused_moe.py (2)
3131-3140:⚠️ Potential issue | 🟠 Major
logits_dtypeandcache_permute_indicesare swapped in this call.
run_moe_test()expects(cache_permute_indices, logits_dtype). In the current order, the cache dict is treated as the dtype, so the DeepSeekV3 guard inskip_checks()skips the whole suite instead of exercising the new FP32-logits path.Suggested fix
run_moe_test( num_tokens, hidden_size, intermediate_size, moe_impl, routing_config, weight_processing, activation_type, - logits_dtype, cache_permute_indices, + logits_dtype, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 3131 - 3140, The call to run_moe_test is passing logits_dtype and cache_permute_indices in the wrong order so the cache dict is being interpreted as the dtype; update the call site to swap those two arguments so you pass cache_permute_indices first and logits_dtype second. Locate the call to run_moe_test (the one with parameters num_tokens, hidden_size, intermediate_size, moe_impl, routing_config, weight_processing, activation_type, logits_dtype, cache_permute_indices) and reorder the last two args accordingly to avoid triggering the DeepSeekV3 guard in skip_checks().
2567-2577:⚠️ Potential issue | 🟠 MajorThis helper change breaks the existing GEMM-bias test.
The
test_nvfp4_moe_gemm_bias()call near Line 3332 still invokesrun_moe_test()withoutlogits_dtype, so this new required parameter turns that test into aTypeErrorbefore the bias path runs. Either givelogits_dtypea backward-compatible default or update the remaining caller.One backward-compatible option
- logits_dtype, + logits_dtype=torch.bfloat16,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 2567 - 2577, The new required parameter logits_dtype on run_moe_test breaks callers like test_nvfp4_moe_gemm_bias that still call run_moe_test() without it; make logits_dtype optional with a backward-compatible default (e.g., None or the previous default dtype) in the run_moe_test signature and branch inside run_moe_test (or set a local default variable) so existing callers continue to exercise the GEMM-bias path without modification; alternatively, update all callers such as test_nvfp4_moe_gemm_bias to pass the intended logits_dtype explicitly, but prefer adding the default to run_moe_test to avoid many caller changes.
♻️ Duplicate comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
313-322:⚠️ Potential issue | 🔴 CriticalDerive
mDtypeScorefromrouting_logits.dtype()for non-DeepSeek paths.This still forces
btg::Dtype::Bfloat16for every non-DeepSeek route. The newrouting_runner.run(..., mDtypeScore, ...)plumbing forwards that value intoroutingData.mDtypeScore, so FP32 logits added by this PR are still dispatched/read as BF16 and will produce incorrect routing weights.Suggested fix
// Set dtype of score if (routing_logits.has_value()) { if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) { TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) << "routing_logits must be float."; mDtypeScore = btg::Dtype::Fp32; + } else if (routing_logits.value().dtype() == dl_float32) { + mDtypeScore = btg::Dtype::Fp32; } else { mDtypeScore = btg::Dtype::Bfloat16; } }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 313 - 322, The code currently forces mDtypeScore = btg::Dtype::Bfloat16 for all non-DeepSeek routes causing FP32 routing_logits to be misinterpreted; change the non-DeepSeek branch to derive mDtypeScore from routing_logits.value().dtype() (e.g., map dl_float32 -> btg::Dtype::Fp32, dl_bfloat16 -> btg::Dtype::Bfloat16, etc.) while keeping the existing RoutingMethodType::DeepSeekV3 check that ICHECKs float32 and sets Fp32; this ensures the value passed into routing_runner.run(...) and stored in routingData.mDtypeScore matches the actual routing_logits.dtype().tests/moe/utils.py (1)
170-177:⚠️ Potential issue | 🟠 MajorThe FP32-logits whitelist can never succeed.
type(moe_impl)is a class, not aQuantMode, andQuantMode.FP8_BLOCK_SCALEis not a member of this enum. On FP32 cases this branch either raisesAttributeErroror skips every implementation, so the new coverage never runs.Suggested fix
- if logits_dtype == torch.float32 and type(moe_impl) not in [ - QuantMode.FP8_PER_TENSOR, - QuantMode.FP8_BLOCK_SCALE, - QuantMode.BF16, - ]: + if logits_dtype == torch.float32 and moe_impl.quant_mode not in [ + QuantMode.FP8_PER_TENSOR, + QuantMode.FP8_BLOCK_SCALE_DEEPSEEK, + QuantMode.FP8_BLOCK_SCALE_MXFP8, + QuantMode.BF16, + ]: pytest.skip( f"Incompatible: logits_dtype={logits_dtype} with {type(moe_impl).__name__} + {moe_impl.quant_mode}" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/utils.py` around lines 170 - 177, The condition is comparing type(moe_impl) to enum members (and includes a non-existent QuantMode.FP8_BLOCK_SCALE), so replace the check with a comparison against moe_impl.quant_mode (e.g., if logits_dtype == torch.float32 and moe_impl.quant_mode not in [QuantMode.FP8_PER_TENSOR, QuantMode.BF16]:) and remove or correct the invalid QuantMode member; keep the pytest.skip call but use moe_impl.__class__.__name__ and moe_impl.quant_mode in the message to preserve context.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 3131-3140: The call to run_moe_test is passing logits_dtype and
cache_permute_indices in the wrong order so the cache dict is being interpreted
as the dtype; update the call site to swap those two arguments so you pass
cache_permute_indices first and logits_dtype second. Locate the call to
run_moe_test (the one with parameters num_tokens, hidden_size,
intermediate_size, moe_impl, routing_config, weight_processing, activation_type,
logits_dtype, cache_permute_indices) and reorder the last two args accordingly
to avoid triggering the DeepSeekV3 guard in skip_checks().
- Around line 2567-2577: The new required parameter logits_dtype on run_moe_test
breaks callers like test_nvfp4_moe_gemm_bias that still call run_moe_test()
without it; make logits_dtype optional with a backward-compatible default (e.g.,
None or the previous default dtype) in the run_moe_test signature and branch
inside run_moe_test (or set a local default variable) so existing callers
continue to exercise the GEMM-bias path without modification; alternatively,
update all callers such as test_nvfp4_moe_gemm_bias to pass the intended
logits_dtype explicitly, but prefer adding the default to run_moe_test to avoid
many caller changes.
---
Duplicate comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 313-322: The code currently forces mDtypeScore =
btg::Dtype::Bfloat16 for all non-DeepSeek routes causing FP32 routing_logits to
be misinterpreted; change the non-DeepSeek branch to derive mDtypeScore from
routing_logits.value().dtype() (e.g., map dl_float32 -> btg::Dtype::Fp32,
dl_bfloat16 -> btg::Dtype::Bfloat16, etc.) while keeping the existing
RoutingMethodType::DeepSeekV3 check that ICHECKs float32 and sets Fp32; this
ensures the value passed into routing_runner.run(...) and stored in
routingData.mDtypeScore matches the actual routing_logits.dtype().
In `@tests/moe/utils.py`:
- Around line 170-177: The condition is comparing type(moe_impl) to enum members
(and includes a non-existent QuantMode.FP8_BLOCK_SCALE), so replace the check
with a comparison against moe_impl.quant_mode (e.g., if logits_dtype ==
torch.float32 and moe_impl.quant_mode not in [QuantMode.FP8_PER_TENSOR,
QuantMode.BF16]:) and remove or correct the invalid QuantMode member; keep the
pytest.skip call but use moe_impl.__class__.__name__ and moe_impl.quant_mode in
the message to preserve context.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7b10c190-eea2-4f30-b759-706e05e54311
📒 Files selected for processing (6)
csrc/trtllm_fused_moe_kernel_launcher.cucsrc/trtllm_fused_moe_runner.cuinclude/flashinfer/trtllm/fused_moe/RoutingKernel.htests/moe/test_dpsk_fused_moe_fp8.pytests/moe/test_trtllm_gen_fused_moe.pytests/moe/utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
- include/flashinfer/trtllm/fused_moe/RoutingKernel.h
- tests/moe/test_dpsk_fused_moe_fp8.py
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tests/moe/test_trtllm_gen_fused_moe.py (2)
2567-2577:⚠️ Potential issue | 🟠 Major
run_moe_testnow requireslogits_dtype, but not all call sites provide it.Making
logits_dtypemandatory here breaks unchanged callers (e.g.,test_nvfp4_moe_gemm_bias) with a missing-argument failure. Please keep this helper backward-compatible.Proposed fix
def run_moe_test( @@ - logits_dtype, + logits_dtype=torch.float32, zero_hidden_states=False,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 2567 - 2577, The helper function run_moe_test was made incompatible by adding a required parameter logits_dtype; revert this by giving logits_dtype a sensible default (e.g., None or a default dtype like torch.float32) in run_moe_test's signature so existing callers (for example test_nvfp4_moe_gemm_bias) continue to work, and then update any internal uses of logits_dtype inside run_moe_test (and related helpers) to handle the default case (use the default dtype when logits_dtype is None or unspecified) so behavior remains backward-compatible.
3151-3160:⚠️ Potential issue | 🟠 MajorDeepSeek test passes
run_moe_testarguments in the wrong order.At Line 3159-Line 3160,
logits_dtypeandcache_permute_indicesare swapped. This causes runtime type errors when creatingexpert_logits.Proposed fix
run_moe_test( @@ - activation_type, - logits_dtype, - cache_permute_indices, + activation_type, + cache_permute_indices, + logits_dtype, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 3151 - 3160, The call to run_moe_test in the DeepSeek test passes logits_dtype and cache_permute_indices in the wrong order, causing runtime type errors when expert_logits are created; fix it by swapping those two arguments so that the parameter named logits_dtype receives the dtype value and cache_permute_indices receives the permutation indices, i.e., locate the run_moe_test invocation and ensure the argument corresponding to logits_dtype is the logits dtype variable and the argument corresponding to cache_permute_indices is the indices variable.
♻️ Duplicate comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
313-322:⚠️ Potential issue | 🔴 Critical
mDtypeScoreis still derived from routing method instead of actualrouting_logitsdtype.At Line 319-Line 320, non-DeepSeek paths force BF16 even when
routing_logitsis FP32, so FP32 logits can be interpreted with the wrong score dtype. This is a correctness bug.Proposed fix
- // Set dtype of score - if (routing_logits.has_value()) { - if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) { - TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) - << "routing_logits must be float."; - mDtypeScore = btg::Dtype::Fp32; - } else { - mDtypeScore = btg::Dtype::Bfloat16; - } - } + // Set dtype of score from actual routing_logits dtype + if (routing_logits.has_value()) { + auto const logits_dtype = routing_logits.value().dtype(); + if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) { + TVM_FFI_ICHECK_EQ(logits_dtype, dl_float32) << "routing_logits must be float."; + } + if (logits_dtype == dl_float32) { + mDtypeScore = btg::Dtype::Fp32; + } else if (logits_dtype == dl_bfloat16) { + mDtypeScore = btg::Dtype::Bfloat16; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) + << "routing_logits must be float32 or bfloat16."; + } + }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 313 - 322, The code currently sets mDtypeScore based on routing_method_type which forces BF16 for non-DeepSeek methods and can misinterpret FP32 routing_logits; change the logic in the block that checks routing_logits.has_value() to derive mDtypeScore from routing_logits.value().dtype() instead: if routing_logits.value().dtype() == dl_float32 set mDtypeScore = btg::Dtype::Fp32, else set mDtypeScore = btg::Dtype::Bfloat16 (and add a defensive check/error via TVM_FFI_ICHECK if an unexpected dtype appears); replace the existing routing_method_type conditional around mDtypeScore so routing_logits dtype is the single source of truth (symbols: routing_logits, mDtypeScore, routing_method_type, RoutingMethodType::DeepSeekV3).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/trtllm/fused_moe/DevKernel.h`:
- Around line 201-202: There are two conflicting macro definitions of
LAUNCH_ROUTING_WITH_NUM_EXPERTS; remove the earlier 10-parameter version (the
one that includes the numTopExperts parameter) so it does not get silently
overridden by the later 9-parameter definition, ensuring callers that invoke
LAUNCH_ROUTING_WITH_NUM_EXPERTS with ten arguments (e.g., passing numTopExperts)
continue to expand correctly; specifically delete the first definition (the one
that lists numTopExperts in its parameter list) so only the intended macro
remains and compilation/macro-argument mismatches are resolved.
- Around line 201-236: The DeepSeek routing migration broke because the original
macro LAUNCH_ROUTING_WITH_NUM_EXPERTS (defined in DevKernel.h) was intended to
be LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT and callers in
RoutingDeepSeekCommon.cuh still call the undefined
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT and pass the old
forceFloatInput parameter; rename the first macro definition to
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT in DevKernel.h and then update
the DeepSeek launchers in RoutingDeepSeekCommon.cuh (and any other DeepSeek
routing call sites) to stop using the legacy forceFloatInput dispatch and
instead call the new score-dtype-aware macro LAUNCH_ROUTING_WITH_NUM_EXPERTS
(which dispatches on data.mDtypeScore and data.mDtypeExpW); ensure call-site
argument lists match the new macro signature (remove forceFloatInput) and that
any logic that relied on forceFloatInput is expressed via the
extraFlag/score-dtype checks already present in the new macro.
---
Outside diff comments:
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 2567-2577: The helper function run_moe_test was made incompatible
by adding a required parameter logits_dtype; revert this by giving logits_dtype
a sensible default (e.g., None or a default dtype like torch.float32) in
run_moe_test's signature so existing callers (for example
test_nvfp4_moe_gemm_bias) continue to work, and then update any internal uses of
logits_dtype inside run_moe_test (and related helpers) to handle the default
case (use the default dtype when logits_dtype is None or unspecified) so
behavior remains backward-compatible.
- Around line 3151-3160: The call to run_moe_test in the DeepSeek test passes
logits_dtype and cache_permute_indices in the wrong order, causing runtime type
errors when expert_logits are created; fix it by swapping those two arguments so
that the parameter named logits_dtype receives the dtype value and
cache_permute_indices receives the permutation indices, i.e., locate the
run_moe_test invocation and ensure the argument corresponding to logits_dtype is
the logits dtype variable and the argument corresponding to
cache_permute_indices is the indices variable.
---
Duplicate comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 313-322: The code currently sets mDtypeScore based on
routing_method_type which forces BF16 for non-DeepSeek methods and can
misinterpret FP32 routing_logits; change the logic in the block that checks
routing_logits.has_value() to derive mDtypeScore from
routing_logits.value().dtype() instead: if routing_logits.value().dtype() ==
dl_float32 set mDtypeScore = btg::Dtype::Fp32, else set mDtypeScore =
btg::Dtype::Bfloat16 (and add a defensive check/error via TVM_FFI_ICHECK if an
unexpected dtype appears); replace the existing routing_method_type conditional
around mDtypeScore so routing_logits dtype is the single source of truth
(symbols: routing_logits, mDtypeScore, routing_method_type,
RoutingMethodType::DeepSeekV3).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: daefa7eb-107b-4cbc-b220-890e41b1f3f6
📒 Files selected for processing (5)
csrc/trtllm_fused_moe_kernel_launcher.cucsrc/trtllm_fused_moe_runner.cuinclude/flashinfer/trtllm/fused_moe/DevKernel.hinclude/flashinfer/trtllm/fused_moe/RoutingKernel.htests/moe/test_trtllm_gen_fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (1)
- include/flashinfer/trtllm/fused_moe/RoutingKernel.h
|
Hi @yweng0828 we are trying to help merging it. last week we had an CI issue blocking all PRs merging. pls double check that it's fine |
|
/bot run |
|
[SUCCESS] Pipeline #46367047: 8/20 passed |
|
there is currently an error on JIT Unittest (H100) unfortunately |
Head branch was pushed to by a user without write access
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)
3298-3326: Consider adding FP32 logits test for Llama4 routing.
FP8PerTensorMoesupports FP32 logits per the skip logic inutils.py(line 171), but this test only parametrizes BF16. If this is intentional for test speed, it's fine, but FP32 coverage could be added for completeness.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 3298 - 3326, The test test_llama4_routing currently only parametrizes logits_dtype with torch.bfloat16; add an additional parametrization for torch.float32 so Llama4 routing is exercised with FP32 logits (since FP8PerTensorMoe supports FP32 per the skip logic in utils.py around the FP32 check) by updating the pytest.mark.parametrize block to include pytest.param(torch.float32, id="FP32_logits"); ensure test_llama4_routing still calls run_moe_test with the new logits_dtype value so the FP32 path is covered.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 521-523: The expert_weights buffer is allocated using mDtypeScore
but always read by the finalize kernel as TypeExpW instantiated from mDtypeExpW,
causing mismatched interpretation; fix by centralizing the expert-weights dtype
policy: derive a single expW_dtype (based on mDtypeExpW) and use that when
calling alloc_tensor to set FusedMoeLauncher::expert_weights at all allocation
sites (the ones allocating expert_weights), and ensure the same expW_dtype is
passed/visible to the runner/finalize kernel invocation so the template TypeExpW
and the allocated buffer use the same dtype.
---
Nitpick comments:
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 3298-3326: The test test_llama4_routing currently only
parametrizes logits_dtype with torch.bfloat16; add an additional parametrization
for torch.float32 so Llama4 routing is exercised with FP32 logits (since
FP8PerTensorMoe supports FP32 per the skip logic in utils.py around the FP32
check) by updating the pytest.mark.parametrize block to include
pytest.param(torch.float32, id="FP32_logits"); ensure test_llama4_routing still
calls run_moe_test with the new logits_dtype value so the FP32 path is covered.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6b27a662-6c32-4aec-a370-49cdbe427185
📒 Files selected for processing (3)
csrc/trtllm_fused_moe_kernel_launcher.cutests/moe/test_trtllm_gen_fused_moe.pytests/moe/utils.py
| auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16; | ||
| FusedMoeLauncher::expert_weights = | ||
| alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); | ||
| alloc_tensor({args->num_tokens, args->top_k}, ew_dtype, hidden_states.device()); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== DeepSeek expW dtype in routing runner =="
rg -n -C3 'RoutingMethodType::DeepSeekV3|mDtypeExpW' csrc/trtllm_fused_moe_runner.cu
echo
echo "== expert_weights dtype allocation sites in launcher =="
rg -n -C2 'ew_dtype|expert_weights_dtype|alloc_tensor\(\{args->num_tokens, args->top_k\}' csrc/trtllm_fused_moe_kernel_launcher.cuRepository: flashinfer-ai/flashinfer
Length of output: 3594
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== mDtypeScore initialization in launcher =="
rg -n 'mDtypeScore\s*=' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20
echo
echo "== Where expert_weights dtype flows into runner args =="
rg -n 'args->mDtypeExpW|mDtypeExpW\s*=' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20
echo
echo "== Check if Llama and Renormalize/TopK also have the issue =="
rg -n 'class.*Llama|class.*Renormalize|class.*Topk' csrc/trtllm_fused_moe_kernel_launcher.cu
echo
echo "== Allocation sites - check their context for routing method =="
rg -n -B10 'auto ew_dtype = mDtypeScore' csrc/trtllm_fused_moe_kernel_launcher.cu | head -40Repository: flashinfer-ai/flashinfer
Length of output: 670
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Context: Which routing methods trigger Fp32 vs Bfloat16 mDtypeScore =="
sed -n '310,330p' csrc/trtllm_fused_moe_kernel_launcher.cu
echo
echo "== Class definitions to identify which classes set mDtypeScore =="
rg -n 'class.*FusedMoe|struct.*FusedMoe' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20
echo
echo "== How is workspace.expert_weights used in runner =="
rg -n -A3 'workspace\.expert_weights' csrc/trtllm_fused_moe_runner.cu | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1692
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Where is args.mDtypeExpW set =="
rg -n 'args.*mDtypeExpW|mDtypeExpW.*args' csrc/trtllm_fused_moe_kernel_launcher.cu
echo
echo "== Full context of finalizeData usage in runner =="
sed -n '490,510p' csrc/trtllm_fused_moe_runner.cu
echo
echo "== What does finalizeData do with expertWeightsPtr and mDtypeExpW =="
rg -n 'finalizeData\.' csrc/trtllm_fused_moe_runner.cu | head -20
echo
echo "== Check if expert_weights buffer is actually read/interpreted as a specific dtype =="
rg -n 'expert_weights|expertWeights' csrc/trtllm_fused_moe_runner.cu | head -20Repository: flashinfer-ai/flashinfer
Length of output: 102
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Runner.run signature - understand args =="
rg -n 'void run\(|struct.*Args|args->num_tokens' csrc/trtllm_fused_moe_runner.cu | head -20
echo
echo "== Check if finalizeData operations use mDtypeExpW =="
rg -n 'mDtypeExpW' csrc/trtllm_fused_moe_runner.cu
echo
echo "== Look for kernel calls that use expertWeightsPtr and check if dtype matters =="
rg -n 'expertWeights|expW' csrc/ -t cu | grep -E '(kernel|<<<|>>>)' | head -20Repository: flashinfer-ai/flashinfer
Length of output: 112
🏁 Script executed:
#!/bin/bash
set -euo pipefandet
echo "== Routing runner file structure - find main entry point =="
head -100 csrc/trtllm_fused_moe_runner.cu
echo
echo "== Search more broadly for mDtypeExpW usage =="
rg -n 'mDtypeExpW' csrc/
echo
echo "== Check finalize kernel or operation that uses expertWeightsPtr =="
rg -n 'Finalize|finalize' csrc/trtllm_fused_moe_runner.cu | head -10Repository: flashinfer-ai/flashinfer
Length of output: 123
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Routing runner file structure - find main entry point =="
head -100 csrc/trtllm_fused_moe_runner.cu
echo
echo "== Search more broadly for mDtypeExpW usage =="
rg -n 'mDtypeExpW' csrc/
echo
echo "== Check finalize kernel or operation that uses expertWeightsPtr =="
rg -n 'Finalize|finalize' csrc/trtllm_fused_moe_runner.cu | head -10Repository: flashinfer-ai/flashinfer
Length of output: 5503
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Look for where args.mDtypeExpW is initialized =="
rg -n 'struct.*Args|args\.mDtypeExpW|mDtypeExpW' csrc/trtllm_fused_moe_runner.cu | head -30
echo
echo "== Search in launcher for any Args struct definition =="
rg -n 'struct.*Args|\.mDtypeExpW' csrc/trtllm_fused_moe_kernel_launcher.cu | head -30
echo
echo "== Check header files for Args struct =="
fd 'runner\.h|fused_moe.*\.h' csrc/ --exec head -150 {} \; 2>/dev/null | grep -A20 'struct Args'Repository: flashinfer-ai/flashinfer
Length of output: 369
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Finalize function signature =="
rg -n 'void.*finalize|finalize.*Data' csrc/trtllm_fused_moe_runner.cu -A2 | head -30
echo
echo "== Check line 467 context for finalize function =="
sed -n '460,510p' csrc/trtllm_fused_moe_runner.cu
echo
echo "== Look for expert_weights dtype handling in finalize kernel =="
fd '\.h$' csrc/ --exec rg -l 'finalize' {} \; 2>/dev/null | head -5 | xargs rg -n 'struct.*finalize|mDtypeExpW' 2>/dev/null | head -20Repository: flashinfer-ai/flashinfer
Length of output: 4018
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Find MoERunnerArgs struct definition =="
fd '\.h$' csrc/ | xargs rg -l 'MoERunnerArgs' | head -3
echo
echo "== Search for mDtypeExpW in those headers =="
fd '\.h$' csrc/ | xargs rg -n 'mDtypeExpW|struct MoERunnerArgs' 2>/dev/null | head -40
echo
echo "== Check if launcher sets args.mDtypeExpW anywhere =="
rg -n 'args\.mDtypeExpW|mDtypeExpW.*=' csrc/trtllm_fused_moe_kernel_launcher.cuRepository: flashinfer-ai/flashinfer
Length of output: 112
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Search for MoERunnerArgs definition =="
rg -n 'struct MoERunnerArgs|mDtypeExpW' csrc/ -t h 2>/dev/null | head -50
echo
echo "== Look at Bf16MoeLauncher class (lines 458+) =="
sed -n '458,550p' csrc/trtllm_fused_moe_kernel_launcher.cuRepository: flashinfer-ai/flashinfer
Length of output: 109
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Look at actual class that creates workspace and calls runner =="
sed -n '458,550p' csrc/trtllm_fused_moe_kernel_launcher.cu
echo
echo "== Search for where Runner::run is called from launcher =="
rg -n 'Runner::run|\.run\(' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20
echo
echo "== Check the workspace initialization to understand args flow =="
rg -n 'workspace\.|MoEWorkspace' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20Repository: flashinfer-ai/flashinfer
Length of output: 6457
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Check the routing_runner.run call signature and what args it receives =="
sed -n '398,410p' csrc/trtllm_fused_moe_kernel_launcher.cu
echo
echo "== Look for args.mDtypeExpW initialization anywhere in launcher =="
rg -n 'mDtypeExpW' csrc/trtllm_fused_moe_kernel_launcher.cu
echo
echo "== Check what happens with expert_weights in routing kernel (DeepSeekV3) =="
rg -n 'mPtrTopKWeights|mDtypeExpW' csrc/ -t cu 2>/dev/null | grep -i deepseek
echo
echo "== Verify: is args.mDtypeExpW ever set in launcher classes? =="
rg -n 'args->mDtypeExpW\s*=' csrc/trtllm_fused_moe_kernel_launcher.cuRepository: flashinfer-ai/flashinfer
Length of output: 1120
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Full routing_runner.run call =="
sed -n '398,415p' csrc/trtllm_fused_moe_kernel_launcher.cu
echo
echo "== Search for where args is created/initialized =="
rg -n 'args.*=' csrc/trtllm_fused_moe_kernel_launcher.cu | grep -E 'mDtypeExpW|make_unique|new' | head -20
echo
echo "== Check if expert_weights buffer is actually used by routing kernel =="
rg -n 'mPtrTopKWeights' csrc/ -A5 -B5 2>/dev/null | head -40Repository: flashinfer-ai/flashinfer
Length of output: 4895
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Check if mDtypeExpW is actually used in the routing kernel =="
fd '\.h$' csrc/ | xargs rg -l 'mDtypeExpW|routingDeepSeek' 2>/dev/null | head -5
echo
echo "== Check DeepSeekV3 routing kernel implementation =="
rg -n 'routingDeepSeek|mDtypeExpW|mPtrTopKWeights' csrc/ -t h -t cu 2>/dev/null | grep -A5 -B5 'mPtrTopKWeights' | head -50
echo
echo "== Verify: args.mDtypeExpW default value - check if it's initialized =="
rg -n 'MoERunnerArgs' csrc/ -t h --max-count=5 2>/dev/null | head -10Repository: flashinfer-ai/flashinfer
Length of output: 134
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Look for finalize kernel that uses mDtypeExpW and expertWeightsPtr =="
find csrc -name "*.h" -o -name "*.cu" | xargs rg -l 'finalize.*Data|expertWeightsPtr' 2>/dev/null | head -5
echo
echo "== Check how finalize uses expert_weights dtype =="
rg -n 'expertWeightsPtr|mDtypeExpW' csrc/ 2>/dev/null | grep -v 'Binary' | head -30
echo
echo "== Check the actual MoERunnerArgs initialization - look at constructor/default values =="
find csrc -name "*.h" | xargs grep -A30 'struct MoERunnerArgs\|class MoERunnerArgs' 2>/dev/null | head -60Repository: flashinfer-ai/flashinfer
Length of output: 1675
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Check finalize kernel - how TypeExpW is defined and used =="
sed -n '660,680p' csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu
echo
echo "== Check if TypeExpW is template parameter or determined by mDtypeExpW =="
rg -n 'template.*TypeExpW|TypeExpW|mDtypeExpW' csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu | head -30
echo
echo "== Find where args.mDtypeExpW is actually set/initialized =="
rg -n 'mDtypeExpW\s*=' csrc/trtllm_fused_moe_kernel_launcher.cuRepository: flashinfer-ai/flashinfer
Length of output: 1771
Expert_weights dtype mismatch between launcher allocation and kernel interpretation affects all routing methods.
The launcher allocates expert_weights with dtype determined by mDtypeScore (which is Fp32 for DeepSeekV3 and conditionally Fp32 for other methods). However, all routing paths—DeepSeekV3, Llama4, and TopK—hardcode mDtypeExpW = Bfloat16 independent of the launcher's allocation. This buffer is then passed to the finalize kernel, which reads it using TypeExpW template instantiated from mDtypeExpW. When expert_weights is allocated as Fp32 but read as Bfloat16, bytes are misinterpreted, causing data corruption.
This issue is broader than DeepSeekV3 alone: it affects any configuration where mDtypeScore is Fp32 across all routing methods. The launcher never communicates the actual expert_weights dtype back to the runner or finalize kernel.
🔧 Suggested fix (centralize expW dtype policy)
class FusedMoeLauncher {
protected:
+ DLDataType get_expert_weights_dtype() const {
+ if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
+ // Runner DeepSeek path currently expects expW as BF16.
+ return dl_bfloat16;
+ }
+ return mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
+ }- auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
+ auto ew_dtype = get_expert_weights_dtype();Apply the replacement at all four allocation sites (lines 521–523, 662–664, 938–940, 1213–1215).
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16; | |
| FusedMoeLauncher::expert_weights = | |
| alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); | |
| alloc_tensor({args->num_tokens, args->top_k}, ew_dtype, hidden_states.device()); | |
| auto ew_dtype = get_expert_weights_dtype(); | |
| FusedMoeLauncher::expert_weights = | |
| alloc_tensor({args->num_tokens, args->top_k}, ew_dtype, hidden_states.device()); |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 521 - 523, The
expert_weights buffer is allocated using mDtypeScore but always read by the
finalize kernel as TypeExpW instantiated from mDtypeExpW, causing mismatched
interpretation; fix by centralizing the expert-weights dtype policy: derive a
single expW_dtype (based on mDtypeExpW) and use that when calling alloc_tensor
to set FusedMoeLauncher::expert_weights at all allocation sites (the ones
allocating expert_weights), and ensure the same expW_dtype is passed/visible to
the runner/finalize kernel invocation so the template TypeExpW and the allocated
buffer use the same dtype.
|
/bot run |
|
[SUCCESS] Pipeline #46552887: 8/20 passed |
📌 Description
This PR adds more template instantiation for supporting FP32 logits for routing when using
fp8_per_tensorandfp8_blockquantization.mDtypeScoreandmDtypeExpWand adds more template instantiation.🔍 Related Issues
#2469
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes
Refactor
Tests