fix: segfault using packed topk_id/weight as trtllm_bf16_routed_moe input in DeepSeek routing#2911
Conversation
📝 WalkthroughWalkthroughThe pull request modifies DeepSeek MoE routing logic in the CUDA kernel to change the path selection from checking precomputed top-k IDs to checking presence of routing scores, and extends test coverage with a shared helper function and DeepSeekV3-specific test case. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request updates the DeepSeek routing kernel logic in the TRT-LLM backend and refactors the BF16 routed MoE tests to support DeepSeekV3 routing. The test suite now includes a helper function for equivalence testing and a specific test case for DeepSeek routing. A review comment suggests parameterizing the new DeepSeek routing test to enhance coverage of different grouping strategies and input configurations.
| def test_trtllm_gen_bf16_routed_fused_moe_deepseek_routing(): | ||
| run_bf16_routed_equivalence_test( | ||
| num_tokens=1, | ||
| hidden_size=128, | ||
| intermediate_size=128, | ||
| top_k=1, | ||
| num_experts=32, | ||
| routing_method_type=RoutingMethodType.DeepSeekV3, | ||
| routed_scaling_factor=2.0, | ||
| n_group=1, | ||
| topk_group=1, | ||
| seed=0, | ||
| ) |
There was a problem hiding this comment.
This test for DeepSeek routing uses a single set of hardcoded parameters. To improve test coverage and ensure the fix is robust across different scenarios, consider parameterizing this test to cover a wider range of inputs. This is especially important for n_group and topk_group to test different grouping strategies.
@pytest.mark.parametrize(
"num_tokens, hidden_size, intermediate_size, top_k, num_experts, n_group, topk_group",
[
(1, 128, 128, 1, 32, 1, 1), # Original case
(64, 256, 256, 4, 64, 1, 1),
(8, 1024, 512, 8, 256, 8, 4), # DSv3-like config
],
)
def test_trtllm_gen_bf16_routed_fused_moe_deepseek_routing(
num_tokens: int,
hidden_size: int,
intermediate_size: int,
top_k: int,
num_experts: int,
n_group: int,
topk_group: int,
):
run_bf16_routed_equivalence_test(
num_tokens=num_tokens,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
top_k=top_k,
num_experts=num_experts,
routing_method_type=RoutingMethodType.DeepSeekV3,
routed_scaling_factor=2.0,
n_group=n_group,
topk_group=topk_group,
seed=0,
)There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/moe/test_trtllm_gen_routed_fused_moe.py (1)
409-420: Fail fast when DeepSeekV3-specific args are missing.
routing_reference_no_aux()treatsn_group,topk_group, androuted_scaling_factoras required inputs, but this helper keeps them optional and only blows up later in the DeepSeekV3 branch. A small precondition here would make the next DeepSeekV3 caller fail with a clear setup error instead of an opaque reference-path exception.Suggested guard
def run_bf16_routed_equivalence_test( num_tokens: int, hidden_size: int, intermediate_size: int, top_k: int, @@ ): """Test Bf16 scale routed MoE matches standard routing.""" compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] not in [10]: pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") torch.manual_seed(seed) device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) # Generate random routing logits for reference if routing_method_type == RoutingMethodType.DeepSeekV3: + if n_group is None or topk_group is None or routed_scaling_factor is None: + raise ValueError( + "DeepSeekV3 requires n_group, topk_group, and routed_scaling_factor" + ) routing_logits = torch.randn(num_tokens, num_experts, device=device).to( torch.float ) if routing_bias is None: routing_bias = torch.randn(num_experts, device=device, dtype=torch.bfloat16)Also applies to: 431-437, 501-512
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_gen_routed_fused_moe.py` around lines 409 - 420, In run_bf16_routed_equivalence_test add an early precondition check when routing_method_type == RoutingMethodType.DeepSeekV3 to validate that n_group, topk_group, and routed_scaling_factor are not None (and optionally routing_bias as required by routing_reference_no_aux), and raise a clear ValueError if any are missing; update the same guard pattern for the other two similar call sites mentioned (around the blocks at the other locations) so callers fail fast with a descriptive message rather than raising an opaque exception later in routing_reference_no_aux.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/moe/test_trtllm_gen_routed_fused_moe.py`:
- Around line 409-420: In run_bf16_routed_equivalence_test add an early
precondition check when routing_method_type == RoutingMethodType.DeepSeekV3 to
validate that n_group, topk_group, and routed_scaling_factor are not None (and
optionally routing_bias as required by routing_reference_no_aux), and raise a
clear ValueError if any are missing; update the same guard pattern for the other
two similar call sites mentioned (around the blocks at the other locations) so
callers fail fast with a descriptive message rather than raising an opaque
exception later in routing_reference_no_aux.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2aa9f21d-6547-4b74-b53c-29a382d924e2
📒 Files selected for processing (2)
csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cutests/moe/test_trtllm_gen_routed_fused_moe.py
📌 Description
When deciding whether to enter the main routing kernel inside the MoE, switch the condition from
mPtrTopKIds == nullptrtomPtrScores != nullptr.mPtrTopKIds == nullptris true for the "packed" pre-routed inputs so it will go down the main routing kernel path. It is the wrong code path. AsmPtrScoresis the routing logits, the condition should bemPtrScores != nullptr.🔍 Related Issues
Found when integrating FlashInfer BF16 TRTLLM MoE in NVIDIA/TensorRT-LLM#12557
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Tests