Skip to content

fix: Fix autotuner crash on meta-device tensor in trtllm_fp4_block_scale_routed_moe#2916

Merged
aleozlx merged 2 commits intoflashinfer-ai:mainfrom
bkryu:autotune_trtllm_fp4_block_scale_routed_moe
Apr 1, 2026
Merged

fix: Fix autotuner crash on meta-device tensor in trtllm_fp4_block_scale_routed_moe#2916
aleozlx merged 2 commits intoflashinfer-ai:mainfrom
bkryu:autotune_trtllm_fp4_block_scale_routed_moe

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Mar 30, 2026

📌 Description

Summary

  • Fixes RuntimeError: Cannot pack tensors on meta when trtllm_fp4_block_scale_routed_moe is called with autotuning enabled
  • Ensures the autotuner profiles the correct kernel code path (no-routing) when routing is pre-computed

Root Cause

Wh#en trtllm_fp4_block_scale_routed_moe is called, routing_logits is None because routing has already been done (pre-computed topk_ids are provided instead). To give the autotuner a tensor with the right shape/dtype for profile generation, a placeholder was created with device="meta":

torch.empty(num_tokens, num_experts, dtype=routing_dtype, device="meta")

This worked without autotuning because choose_one returns early (only inspects .size() for the cache key, never passes the tensor to a kernel).

With autotuning enabled, choose_one enters the profiling loop, which calls _create_tensor_like on the placeholder. That method copies origin_tensor.device, so the derived profiling tensor is also on "meta". When the profiling path calls MoERunner.forward, this meta tensor is passed to the C++ kernel via TVM FFI, which attempts DLPack conversion and fails: the meta device has no real memory.

Fix

Three changes in flashinfer/fused_moe/core.py:

  • Replace device="meta" with device=hidden_states.device — the placeholder is now a real CUDA tensor so the autotuner can safely derive profiling tensors from it.
  • Pass skip_routing=(routing_logits is None) through kwargs to choose_one, signaling that routing was pre-computed.
  • In MoERunner.forward, set routing_logits = None when skip_routing=True — this ensures the C++ kernel takes the same no-routing code path during profiling as it does in production. Without this, the autotuner would profile with routing computation enabled (random routing_logits data), potentially selecting a suboptimal tactic for the actual inference path where routing is skipped.

Unit test changes

Added test_fp4_routed_moe_autotune_no_crash regression test in tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py The test calls trtllm_fp4_block_scale_routed_moe inside autotune(True) with num_tokens=1 and num_tokens=16, verifying no crash occurs.

Main branch fails the newly added tests before the changes in flashinfer/fused_moe/core.py:

$ pytest tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py
...
tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py ....FF...   
...
FAILED tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py::test_fp4_routed_moe_autotune_no_crash[4-16-1] - RuntimeError: Cannot pack tensors on meta
FAILED tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py::test_fp4_routed_moe_autotune_no_crash[4-16-16] - RuntimeError: Cannot pack tensors on meta
====================================================================================== 2 failed, 7 passed in 4.08s ======================================================================================

After fix:

$ pytest tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py
...
tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py .........                                                                                                                           [100%]

=========================================================================================== 9 passed in 6.47s ============================================================================================

🔍 Related Issues

#2023

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added skip routing support for Mixture of Experts operations, enabling optimization when routing computation can be bypassed.
  • Tests

    • Added regression test for FP4 routed Mixture of Experts autotuning to ensure stability across token configurations.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 30, 2026

📝 Walkthrough

Walkthrough

Added support for a skip_routing flag in MoERunner to bypass routing computation. Updated TensorRT-LLM FP4 block-scale MoE operator to allocate routing logits workspace on the correct device and pass the skip_routing state to the autotuner. Included integration test for FP4 routed MoE autotuning.

Changes

Cohort / File(s) Summary
Core MoE routing bypass support
flashinfer/fused_moe/core.py
Added conditional skip_routing flag handling in MoERunner.forward to set routing_logits to None when requested. Updated workspace allocation for routing logits to use hidden_states.device instead of meta device. Propagated skip_routing state to autotuner call for consistent tactic selection.
FP4 MoE autotuner integration test
tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py
Added SM100-only regression test test_fp4_routed_moe_autotune_no_crash parameterized over token counts and expert configurations. Constructs FP4 routed MoE inputs with packed topk_ids and block-scale quantized weights, runs operator within autotuning context, validates no crash occurs.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Possibly related issues

  • 0.6.5 auto-tuning crash #2701: Changes to skip_routing flag propagation and routing_logits device placement directly affect the same trtllm MoE autotuner/executor interaction path.

Possibly related PRs

Suggested labels

run-ci, op: moe, op: moe-routing

Suggested reviewers

  • sricketts
  • aleozlx
  • cyx-6
  • yzh119
  • samuellees
  • djmmoss

Poem

🐰✨ A hop through the routing bypass flow,
Where skip_routing flags let decisions go!
Device placement dancing, workspace repositioned,
The autotuner's wisdom expertly conditioned.
Tests ensure no crashes on the SM100 stage—
A well-tuned MoE writes a cleaner page! 🎯

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Description check ✅ Passed The pull request description is comprehensive and well-structured, following the template with detailed explanations of the problem, root cause, and the implemented fix.
Title check ✅ Passed The title clearly identifies the specific bug fix (autotuner crash) and the affected function (trtllm_fp4_block_scale_routed_moe), directly matching the main changes in the PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@bkryu bkryu added the run-ci label Mar 30, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a crash in the autotuner for FP4 routed MoE by ensuring that placeholder tensors are created on the actual device instead of the 'meta' device, which lacks storage for C++ kernel interaction. It also introduces logic to skip routing when routing_logits are absent and adds a regression test to verify the fix. I have no feedback to provide as the changes correctly address the issue.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 30, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !476 has been created, and the CI pipeline #47290871 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx linked an issue Mar 30, 2026 that may be closed by this pull request
@aleozlx aleozlx enabled auto-merge (squash) March 30, 2026 22:14
@nvpohanh
Copy link
Copy Markdown
Contributor

cc @trevor-m who is working on integrating routed moe into SGL

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47290871: 5/20 passed

@bkryu bkryu changed the title fix: Fix autotuner crash on meta-device tensor in trtllm_fp4_block_scale_routed_moe Description: fix: Fix autotuner crash on meta-device tensor in trtllm_fp4_block_scale_routed_moe Mar 31, 2026
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 31, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !476 has been updated with latest changes, and the CI pipeline #47381824 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47381824: 11/20 passed

@aleozlx aleozlx merged commit 23b3279 into flashinfer-ai:main Apr 1, 2026
29 of 34 checks passed
@bkryu bkryu deleted the autotune_trtllm_fp4_block_scale_routed_moe branch April 1, 2026 19:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Autotuning + trtllm_fp4_block_scale_routed_moe Issue

4 participants