Skip to content

fix: support fp32 logits for fp8_per_tensor and fp8_block#2534

Open
yweng0828 wants to merge 6 commits intoflashinfer-ai:mainfrom
yweng0828:yweng/add_fp32_logits_for_fp8_routing
Open

fix: support fp32 logits for fp8_per_tensor and fp8_block#2534
yweng0828 wants to merge 6 commits intoflashinfer-ai:mainfrom
yweng0828:yweng/add_fp32_logits_for_fp8_routing

Conversation

@yweng0828
Copy link

@yweng0828 yweng0828 commented Feb 10, 2026

📌 Description

This PR adds more template instantiation for supporting FP32 logits for routing when using fp8_per_tensor and fp8_block quantization.

  • Differentiates between mDtypeScore and mDtypeExpW and adds more template instantiation.
  • Adds testing for different logits data types.

🔍 Related Issues

#2469

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Configurable routing score dtype (FP32/BF16) with routing-method-specific defaults and a new routing API parameter.
  • Bug Fixes

    • Routing-logits dtype validation aligned with score-dtype rules across routing paths to prevent mismatched types.
  • Refactor

    • Routing dispatch and launcher logic updated to consider score dtype alongside existing dtypes for correct code-path selection.
  • Tests

    • Tests parameterized for FP32/BF16 logits; compatibility/skip checks updated to validate logits dtype and quantization modes.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yweng0828, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the fused MoE kernels to support FP32 logits, which is necessary for compatibility with certain models like DeepSeekV3. The changes involve modifications to the kernel launcher, runner, and test suite to accommodate the new data type. This ensures that the MoE kernels can handle a wider range of models and configurations.

Highlights

  • FP32 Logits Support: This PR introduces support for FP32 logits in FP8 per-tensor and FP8 block-scale fused MoE kernels, enabling compatibility with models like DeepSeekV3.
  • Code Modifications: The changes involve modifying the FusedMoeLauncher class to handle different data types for routing scores and updating kernel launch configurations to support FP32.
  • Test Updates: The PR includes updates to the test suite to incorporate FP32 logits testing, ensuring the new functionality works as expected.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_fused_moe_kernel_launcher.cu
    • Added mDtypeScore member to FusedMoeLauncher class.
    • Modified kernel launch parameters to pass mDtypeScore.
    • Added logic to determine mDtypeScore based on routing_logits dtype.
  • csrc/trtllm_fused_moe_runner.cu
    • Modified run function to accept dtypeScore.
    • Assigned dtypeScore to routingData.mDtypeScore.
  • include/flashinfer/trtllm/fused_moe/DevKernel.h
    • Updated LAUNCH_ROUTING_WITH_NUM_EXPERTS macro to handle different combinations of mDtypeScore and mDtypeExpW.
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
    • Added mDtypeScore member to the Data struct.
  • include/flashinfer/trtllm/fused_moe/runner.h
    • Modified Runner::run to accept dtypeScore as a parameter.
  • tests/moe/test_dpsk_fused_moe_fp8.py
    • Added logits_dtype parameter to DPSKFusedMoEFp8 class.
  • tests/moe/test_trtllm_gen_fused_moe.py
    • Removed unnecessary dtype conversion for expert_logits.
    • Added logits_dtype parameter to run_moe_test function.
    • Modified run_moe_test to create expert_logits with the specified dtype.
    • Added logits_dtype parameterization to test functions.
  • tests/moe/utils.py
    • Added logits_dtype parameter to skip_checks function.
    • Added skip logic for incompatible logits_dtype and routing method/quant mode combinations.
Activity
  • The pull request introduces support for FP32 logits in FP8 fused MoE kernels.
  • The changes involve modifying the kernel launcher, runner, and test suite.
  • The test suite is updated to include FP32 logits testing.
  • Skip logic is added to the tests to handle incompatible configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces support for fp32 logits in the fused MoE kernels, specifically for fp8_per_tensor and fp8_block quantization modes. This is achieved by adding a mDtypeScore member to FusedMoeLauncher and routingRenormalize::Data structs, and updating the routing_runner.run calls and kernel dispatch macros to utilize this new dtype. The routing_logits dtype validation logic in trtllm_fp8_per_tensor_scale_moe and trtllm_fp8_block_scale_moe functions is relaxed to allow float32 where appropriate, while still enforcing float32 for DeepSeekV3 routing. Corresponding test cases are updated to parameterize logits_dtype and include new skip conditions to ensure compatibility. The changes are consistent across the codebase and align with the stated goal of the pull request.

workspace.token_scales = expert_weights.data_ptr(); // Consumed by permuteGemm1 kernel
}
if (routing_logits.has_value()) {
mDtypeScore =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this piece of code be part of the FusedMoeLauncher class so that all child classes can share it? It seems that this logic is currently in the Fp8PerTensorLauncher class. Also, we might want to add an assertion to check the data type of routing_logits.

  TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16)
      << "BF16 MoE: routing_logits must be bfloat16 or float.";

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. I have refactored this part of the logic and moved it to the base class.

kernel, numBlocks, numThreads, smemSize, stream); \
} else { \
FLASHINFER_WARN("Unsupported dtypeExpW"); \
FLASHINFER_WARN("Unsupported mDtypeScore/mDtypeExpW combination"); \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about use this infor: Unsupported combination of mDtypeScore and mDtypeExpW

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it has been updated.

@yweng0828 yweng0828 force-pushed the yweng/add_fp32_logits_for_fp8_routing branch from a62decc to 0c876d4 Compare February 12, 2026 07:17
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 12, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds an explicit score dtype (mDtypeScore / dtypeScore), threads it Launcher → Runner → Kernel, updates routing validations (DeepSeekV3 requires FP32), changes Runner::run signature to accept dtypeScore, extends kernel launch branching on score dtype, and parameterizes tests to pass logits dtype.

Changes

Cohort / File(s) Summary
Launcher core
csrc/trtllm_fused_moe_kernel_launcher.cu
Introduce mDtypeScore (default BF16), rename check_routing_logits_shape()check_routing_logits(), set mDtypeScore during preparation (DeepSeekV3 → FP32 else BF16), and pass score dtype into routing/MoE calls before elt dtype.
Runner API & impl
csrc/trtllm_fused_moe_runner.cu, include/flashinfer/trtllm/fused_moe/runner.h
Extended Runner::run(...) signature to insert dtypeScore before dtypeElt; propagate dtypeScore into routingData for DeepSeekV3, Renormalize, and other paths.
Kernel interfaces & macros
include/flashinfer/trtllm/fused_moe/RoutingKernel.h, include/flashinfer/trtllm/fused_moe/DevKernel.h
Add mDtypeScore to routing data struct; expand LAUNCH_ROUTING_WITH_NUM_EXPERTS branching to select templates by (mDtypeScore, mDtypeExpW) combinations and update unsupported-dtype messaging.
Tests & test utils
tests/moe/test_trtllm_gen_fused_moe.py, tests/moe/test_dpsk_fused_moe_fp8.py, tests/moe/utils.py
Parameterize tests with logits_dtype, thread it through run/skip logic, add gating that DeepSeekV3 requires FP32 logits and validate compatibility of FP32 logits with certain quant modes.

Sequence Diagram(s)

sequenceDiagram
    participant Launcher
    participant Runner
    participant Kernel
    participant GPU

    Launcher->>Runner: run(..., dtypeScore, dtypeElt, dtypeBias, ..., stream)
    Runner->>Kernel: prepare routingData (mDtypeScore := dtypeScore, mDtypeElt := dtypeElt, ...)
    Kernel->>GPU: launch routing kernels using mDtypeScore and mDtypeElt
    GPU-->>Kernel: routing results
    Kernel-->>Runner: routing outputs (indices, counts)
    Runner-->>Launcher: return routing results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

op: moe-routing

Suggested reviewers

  • joker-eph
  • djmmoss
  • yzh119
  • cyx-6
  • aleozlx
  • jiahanc
  • nv-yunzheq

Poem

🐰 I nudged the bytes to name their score,
FP32 or BF16, hopped through every door,
Launchers passed the baton down the chain,
Runners and kernels hummed the same refrain,
Carrots routed clean — hooray, more throughput galore! 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 44.44% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main objective: adding support for FP32 logits in fp8_per_tensor and fp8_block quantization modes.
Description check ✅ Passed The description covers the main changes, related issues, and test status, though it notes not all tests are passing.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can use OpenGrep to find security vulnerabilities and bugs across 17+ programming languages.

OpenGrep is compatible with Semgrep configurations. Add an opengrep.yml or semgrep.yml configuration file to your project to enable OpenGrep analysis.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)

219-221: ⚠️ Potential issue | 🟡 Minor

Stale error message in LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG.

The else-branch error message still says "Unsupported dtypeExpW", but the macro now gates on mDtypeScore, mDtypeBias, and mDtypeExpW. Update it similar to line 269 for consistency and easier debugging.

Proposed fix
   } else {                                                                                        \
-    FLASHINFER_WARN("Unsupported dtypeExpW");                                                     \
+    FLASHINFER_WARN("Unsupported combination of mDtypeScore, mDtypeBias, and mDtypeExpW");        \
   }
tests/moe/test_dpsk_fused_moe_fp8.py (1)

615-624: ⚠️ Potential issue | 🔴 Critical

Missing routing_method_type key in routing_config will cause KeyError in the updated skip_checks.

The routing_config dicts defined at lines 510–548 don't contain a "routing_method_type" key, but the new check in skip_checks (line 148 of utils.py) accesses routing_config["routing_method_type"] unconditionally. This will crash every test case in this file.

Either add "routing_method_type" to each routing config dict, or use .get() with a default in skip_checks:

Option 1: Fix in utils.py (safer — handles callers that don't set the key)
     if (
-        routing_config["routing_method_type"] == RoutingMethodType.DeepSeekV3
+        routing_config.get("routing_method_type") == RoutingMethodType.DeepSeekV3
         and logits_dtype != torch.float32
     ):
Option 2: Fix in this test file (add routing_method_type to each config)

For the DSv3 config:

             {
                 "num_experts": 256,
                 "top_k": 8,
+                "routing_method_type": RoutingMethodType.DeepSeekV3,
                 ...
             },

And similarly for other configs with the appropriate RoutingMethodType.

tests/moe/test_trtllm_gen_fused_moe.py (1)

2883-2893: ⚠️ Potential issue | 🔴 Critical

Bug: logits_dtype and cache_permute_indices arguments are swapped.

The run_moe_test signature (line 2337) expects cache_permute_indices as the 8th positional arg and logits_dtype as the 9th. Here, they are passed in the opposite order. This will cause moe_impl._cache_permute_indices to be set to a torch.dtype and expert_logits.to(logits_dtype) to receive a dict, resulting in a runtime crash.

Compare with test_renormalize_routing (line 2695–2696), test_topk_routing (line 2975–2976), and test_llama4_routing (line 3056–3057), which all pass the arguments in the correct order.

🐛 Proposed fix
     run_moe_test(
         num_tokens,
         hidden_size,
         intermediate_size,
         moe_impl,
         routing_config,
         weight_processing,
         activation_type,
-        logits_dtype,
         cache_permute_indices,
+        logits_dtype,
     )
🤖 Fix all issues with AI agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 288-298: The code currently sets mDtypeScore based solely on
routing_method_type which forces BFloat16 for non-DeepSeekV3 even when
routing_logits are float32; update the block that runs when
routing_logits.has_value() so mDtypeScore is derived from
routing_logits.value().dtype(): if routing_method_type ==
RoutingMethodType::DeepSeekV3 keep the TVM_FFI_ICHECK_EQ asserting dtype is
dl_float32 and set mDtypeScore = btg::Dtype::Fp32; otherwise inspect
routing_logits.value().dtype() and set mDtypeScore = btg::Dtype::Fp32 for
dl_float32, btg::Dtype::Bfloat16 for dl_bfloat16 (and error/ICHECK for
unsupported dtypes). Reference symbols: mDtypeScore, routing_logits,
RoutingMethodType::DeepSeekV3.

In `@tests/moe/utils.py`:
- Around line 155-162: The condition incorrectly compares type(moe_impl) to
QuantMode enum values causing all FP32-logits tests to skip; change the check to
inspect moe_impl.quant_mode instead. Update the if-statement that currently
reads "if logits_dtype == torch.float32 and type(moe_impl) not in
[QuantMode...]" to use "moe_impl.quant_mode not in [QuantMode.FP8_PER_TENSOR,
QuantMode.FP8_BLOCK_SCALE, QuantMode.BF16]" so the pytest.skip call only
triggers for incompatible quant modes; keep the existing pytest.skip message and
variables (logits_dtype, moe_impl, QuantMode) unchanged.
🧹 Nitpick comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)

2865-2870: DeepSeekV3 routing is only parametrized with FP32 logits — intentional?

Unlike test_renormalize_routing and test_topk_routing which test both FP32 and BF16, this test only exercises FP32 logits. If BF16 logits are also a valid input for DeepSeekV3 routing in production, consider adding BF16 coverage here too.

@yweng0828
Copy link
Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@yweng0828 is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@aleozlx
Copy link
Collaborator

aleozlx commented Feb 18, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !321 has been created, and the CI pipeline #44270124 is currently running. I'll report back once the pipeline job completes.

@aleozlx
Copy link
Collaborator

aleozlx commented Feb 18, 2026

hi @yweng0828 thx for the contrib
wanna sync with you

  • ready for review?
  • tests passing locally?

@aleozlx aleozlx self-assigned this Feb 18, 2026
@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #44270124: 16/20 passed

Copy link
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

tests clean, ready to merge

@aleozlx aleozlx added the run-ci label Feb 20, 2026
@aleozlx aleozlx enabled auto-merge (squash) February 20, 2026 19:20
@aleozlx aleozlx added the ready label Feb 20, 2026
@yweng0828
Copy link
Author

Hi @aleozlx , thank you for your review. The PR is ready. Local testing has also passed.
Do I need to rebase to main? Or can we just merge it directly?

@wenscarl
Copy link
Collaborator

wenscarl commented Mar 4, 2026

@yweng0828 Does the change also apply to trtllm_fp4_block_scale_moe?

@yweng0828
Copy link
Author

@yweng0828 Does the change also apply to trtllm_fp4_block_scale_moe?

Hi @wenscarl, No, this change does not apply to trtllm_fp4_block_scale_moe. This is because the issue (#2469) report only mentions fp8_per_tensor and fp8_per_block, and we want to minimize the scope of changes.

@wzhao18
Copy link

wzhao18 commented Mar 16, 2026

@aleozlx Any update on merging this?

@wzhao18
Copy link

wzhao18 commented Mar 17, 2026

Hi @yweng0828, can you rebase to main so we can restart the CI?

@aleozlx aleozlx requested a review from nv-yunzheq as a code owner March 17, 2026 18:57
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tests/moe/test_trtllm_gen_fused_moe.py (2)

3131-3140: ⚠️ Potential issue | 🟠 Major

logits_dtype and cache_permute_indices are swapped in this call.

run_moe_test() expects (cache_permute_indices, logits_dtype). In the current order, the cache dict is treated as the dtype, so the DeepSeekV3 guard in skip_checks() skips the whole suite instead of exercising the new FP32-logits path.

Suggested fix
     run_moe_test(
         num_tokens,
         hidden_size,
         intermediate_size,
         moe_impl,
         routing_config,
         weight_processing,
         activation_type,
-        logits_dtype,
         cache_permute_indices,
+        logits_dtype,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 3131 - 3140, The call to
run_moe_test is passing logits_dtype and cache_permute_indices in the wrong
order so the cache dict is being interpreted as the dtype; update the call site
to swap those two arguments so you pass cache_permute_indices first and
logits_dtype second. Locate the call to run_moe_test (the one with parameters
num_tokens, hidden_size, intermediate_size, moe_impl, routing_config,
weight_processing, activation_type, logits_dtype, cache_permute_indices) and
reorder the last two args accordingly to avoid triggering the DeepSeekV3 guard
in skip_checks().

2567-2577: ⚠️ Potential issue | 🟠 Major

This helper change breaks the existing GEMM-bias test.

The test_nvfp4_moe_gemm_bias() call near Line 3332 still invokes run_moe_test() without logits_dtype, so this new required parameter turns that test into a TypeError before the bias path runs. Either give logits_dtype a backward-compatible default or update the remaining caller.

One backward-compatible option
-    logits_dtype,
+    logits_dtype=torch.bfloat16,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 2567 - 2577, The new
required parameter logits_dtype on run_moe_test breaks callers like
test_nvfp4_moe_gemm_bias that still call run_moe_test() without it; make
logits_dtype optional with a backward-compatible default (e.g., None or the
previous default dtype) in the run_moe_test signature and branch inside
run_moe_test (or set a local default variable) so existing callers continue to
exercise the GEMM-bias path without modification; alternatively, update all
callers such as test_nvfp4_moe_gemm_bias to pass the intended logits_dtype
explicitly, but prefer adding the default to run_moe_test to avoid many caller
changes.
♻️ Duplicate comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

313-322: ⚠️ Potential issue | 🔴 Critical

Derive mDtypeScore from routing_logits.dtype() for non-DeepSeek paths.

This still forces btg::Dtype::Bfloat16 for every non-DeepSeek route. The new routing_runner.run(..., mDtypeScore, ...) plumbing forwards that value into routingData.mDtypeScore, so FP32 logits added by this PR are still dispatched/read as BF16 and will produce incorrect routing weights.

Suggested fix
     // Set dtype of score
     if (routing_logits.has_value()) {
       if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
         TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32)
             << "routing_logits must be float.";
         mDtypeScore = btg::Dtype::Fp32;
+      } else if (routing_logits.value().dtype() == dl_float32) {
+        mDtypeScore = btg::Dtype::Fp32;
       } else {
         mDtypeScore = btg::Dtype::Bfloat16;
       }
     }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 313 - 322, The code
currently forces mDtypeScore = btg::Dtype::Bfloat16 for all non-DeepSeek routes
causing FP32 routing_logits to be misinterpreted; change the non-DeepSeek branch
to derive mDtypeScore from routing_logits.value().dtype() (e.g., map dl_float32
-> btg::Dtype::Fp32, dl_bfloat16 -> btg::Dtype::Bfloat16, etc.) while keeping
the existing RoutingMethodType::DeepSeekV3 check that ICHECKs float32 and sets
Fp32; this ensures the value passed into routing_runner.run(...) and stored in
routingData.mDtypeScore matches the actual routing_logits.dtype().
tests/moe/utils.py (1)

170-177: ⚠️ Potential issue | 🟠 Major

The FP32-logits whitelist can never succeed.

type(moe_impl) is a class, not a QuantMode, and QuantMode.FP8_BLOCK_SCALE is not a member of this enum. On FP32 cases this branch either raises AttributeError or skips every implementation, so the new coverage never runs.

Suggested fix
-    if logits_dtype == torch.float32 and type(moe_impl) not in [
-        QuantMode.FP8_PER_TENSOR,
-        QuantMode.FP8_BLOCK_SCALE,
-        QuantMode.BF16,
-    ]:
+    if logits_dtype == torch.float32 and moe_impl.quant_mode not in [
+        QuantMode.FP8_PER_TENSOR,
+        QuantMode.FP8_BLOCK_SCALE_DEEPSEEK,
+        QuantMode.FP8_BLOCK_SCALE_MXFP8,
+        QuantMode.BF16,
+    ]:
         pytest.skip(
             f"Incompatible: logits_dtype={logits_dtype} with {type(moe_impl).__name__} + {moe_impl.quant_mode}"
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/utils.py` around lines 170 - 177, The condition is comparing
type(moe_impl) to enum members (and includes a non-existent
QuantMode.FP8_BLOCK_SCALE), so replace the check with a comparison against
moe_impl.quant_mode (e.g., if logits_dtype == torch.float32 and
moe_impl.quant_mode not in [QuantMode.FP8_PER_TENSOR, QuantMode.BF16]:) and
remove or correct the invalid QuantMode member; keep the pytest.skip call but
use moe_impl.__class__.__name__ and moe_impl.quant_mode in the message to
preserve context.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 3131-3140: The call to run_moe_test is passing logits_dtype and
cache_permute_indices in the wrong order so the cache dict is being interpreted
as the dtype; update the call site to swap those two arguments so you pass
cache_permute_indices first and logits_dtype second. Locate the call to
run_moe_test (the one with parameters num_tokens, hidden_size,
intermediate_size, moe_impl, routing_config, weight_processing, activation_type,
logits_dtype, cache_permute_indices) and reorder the last two args accordingly
to avoid triggering the DeepSeekV3 guard in skip_checks().
- Around line 2567-2577: The new required parameter logits_dtype on run_moe_test
breaks callers like test_nvfp4_moe_gemm_bias that still call run_moe_test()
without it; make logits_dtype optional with a backward-compatible default (e.g.,
None or the previous default dtype) in the run_moe_test signature and branch
inside run_moe_test (or set a local default variable) so existing callers
continue to exercise the GEMM-bias path without modification; alternatively,
update all callers such as test_nvfp4_moe_gemm_bias to pass the intended
logits_dtype explicitly, but prefer adding the default to run_moe_test to avoid
many caller changes.

---

Duplicate comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 313-322: The code currently forces mDtypeScore =
btg::Dtype::Bfloat16 for all non-DeepSeek routes causing FP32 routing_logits to
be misinterpreted; change the non-DeepSeek branch to derive mDtypeScore from
routing_logits.value().dtype() (e.g., map dl_float32 -> btg::Dtype::Fp32,
dl_bfloat16 -> btg::Dtype::Bfloat16, etc.) while keeping the existing
RoutingMethodType::DeepSeekV3 check that ICHECKs float32 and sets Fp32; this
ensures the value passed into routing_runner.run(...) and stored in
routingData.mDtypeScore matches the actual routing_logits.dtype().

In `@tests/moe/utils.py`:
- Around line 170-177: The condition is comparing type(moe_impl) to enum members
(and includes a non-existent QuantMode.FP8_BLOCK_SCALE), so replace the check
with a comparison against moe_impl.quant_mode (e.g., if logits_dtype ==
torch.float32 and moe_impl.quant_mode not in [QuantMode.FP8_PER_TENSOR,
QuantMode.BF16]:) and remove or correct the invalid QuantMode member; keep the
pytest.skip call but use moe_impl.__class__.__name__ and moe_impl.quant_mode in
the message to preserve context.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7b10c190-eea2-4f30-b759-706e05e54311

📥 Commits

Reviewing files that changed from the base of the PR and between 0c876d4 and 5647136.

📒 Files selected for processing (6)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_runner.cu
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
  • tests/moe/test_dpsk_fused_moe_fp8.py
  • tests/moe/test_trtllm_gen_fused_moe.py
  • tests/moe/utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
  • tests/moe/test_dpsk_fused_moe_fp8.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tests/moe/test_trtllm_gen_fused_moe.py (2)

2567-2577: ⚠️ Potential issue | 🟠 Major

run_moe_test now requires logits_dtype, but not all call sites provide it.

Making logits_dtype mandatory here breaks unchanged callers (e.g., test_nvfp4_moe_gemm_bias) with a missing-argument failure. Please keep this helper backward-compatible.

Proposed fix
 def run_moe_test(
@@
-    logits_dtype,
+    logits_dtype=torch.float32,
     zero_hidden_states=False,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 2567 - 2577, The helper
function run_moe_test was made incompatible by adding a required parameter
logits_dtype; revert this by giving logits_dtype a sensible default (e.g., None
or a default dtype like torch.float32) in run_moe_test's signature so existing
callers (for example test_nvfp4_moe_gemm_bias) continue to work, and then update
any internal uses of logits_dtype inside run_moe_test (and related helpers) to
handle the default case (use the default dtype when logits_dtype is None or
unspecified) so behavior remains backward-compatible.

3151-3160: ⚠️ Potential issue | 🟠 Major

DeepSeek test passes run_moe_test arguments in the wrong order.

At Line 3159-Line 3160, logits_dtype and cache_permute_indices are swapped. This causes runtime type errors when creating expert_logits.

Proposed fix
     run_moe_test(
@@
-        activation_type,
-        logits_dtype,
-        cache_permute_indices,
+        activation_type,
+        cache_permute_indices,
+        logits_dtype,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 3151 - 3160, The call to
run_moe_test in the DeepSeek test passes logits_dtype and cache_permute_indices
in the wrong order, causing runtime type errors when expert_logits are created;
fix it by swapping those two arguments so that the parameter named logits_dtype
receives the dtype value and cache_permute_indices receives the permutation
indices, i.e., locate the run_moe_test invocation and ensure the argument
corresponding to logits_dtype is the logits dtype variable and the argument
corresponding to cache_permute_indices is the indices variable.
♻️ Duplicate comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

313-322: ⚠️ Potential issue | 🔴 Critical

mDtypeScore is still derived from routing method instead of actual routing_logits dtype.

At Line 319-Line 320, non-DeepSeek paths force BF16 even when routing_logits is FP32, so FP32 logits can be interpreted with the wrong score dtype. This is a correctness bug.

Proposed fix
-    // Set dtype of score
-    if (routing_logits.has_value()) {
-      if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
-        TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32)
-            << "routing_logits must be float.";
-        mDtypeScore = btg::Dtype::Fp32;
-      } else {
-        mDtypeScore = btg::Dtype::Bfloat16;
-      }
-    }
+    // Set dtype of score from actual routing_logits dtype
+    if (routing_logits.has_value()) {
+      auto const logits_dtype = routing_logits.value().dtype();
+      if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
+        TVM_FFI_ICHECK_EQ(logits_dtype, dl_float32) << "routing_logits must be float.";
+      }
+      if (logits_dtype == dl_float32) {
+        mDtypeScore = btg::Dtype::Fp32;
+      } else if (logits_dtype == dl_bfloat16) {
+        mDtypeScore = btg::Dtype::Bfloat16;
+      } else {
+        TVM_FFI_LOG_AND_THROW(NotImplementedError)
+            << "routing_logits must be float32 or bfloat16.";
+      }
+    }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 313 - 322, The code
currently sets mDtypeScore based on routing_method_type which forces BF16 for
non-DeepSeek methods and can misinterpret FP32 routing_logits; change the logic
in the block that checks routing_logits.has_value() to derive mDtypeScore from
routing_logits.value().dtype() instead: if routing_logits.value().dtype() ==
dl_float32 set mDtypeScore = btg::Dtype::Fp32, else set mDtypeScore =
btg::Dtype::Bfloat16 (and add a defensive check/error via TVM_FFI_ICHECK if an
unexpected dtype appears); replace the existing routing_method_type conditional
around mDtypeScore so routing_logits dtype is the single source of truth
(symbols: routing_logits, mDtypeScore, routing_method_type,
RoutingMethodType::DeepSeekV3).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/trtllm/fused_moe/DevKernel.h`:
- Around line 201-202: There are two conflicting macro definitions of
LAUNCH_ROUTING_WITH_NUM_EXPERTS; remove the earlier 10-parameter version (the
one that includes the numTopExperts parameter) so it does not get silently
overridden by the later 9-parameter definition, ensuring callers that invoke
LAUNCH_ROUTING_WITH_NUM_EXPERTS with ten arguments (e.g., passing numTopExperts)
continue to expand correctly; specifically delete the first definition (the one
that lists numTopExperts in its parameter list) so only the intended macro
remains and compilation/macro-argument mismatches are resolved.
- Around line 201-236: The DeepSeek routing migration broke because the original
macro LAUNCH_ROUTING_WITH_NUM_EXPERTS (defined in DevKernel.h) was intended to
be LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT and callers in
RoutingDeepSeekCommon.cuh still call the undefined
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT and pass the old
forceFloatInput parameter; rename the first macro definition to
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT in DevKernel.h and then update
the DeepSeek launchers in RoutingDeepSeekCommon.cuh (and any other DeepSeek
routing call sites) to stop using the legacy forceFloatInput dispatch and
instead call the new score-dtype-aware macro LAUNCH_ROUTING_WITH_NUM_EXPERTS
(which dispatches on data.mDtypeScore and data.mDtypeExpW); ensure call-site
argument lists match the new macro signature (remove forceFloatInput) and that
any logic that relied on forceFloatInput is expressed via the
extraFlag/score-dtype checks already present in the new macro.

---

Outside diff comments:
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 2567-2577: The helper function run_moe_test was made incompatible
by adding a required parameter logits_dtype; revert this by giving logits_dtype
a sensible default (e.g., None or a default dtype like torch.float32) in
run_moe_test's signature so existing callers (for example
test_nvfp4_moe_gemm_bias) continue to work, and then update any internal uses of
logits_dtype inside run_moe_test (and related helpers) to handle the default
case (use the default dtype when logits_dtype is None or unspecified) so
behavior remains backward-compatible.
- Around line 3151-3160: The call to run_moe_test in the DeepSeek test passes
logits_dtype and cache_permute_indices in the wrong order, causing runtime type
errors when expert_logits are created; fix it by swapping those two arguments so
that the parameter named logits_dtype receives the dtype value and
cache_permute_indices receives the permutation indices, i.e., locate the
run_moe_test invocation and ensure the argument corresponding to logits_dtype is
the logits dtype variable and the argument corresponding to
cache_permute_indices is the indices variable.

---

Duplicate comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 313-322: The code currently sets mDtypeScore based on
routing_method_type which forces BF16 for non-DeepSeek methods and can
misinterpret FP32 routing_logits; change the logic in the block that checks
routing_logits.has_value() to derive mDtypeScore from
routing_logits.value().dtype() instead: if routing_logits.value().dtype() ==
dl_float32 set mDtypeScore = btg::Dtype::Fp32, else set mDtypeScore =
btg::Dtype::Bfloat16 (and add a defensive check/error via TVM_FFI_ICHECK if an
unexpected dtype appears); replace the existing routing_method_type conditional
around mDtypeScore so routing_logits dtype is the single source of truth
(symbols: routing_logits, mDtypeScore, routing_method_type,
RoutingMethodType::DeepSeekV3).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: daefa7eb-107b-4cbc-b220-890e41b1f3f6

📥 Commits

Reviewing files that changed from the base of the PR and between 5647136 and b9d4245.

📒 Files selected for processing (5)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_runner.cu
  • include/flashinfer/trtllm/fused_moe/DevKernel.h
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
  • tests/moe/test_trtllm_gen_fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h

@aleozlx
Copy link
Collaborator

aleozlx commented Mar 17, 2026

Hi @yweng0828 we are trying to help merging it. last week we had an CI issue blocking all PRs merging.
i refreshed the branch but had to manually resolve some conflicts in include/flashinfer/trtllm/fused_moe/DevKernel.h

pls double check that it's fine

@aleozlx
Copy link
Collaborator

aleozlx commented Mar 17, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !321 has been updated with latest changes, and the CI pipeline #46367047 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #46367047: 8/20 passed

@aleozlx
Copy link
Collaborator

aleozlx commented Mar 18, 2026

there is currently an error on JIT Unittest (H100) unfortunately

E       TypeError: run_moe_test() missing 1 required positional argument: 'logits_dtype'

tests/moe/test_trtllm_gen_fused_moe.py:3352: TypeError

auto-merge was automatically disabled March 18, 2026 23:26

Head branch was pushed to by a user without write access

@yweng0828
Copy link
Author

Hi @aleozlx @wzhao18 , Thanks for your help.
I make some changes to fix the UT. Could you please help me to restart the CI? Thanks.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)

3298-3326: Consider adding FP32 logits test for Llama4 routing.

FP8PerTensorMoe supports FP32 logits per the skip logic in utils.py (line 171), but this test only parametrizes BF16. If this is intentional for test speed, it's fine, but FP32 coverage could be added for completeness.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 3298 - 3326, The test
test_llama4_routing currently only parametrizes logits_dtype with
torch.bfloat16; add an additional parametrization for torch.float32 so Llama4
routing is exercised with FP32 logits (since FP8PerTensorMoe supports FP32 per
the skip logic in utils.py around the FP32 check) by updating the
pytest.mark.parametrize block to include pytest.param(torch.float32,
id="FP32_logits"); ensure test_llama4_routing still calls run_moe_test with the
new logits_dtype value so the FP32 path is covered.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 521-523: The expert_weights buffer is allocated using mDtypeScore
but always read by the finalize kernel as TypeExpW instantiated from mDtypeExpW,
causing mismatched interpretation; fix by centralizing the expert-weights dtype
policy: derive a single expW_dtype (based on mDtypeExpW) and use that when
calling alloc_tensor to set FusedMoeLauncher::expert_weights at all allocation
sites (the ones allocating expert_weights), and ensure the same expW_dtype is
passed/visible to the runner/finalize kernel invocation so the template TypeExpW
and the allocated buffer use the same dtype.

---

Nitpick comments:
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 3298-3326: The test test_llama4_routing currently only
parametrizes logits_dtype with torch.bfloat16; add an additional parametrization
for torch.float32 so Llama4 routing is exercised with FP32 logits (since
FP8PerTensorMoe supports FP32 per the skip logic in utils.py around the FP32
check) by updating the pytest.mark.parametrize block to include
pytest.param(torch.float32, id="FP32_logits"); ensure test_llama4_routing still
calls run_moe_test with the new logits_dtype value so the FP32 path is covered.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6b27a662-6c32-4aec-a370-49cdbe427185

📥 Commits

Reviewing files that changed from the base of the PR and between a934912 and af075f1.

📒 Files selected for processing (3)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • tests/moe/test_trtllm_gen_fused_moe.py
  • tests/moe/utils.py

Comment on lines +521 to +523
auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
FusedMoeLauncher::expert_weights =
alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device());
alloc_tensor({args->num_tokens, args->top_k}, ew_dtype, hidden_states.device());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== DeepSeek expW dtype in routing runner =="
rg -n -C3 'RoutingMethodType::DeepSeekV3|mDtypeExpW' csrc/trtllm_fused_moe_runner.cu

echo
echo "== expert_weights dtype allocation sites in launcher =="
rg -n -C2 'ew_dtype|expert_weights_dtype|alloc_tensor\(\{args->num_tokens, args->top_k\}' csrc/trtllm_fused_moe_kernel_launcher.cu

Repository: flashinfer-ai/flashinfer

Length of output: 3594


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== mDtypeScore initialization in launcher =="
rg -n 'mDtypeScore\s*=' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20

echo
echo "== Where expert_weights dtype flows into runner args =="
rg -n 'args->mDtypeExpW|mDtypeExpW\s*=' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20

echo
echo "== Check if Llama and Renormalize/TopK also have the issue =="
rg -n 'class.*Llama|class.*Renormalize|class.*Topk' csrc/trtllm_fused_moe_kernel_launcher.cu

echo
echo "== Allocation sites - check their context for routing method =="
rg -n -B10 'auto ew_dtype = mDtypeScore' csrc/trtllm_fused_moe_kernel_launcher.cu | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 670


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Context: Which routing methods trigger Fp32 vs Bfloat16 mDtypeScore =="
sed -n '310,330p' csrc/trtllm_fused_moe_kernel_launcher.cu

echo
echo "== Class definitions to identify which classes set mDtypeScore =="
rg -n 'class.*FusedMoe|struct.*FusedMoe' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20

echo
echo "== How is workspace.expert_weights used in runner =="
rg -n -A3 'workspace\.expert_weights' csrc/trtllm_fused_moe_runner.cu | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1692


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Where is args.mDtypeExpW set =="
rg -n 'args.*mDtypeExpW|mDtypeExpW.*args' csrc/trtllm_fused_moe_kernel_launcher.cu

echo
echo "== Full context of finalizeData usage in runner =="
sed -n '490,510p' csrc/trtllm_fused_moe_runner.cu

echo
echo "== What does finalizeData do with expertWeightsPtr and mDtypeExpW =="
rg -n 'finalizeData\.' csrc/trtllm_fused_moe_runner.cu | head -20

echo
echo "== Check if expert_weights buffer is actually read/interpreted as a specific dtype =="
rg -n 'expert_weights|expertWeights' csrc/trtllm_fused_moe_runner.cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 102


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Runner.run signature - understand args =="
rg -n 'void run\(|struct.*Args|args->num_tokens' csrc/trtllm_fused_moe_runner.cu | head -20

echo
echo "== Check if finalizeData operations use mDtypeExpW =="
rg -n 'mDtypeExpW' csrc/trtllm_fused_moe_runner.cu

echo
echo "== Look for kernel calls that use expertWeightsPtr and check if dtype matters =="
rg -n 'expertWeights|expW' csrc/ -t cu | grep -E '(kernel|<<<|>>>)' | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 112


🏁 Script executed:

#!/bin/bash
set -euo pipefandet

echo "== Routing runner file structure - find main entry point =="
head -100 csrc/trtllm_fused_moe_runner.cu

echo
echo "== Search more broadly for mDtypeExpW usage =="
rg -n 'mDtypeExpW' csrc/

echo
echo "== Check finalize kernel or operation that uses expertWeightsPtr =="
rg -n 'Finalize|finalize' csrc/trtllm_fused_moe_runner.cu | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 123


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Routing runner file structure - find main entry point =="
head -100 csrc/trtllm_fused_moe_runner.cu

echo
echo "== Search more broadly for mDtypeExpW usage =="
rg -n 'mDtypeExpW' csrc/

echo
echo "== Check finalize kernel or operation that uses expertWeightsPtr =="
rg -n 'Finalize|finalize' csrc/trtllm_fused_moe_runner.cu | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 5503


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Look for where args.mDtypeExpW is initialized =="
rg -n 'struct.*Args|args\.mDtypeExpW|mDtypeExpW' csrc/trtllm_fused_moe_runner.cu | head -30

echo
echo "== Search in launcher for any Args struct definition =="
rg -n 'struct.*Args|\.mDtypeExpW' csrc/trtllm_fused_moe_kernel_launcher.cu | head -30

echo
echo "== Check header files for Args struct =="
fd 'runner\.h|fused_moe.*\.h' csrc/ --exec head -150 {} \; 2>/dev/null | grep -A20 'struct Args'

Repository: flashinfer-ai/flashinfer

Length of output: 369


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Finalize function signature =="
rg -n 'void.*finalize|finalize.*Data' csrc/trtllm_fused_moe_runner.cu -A2 | head -30

echo
echo "== Check line 467 context for finalize function =="
sed -n '460,510p' csrc/trtllm_fused_moe_runner.cu

echo
echo "== Look for expert_weights dtype handling in finalize kernel =="
fd '\.h$' csrc/ --exec rg -l 'finalize' {} \; 2>/dev/null | head -5 | xargs rg -n 'struct.*finalize|mDtypeExpW' 2>/dev/null | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 4018


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Find MoERunnerArgs struct definition =="
fd '\.h$' csrc/ | xargs rg -l 'MoERunnerArgs' | head -3

echo
echo "== Search for mDtypeExpW in those headers =="
fd '\.h$' csrc/ | xargs rg -n 'mDtypeExpW|struct MoERunnerArgs' 2>/dev/null | head -40

echo
echo "== Check if launcher sets args.mDtypeExpW anywhere =="
rg -n 'args\.mDtypeExpW|mDtypeExpW.*=' csrc/trtllm_fused_moe_kernel_launcher.cu

Repository: flashinfer-ai/flashinfer

Length of output: 112


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Search for MoERunnerArgs definition =="
rg -n 'struct MoERunnerArgs|mDtypeExpW' csrc/ -t h 2>/dev/null | head -50

echo
echo "== Look at Bf16MoeLauncher class (lines 458+) =="
sed -n '458,550p' csrc/trtllm_fused_moe_kernel_launcher.cu

Repository: flashinfer-ai/flashinfer

Length of output: 109


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Look at actual class that creates workspace and calls runner =="
sed -n '458,550p' csrc/trtllm_fused_moe_kernel_launcher.cu

echo
echo "== Search for where Runner::run is called from launcher =="
rg -n 'Runner::run|\.run\(' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20

echo
echo "== Check the workspace initialization to understand args flow =="
rg -n 'workspace\.|MoEWorkspace' csrc/trtllm_fused_moe_kernel_launcher.cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 6457


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Check the routing_runner.run call signature and what args it receives =="
sed -n '398,410p' csrc/trtllm_fused_moe_kernel_launcher.cu

echo
echo "== Look for args.mDtypeExpW initialization anywhere in launcher =="
rg -n 'mDtypeExpW' csrc/trtllm_fused_moe_kernel_launcher.cu

echo
echo "== Check what happens with expert_weights in routing kernel (DeepSeekV3) =="
rg -n 'mPtrTopKWeights|mDtypeExpW' csrc/ -t cu 2>/dev/null | grep -i deepseek

echo
echo "== Verify: is args.mDtypeExpW ever set in launcher classes? =="
rg -n 'args->mDtypeExpW\s*=' csrc/trtllm_fused_moe_kernel_launcher.cu

Repository: flashinfer-ai/flashinfer

Length of output: 1120


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Full routing_runner.run call =="
sed -n '398,415p' csrc/trtllm_fused_moe_kernel_launcher.cu

echo
echo "== Search for where args is created/initialized =="
rg -n 'args.*=' csrc/trtllm_fused_moe_kernel_launcher.cu | grep -E 'mDtypeExpW|make_unique|new' | head -20

echo
echo "== Check if expert_weights buffer is actually used by routing kernel =="
rg -n 'mPtrTopKWeights' csrc/ -A5 -B5 2>/dev/null | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 4895


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Check if mDtypeExpW is actually used in the routing kernel =="
fd '\.h$' csrc/ | xargs rg -l 'mDtypeExpW|routingDeepSeek' 2>/dev/null | head -5

echo
echo "== Check DeepSeekV3 routing kernel implementation =="
rg -n 'routingDeepSeek|mDtypeExpW|mPtrTopKWeights' csrc/ -t h -t cu 2>/dev/null | grep -A5 -B5 'mPtrTopKWeights' | head -50

echo
echo "== Verify: args.mDtypeExpW default value - check if it's initialized =="
rg -n 'MoERunnerArgs' csrc/ -t h --max-count=5 2>/dev/null | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 134


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Look for finalize kernel that uses mDtypeExpW and expertWeightsPtr =="
find csrc -name "*.h" -o -name "*.cu" | xargs rg -l 'finalize.*Data|expertWeightsPtr' 2>/dev/null | head -5

echo
echo "== Check how finalize uses expert_weights dtype =="
rg -n 'expertWeightsPtr|mDtypeExpW' csrc/ 2>/dev/null | grep -v 'Binary' | head -30

echo
echo "== Check the actual MoERunnerArgs initialization - look at constructor/default values =="
find csrc -name "*.h" | xargs grep -A30 'struct MoERunnerArgs\|class MoERunnerArgs' 2>/dev/null | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 1675


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Check finalize kernel - how TypeExpW is defined and used =="
sed -n '660,680p' csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu

echo
echo "== Check if TypeExpW is template parameter or determined by mDtypeExpW =="
rg -n 'template.*TypeExpW|TypeExpW|mDtypeExpW' csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu | head -30

echo
echo "== Find where args.mDtypeExpW is actually set/initialized =="
rg -n 'mDtypeExpW\s*=' csrc/trtllm_fused_moe_kernel_launcher.cu

Repository: flashinfer-ai/flashinfer

Length of output: 1771


Expert_weights dtype mismatch between launcher allocation and kernel interpretation affects all routing methods.

The launcher allocates expert_weights with dtype determined by mDtypeScore (which is Fp32 for DeepSeekV3 and conditionally Fp32 for other methods). However, all routing paths—DeepSeekV3, Llama4, and TopK—hardcode mDtypeExpW = Bfloat16 independent of the launcher's allocation. This buffer is then passed to the finalize kernel, which reads it using TypeExpW template instantiated from mDtypeExpW. When expert_weights is allocated as Fp32 but read as Bfloat16, bytes are misinterpreted, causing data corruption.

This issue is broader than DeepSeekV3 alone: it affects any configuration where mDtypeScore is Fp32 across all routing methods. The launcher never communicates the actual expert_weights dtype back to the runner or finalize kernel.

🔧 Suggested fix (centralize expW dtype policy)
 class FusedMoeLauncher {
  protected:
+  DLDataType get_expert_weights_dtype() const {
+    if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
+      // Runner DeepSeek path currently expects expW as BF16.
+      return dl_bfloat16;
+    }
+    return mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
+  }
-      auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
+      auto ew_dtype = get_expert_weights_dtype();

Apply the replacement at all four allocation sites (lines 521–523, 662–664, 938–940, 1213–1215).

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
FusedMoeLauncher::expert_weights =
alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device());
alloc_tensor({args->num_tokens, args->top_k}, ew_dtype, hidden_states.device());
auto ew_dtype = get_expert_weights_dtype();
FusedMoeLauncher::expert_weights =
alloc_tensor({args->num_tokens, args->top_k}, ew_dtype, hidden_states.device());
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 521 - 523, The
expert_weights buffer is allocated using mDtypeScore but always read by the
finalize kernel as TypeExpW instantiated from mDtypeExpW, causing mismatched
interpretation; fix by centralizing the expert-weights dtype policy: derive a
single expW_dtype (based on mDtypeExpW) and use that when calling alloc_tensor
to set FusedMoeLauncher::expert_weights at all allocation sites (the ones
allocating expert_weights), and ensure the same expW_dtype is passed/visible to
the runner/finalize kernel invocation so the template TypeExpW and the allocated
buffer use the same dtype.

@aleozlx
Copy link
Collaborator

aleozlx commented Mar 19, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !321 has been updated with latest changes, and the CI pipeline #46552887 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #46552887: 8/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants