Skip to content

fix: segfault using packed topk_id/weight as trtllm_bf16_routed_moe input in DeepSeek routing#2911

Open
rosenrodt wants to merge 1 commit intoflashinfer-ai:mainfrom
rosenrodt:fix/trtllm-routed-moe-bf16-ds-routing
Open

fix: segfault using packed topk_id/weight as trtllm_bf16_routed_moe input in DeepSeek routing#2911
rosenrodt wants to merge 1 commit intoflashinfer-ai:mainfrom
rosenrodt:fix/trtllm-routed-moe-bf16-ds-routing

Conversation

@rosenrodt
Copy link
Copy Markdown
Contributor

@rosenrodt rosenrodt commented Mar 29, 2026

📌 Description

When deciding whether to enter the main routing kernel inside the MoE, switch the condition from mPtrTopKIds == nullptr to mPtrScores != nullptr.

mPtrTopKIds == nullptr is true for the "packed" pre-routed inputs so it will go down the main routing kernel path. It is the wrong code path. As mPtrScores is the routing logits, the condition should be mPtrScores != nullptr.

🔍 Related Issues

Found when integrating FlashInfer BF16 TRTLLM MoE in NVIDIA/TensorRT-LLM#12557

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Corrected routing execution path logic for expert selection based on scores availability
  • Tests

    • Extended routed MoE equivalence testing to support DeepSeekV3 routing method
    • Refactored routing test helper for configurable parameters
    • Added DeepSeekV3-specific routing test case

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 29, 2026

📝 Walkthrough

Walkthrough

The pull request modifies DeepSeek MoE routing logic in the CUDA kernel to change the path selection from checking precomputed top-k IDs to checking presence of routing scores, and extends test coverage with a shared helper function and DeepSeekV3-specific test case.

Changes

Cohort / File(s) Summary
CUDA Routing Logic
csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu
Modified routing execution path condition: now selects main kernel path when data.mPtrScores != nullptr (previously data.mPtrTopKIds == nullptr), triggering DeepSeek validation checks and kernel launch behavior accordingly.
MoE Routing Tests
tests/moe/test_trtllm_gen_routed_fused_moe.py
Added routing_reference_no_aux import and created shared helper run_bf16_routed_equivalence_test(...) with configurable seeding and DeepSeekV3 routing support. Refactored existing test to use helper and added new test_trtllm_gen_bf16_routed_fused_moe_deepseek_routing() test case with explicit routing parameters.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

run-ci, op: moe, op: moe-routing

Suggested reviewers

  • aleozlx
  • yzh119
  • jiahanc
  • nv-yunzheq
  • cyx-6
  • djmmoss

Poem

🐰 A hop and a skip through routing's new way,
Scores now lead where top-k once held sway,
DeepSeek's deep wisdom blooms bright,
Tests multiply tests, all working just right,
The MoE farm hops into the light! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main fix: correcting a segfault in DeepSeek routing when using packed topk_id/weight as input to trtllm_bf16_routed_moe.
Description check ✅ Passed The description adequately explains the problem, the fix rationale, and the related issue context. Test completion is marked. However, pre-commit checks are unchecked despite the checklist being included.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the DeepSeek routing kernel logic in the TRT-LLM backend and refactors the BF16 routed MoE tests to support DeepSeekV3 routing. The test suite now includes a helper function for equivalence testing and a specific test case for DeepSeek routing. A review comment suggests parameterizing the new DeepSeek routing test to enhance coverage of different grouping strategies and input configurations.

Comment on lines +748 to +760
def test_trtllm_gen_bf16_routed_fused_moe_deepseek_routing():
run_bf16_routed_equivalence_test(
num_tokens=1,
hidden_size=128,
intermediate_size=128,
top_k=1,
num_experts=32,
routing_method_type=RoutingMethodType.DeepSeekV3,
routed_scaling_factor=2.0,
n_group=1,
topk_group=1,
seed=0,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test for DeepSeek routing uses a single set of hardcoded parameters. To improve test coverage and ensure the fix is robust across different scenarios, consider parameterizing this test to cover a wider range of inputs. This is especially important for n_group and topk_group to test different grouping strategies.

@pytest.mark.parametrize(
    "num_tokens, hidden_size, intermediate_size, top_k, num_experts, n_group, topk_group",
    [
        (1, 128, 128, 1, 32, 1, 1),  # Original case
        (64, 256, 256, 4, 64, 1, 1),
        (8, 1024, 512, 8, 256, 8, 4),  # DSv3-like config
    ],
)
def test_trtllm_gen_bf16_routed_fused_moe_deepseek_routing(
    num_tokens: int,
    hidden_size: int,
    intermediate_size: int,
    top_k: int,
    num_experts: int,
    n_group: int,
    topk_group: int,
):
    run_bf16_routed_equivalence_test(
        num_tokens=num_tokens,
        hidden_size=hidden_size,
        intermediate_size=intermediate_size,
        top_k=top_k,
        num_experts=num_experts,
        routing_method_type=RoutingMethodType.DeepSeekV3,
        routed_scaling_factor=2.0,
        n_group=n_group,
        topk_group=topk_group,
        seed=0,
    )

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/moe/test_trtllm_gen_routed_fused_moe.py (1)

409-420: Fail fast when DeepSeekV3-specific args are missing.

routing_reference_no_aux() treats n_group, topk_group, and routed_scaling_factor as required inputs, but this helper keeps them optional and only blows up later in the DeepSeekV3 branch. A small precondition here would make the next DeepSeekV3 caller fail with a clear setup error instead of an opaque reference-path exception.

Suggested guard
 def run_bf16_routed_equivalence_test(
     num_tokens: int,
     hidden_size: int,
     intermediate_size: int,
     top_k: int,
@@
 ):
     """Test Bf16 scale routed MoE matches standard routing."""
     compute_capability = get_compute_capability(torch.device(device="cuda"))
     if compute_capability[0] not in [10]:
         pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")
     torch.manual_seed(seed)
     device = torch.device("cuda:0")
     enable_pdl = device_support_pdl(device)
 
     # Generate random routing logits for reference
     if routing_method_type == RoutingMethodType.DeepSeekV3:
+        if n_group is None or topk_group is None or routed_scaling_factor is None:
+            raise ValueError(
+                "DeepSeekV3 requires n_group, topk_group, and routed_scaling_factor"
+            )
         routing_logits = torch.randn(num_tokens, num_experts, device=device).to(
             torch.float
         )
         if routing_bias is None:
             routing_bias = torch.randn(num_experts, device=device, dtype=torch.bfloat16)

Also applies to: 431-437, 501-512

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_gen_routed_fused_moe.py` around lines 409 - 420, In
run_bf16_routed_equivalence_test add an early precondition check when
routing_method_type == RoutingMethodType.DeepSeekV3 to validate that n_group,
topk_group, and routed_scaling_factor are not None (and optionally routing_bias
as required by routing_reference_no_aux), and raise a clear ValueError if any
are missing; update the same guard pattern for the other two similar call sites
mentioned (around the blocks at the other locations) so callers fail fast with a
descriptive message rather than raising an opaque exception later in
routing_reference_no_aux.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/moe/test_trtllm_gen_routed_fused_moe.py`:
- Around line 409-420: In run_bf16_routed_equivalence_test add an early
precondition check when routing_method_type == RoutingMethodType.DeepSeekV3 to
validate that n_group, topk_group, and routed_scaling_factor are not None (and
optionally routing_bias as required by routing_reference_no_aux), and raise a
clear ValueError if any are missing; update the same guard pattern for the other
two similar call sites mentioned (around the blocks at the other locations) so
callers fail fast with a descriptive message rather than raising an opaque
exception later in routing_reference_no_aux.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2aa9f21d-6547-4b74-b53c-29a382d924e2

📥 Commits

Reviewing files that changed from the base of the PR and between 779c24d and 79077e6.

📒 Files selected for processing (2)
  • csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu
  • tests/moe/test_trtllm_gen_routed_fused_moe.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant