fix: autotuner cache key mismatch for trtllm-gen FP8 block scale MoE and FP8 routed MoE#2640
fix: autotuner cache key mismatch for trtllm-gen FP8 block scale MoE and FP8 routed MoE#2640Linda-Stadter wants to merge 2 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR centralizes SM100 MoE tuning via a new MoETuningSetup, refactors MoERunner to use it, updates FP8 MoE input/validation and tuning selection, fixes a log-typo in Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @Linda-Stadter, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical autotuner cache key mismatch that was causing performance degradation in FP8 block scale Mixture-of-Experts (MoE) operations. The solution involves a significant refactoring of the MoE tuning configuration logic, centralizing it into a dedicated setup class. This new structure allows for more precise and dynamic selection of tuning profiles based on the specific FP8 quantization type and routing method, thereby ensuring that the autotuner can consistently find and apply optimized tactics. The changes are validated with new tests covering various MoE configurations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request effectively addresses an autotuner cache key mismatch for FP8 MoE by refactoring the tuning configuration logic. The introduction of the MoETuningSetup class is a commendable improvement, enhancing modularity and correctness by centralizing the complex tuning configurations. The changes are well-structured, and the addition of comprehensive tests in tests/moe/test_moe_autotuner_cache_keys.py is a valuable contribution that ensures the fix is robust and helps prevent future regressions. I have one suggestion to further improve the robustness of a new helper method.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/fused_moe/core.py (1)
1185-1240:⚠️ Potential issue | 🟠 MajorPrecomputed-routing autotuning can crash when
routing_logitsisNone.
trtllm_fp8_block_scale_moe_opnow routes a precomputed path withrouting_logits=None, butMoERunner.get_valid_tactics(Line 1198) andMoERunner.forward(Line 1239) still unconditionally userouting_logits.shape[0]. In tuning mode this raises before tactic selection.🐛 Proposed fix
def get_valid_tactics( self, inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: ( output, routing_logits, topk_ids, expert_weights, hidden_states, *extra_inputs, ) = inputs - num_tokens = routing_logits.shape[0] + token_source = routing_logits if routing_logits is not None else topk_ids + assert token_source is not None, ( + "Either routing_logits or topk_ids must be provided." + ) + num_tokens = token_source.shape[0] @@ def forward( self, inputs: List[torch.Tensor], tactic: int = -1, do_preparation: bool = False, **kwargs, ): ( output, routing_logits, topk_ids, expert_weights, hidden_states, *extra_inputs, ) = inputs - num_tokens = routing_logits.shape[0] + token_source = routing_logits if routing_logits is not None else topk_ids + assert token_source is not None, ( + "Either routing_logits or topk_ids must be provided." + ) + num_tokens = token_source.shape[0]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 1185 - 1240, get_valid_tactics and forward assume routing_logits is not None and call routing_logits.shape[0]; when routing is precomputed routing_logits may be None and causes a crash. Fix both MoERunner.get_valid_tactics and MoERunner.forward by computing num_tokens defensively, e.g. use routing_logits.shape[0] if routing_logits is not None else topk_ids.shape[0] (or another appropriate tensor like topk_ids) before building instance_key or using num_tokens; update references to num_tokens in get_valid_tactics and forward accordingly so they no longer access routing_logits when it's None.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1053-1058: Replace the lambda assignment "mk = lambda idx,
dim_idx, inits: cls._make_tuning_config(...)" with a named function to satisfy
Ruff E731; define a local def (e.g., def mk(idx, dim_idx, inits): return
cls._make_tuning_config(idx, dim_idx, inits, tune_max_num_tokens)) and use that
function in place of the lambda so calls referencing mk remain unchanged and
linting passes.
---
Outside diff comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1185-1240: get_valid_tactics and forward assume routing_logits is
not None and call routing_logits.shape[0]; when routing is precomputed
routing_logits may be None and causes a crash. Fix both
MoERunner.get_valid_tactics and MoERunner.forward by computing num_tokens
defensively, e.g. use routing_logits.shape[0] if routing_logits is not None else
topk_ids.shape[0] (or another appropriate tensor like topk_ids) before building
instance_key or using num_tokens; update references to num_tokens in
get_valid_tactics and forward accordingly so they no longer access
routing_logits when it's None.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flashinfer/autotuner.pyflashinfer/fused_moe/core.pytests/moe/test_moe_autotuner_cache_keys.py
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1049-1057: The refine_tuning_config classmethod is decorated with
functools.lru_cache but mutates multiple class attributes and returns None,
causing stale cached None to prevent re-execution; remove the
`@functools.lru_cache`(maxsize=None) decorator above refine_tuning_config so the
method always runs and updates class state (refer to refine_tuning_config and
helper _make_tuning_config in class fused_moe.core) and ensure no other callers
depend on its caching behavior.
In `@tests/moe/test_moe_autotuner_cache_keys.py`:
- Around line 17-22: Add an architecture-based skip at the top of
tests/moe/test_moe_autotuner_cache_keys.py using flashinfer.utils (e.g., import
get_compute_capability or is_sm90a_supported from flashinfer.utils) and use
pytest.mark.skipif (or a runtime pytest.skip) to skip the module when the
current GPU compute capability is not supported by the SM-specific FP8 MoE
tuning paths; update the module imports (alongside AutoTuner,
Fp8QuantizationType, MoETuningSetup) and add a skip condition referencing the
utility (get_compute_capability/is_sm90a_supported) so the entire test file is
skipped on unsupported architectures.
| @classmethod | ||
| @functools.lru_cache(maxsize=None) | ||
| def refine_tuning_config(cls, tune_max_num_tokens: int): | ||
| mk = lambda idx, dim_idx, inits: cls._make_tuning_config( | ||
| idx, | ||
| dim_idx, | ||
| inits, | ||
| tune_max_num_tokens, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, let's check the file structure and locate the method
fd -t f "core.py" flashinfer/fused_moe/Repository: flashinfer-ai/flashinfer
Length of output: 96
🏁 Script executed:
#!/bin/bash
# Get the actual implementation of refine_tuning_config and surrounding code
cat -n flashinfer/fused_moe/core.py | sed -n '1040,1100p'Repository: flashinfer-ai/flashinfer
Length of output: 2451
🏁 Script executed:
#!/bin/bash
# Check for all calls to refine_tuning_config in the codebase
rg -n "refine_tuning_config" --type pyRepository: flashinfer-ai/flashinfer
Length of output: 922
🏁 Script executed:
#!/bin/bash
# Look for where cls.tuning_config_ is assigned/mutated
rg -n "cls\.tuning_config_|self\.tuning_config_" flashinfer/fused_moe/core.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 793
🏁 Script executed:
#!/bin/bash
# Examine the test file to see if it calls with varying max_tokens
cat -n tests/moe/test_moe_autotuner_cache_keys.py | sed -n '100,130p'Repository: flashinfer-ai/flashinfer
Length of output: 1471
🏁 Script executed:
#!/bin/bash
# Check the calling context for MoERunner.refine_tuning_config
cat -n flashinfer/fused_moe/core.py | sed -n '1465,1475p'Repository: flashinfer-ai/flashinfer
Length of output: 510
🏁 Script executed:
#!/bin/bash
# Check another calling context
cat -n flashinfer/fused_moe/core.py | sed -n '1610,1620p'Repository: flashinfer-ai/flashinfer
Length of output: 540
🏁 Script executed:
#!/bin/bash
# Check if there's a pattern of multiple different tune_max_num_tokens being used
rg -B5 "refine_tuning_config\(" flashinfer/fused_moe/core.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 1759
🏁 Script executed:
#!/bin/bash
# Look for any initialization or reset logic for these class attributes
rg -B10 "def __init__|tuning_config_with_hidden_states_scales\s*=" flashinfer/fused_moe/core.py | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1251
🏁 Script executed:
#!/bin/bash
# Check if refine_tuning_config is called at module initialization or only at runtime
rg -n "refine_tuning_config" flashinfer/fused_moe/core.py | grep -v "def refine_tuning_config"Repository: flashinfer-ai/flashinfer
Length of output: 456
🏁 Script executed:
#!/bin/bash
# Check for any __init__ calls or class resets
rg -n "MoETuningSetup\(\)|MoERunner\(\)" flashinfer/fused_moe/core.py tests/Repository: flashinfer-ai/flashinfer
Length of output: 50
Remove lru_cache from side-effectful refine_tuning_config classmethod.
This method mutates 7 class attributes (lines 1063–1097) and returns None. Caching the None return value prevents re-execution when called with the same tune_max_num_tokens after being called with a different value. For example: refine_tuning_config(4096) sets class state; refine_tuning_config(16384) updates it; refine_tuning_config(4096) returns cached None without re-mutating, leaving class attributes at the 16384 state. Since the method provides no return value benefit from caching, the decorator should be removed.
Fix
`@classmethod`
- `@functools.lru_cache`(maxsize=None)
def refine_tuning_config(cls, tune_max_num_tokens: int):
mk = lambda idx, dim_idx, inits: cls._make_tuning_config(
idx,
dim_idx,
inits,
tune_max_num_tokens,
)🧰 Tools
🪛 Ruff (0.15.2)
[error] 1052-1057: Do not assign a lambda expression, use a def
Rewrite mk as a def
(E731)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/core.py` around lines 1049 - 1057, The
refine_tuning_config classmethod is decorated with functools.lru_cache but
mutates multiple class attributes and returns None, causing stale cached None to
prevent re-execution; remove the `@functools.lru_cache`(maxsize=None) decorator
above refine_tuning_config so the method always runs and updates class state
(refer to refine_tuning_config and helper _make_tuning_config in class
fused_moe.core) and ensure no other callers depend on its caching behavior.
|
/bot run |
|
[FAILED] Pipeline #44998624: 9/20 passed |
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
26a7413 to
1a2f043
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/fused_moe/core.py (1)
1047-1097:⚠️ Potential issue | 🟠 MajorDrop
lru_cachefromrefine_tuning_config.This method mutates class-level tuning configs and returns
None, so a cache hit skips the rebuild entirely. A sequence like4096 -> 16384 -> 4096leaves the16384profiles installed, which can still send the autotuner down the wrong cache-key path.🔧 Minimal fix
`@classmethod` - `@functools.lru_cache`(maxsize=None) def refine_tuning_config(cls, tune_max_num_tokens: int, **kwargs): mk = lambda idx, dim_idx, inits: cls._make_tuning_config( idx, dim_idx, inits,Run this read-only check to confirm the decorator behavior. The last printed line should show that the state stayed at
16384after the second4096call:#!/bin/bash python - <<'PY' import functools class Demo: state = None `@classmethod` `@functools.lru_cache`(maxsize=None) def refine(cls, tokens): cls.state = tokens for tokens in (4096, 16384, 4096): Demo.refine(tokens) print(f"after refine({tokens}): state={Demo.state}") PY rg -n -C2 'refine_tuning_config|@functools\.lru_cache' flashinfer/fused_moe/core.py🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 1047 - 1097, The refine_tuning_config method is decorated with `@functools.lru_cache` but mutates class-level tuning config attributes (e.g., tuning_config_with_hidden_states_scales, tuning_config_no_hidden_states_scales, tuning_config_routing_from_logits, etc.) and returns None, so cached calls skip rebuilding state; remove the `@functools.lru_cache`(maxsize=None) decorator from the classmethod refine_tuning_config to ensure each call reinitializes the class attributes (keep the method signature and internal logic as-is). If caching of results is desired instead, implement an explicit cache keyed by tune_max_num_tokens that stores and returns a copy rather than using lru_cache on the mutating method.
🧹 Nitpick comments (1)
tests/moe/test_moe_autotuner_cache_keys.py (1)
110-125: Cover the max-token flip-flop case.This only validates one
refine_tuning_config()value per invocation, so it still passes if the configs fail to switch back after a largermax_tokensrun. A4096 -> 16384 -> 4096sequence would catch that stale-state regression directly.🧪 Suggested regression case
-@pytest.mark.parametrize("max_tokens", [4096, 16384]) -def test_max_tokens_respected(max_tokens): - """Tokens at max_tokens must still hit the cache after refine.""" - MoETuningSetup.refine_tuning_config(max_tokens) - config = MoETuningSetup.select_fp8_tuning_config( - has_routing_logits=True, - fp8_quantization_type=Fp8QuantizationType.DeepSeekFp8, - ) - output = torch.empty(max_tokens, HIDDEN_SIZE) - inputs = MoETuningSetup.build_fp8_moe_inputs( - routing_logits=torch.empty(max_tokens, NUM_EXPERTS), - hidden_states=torch.empty(max_tokens, HIDDEN_SIZE), - hidden_states_scale=torch.empty(SCALE_DIM, max_tokens), - output=output, - ) - _assert_cache_key_match(config, inputs) +def test_max_tokens_respected_after_repeated_refines(): + for max_tokens in (4096, 16384, 4096): + MoETuningSetup.refine_tuning_config(max_tokens) + config = MoETuningSetup.select_fp8_tuning_config( + has_routing_logits=True, + fp8_quantization_type=Fp8QuantizationType.DeepSeekFp8, + ) + output = torch.empty(max_tokens, HIDDEN_SIZE) + inputs = MoETuningSetup.build_fp8_moe_inputs( + routing_logits=torch.empty(max_tokens, NUM_EXPERTS), + hidden_states=torch.empty(max_tokens, HIDDEN_SIZE), + hidden_states_scale=torch.empty(SCALE_DIM, max_tokens), + output=output, + ) + _assert_cache_key_match(config, inputs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_moe_autotuner_cache_keys.py` around lines 110 - 125, Update the test_max_tokens_respected to exercise the flip-flop sequence by calling MoETuningSetup.refine_tuning_config with 4096, then 16384, then 4096 again before selecting the config via MoETuningSetup.select_fp8_tuning_config and building inputs with MoETuningSetup.build_fp8_moe_inputs; finally run _assert_cache_key_match on the resulting config and inputs so the test verifies the state correctly returns to the smaller max_tokens after the larger run.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1047-1097: The refine_tuning_config method is decorated with
`@functools.lru_cache` but mutates class-level tuning config attributes (e.g.,
tuning_config_with_hidden_states_scales, tuning_config_no_hidden_states_scales,
tuning_config_routing_from_logits, etc.) and returns None, so cached calls skip
rebuilding state; remove the `@functools.lru_cache`(maxsize=None) decorator from
the classmethod refine_tuning_config to ensure each call reinitializes the class
attributes (keep the method signature and internal logic as-is). If caching of
results is desired instead, implement an explicit cache keyed by
tune_max_num_tokens that stores and returns a copy rather than using lru_cache
on the mutating method.
---
Nitpick comments:
In `@tests/moe/test_moe_autotuner_cache_keys.py`:
- Around line 110-125: Update the test_max_tokens_respected to exercise the
flip-flop sequence by calling MoETuningSetup.refine_tuning_config with 4096,
then 16384, then 4096 again before selecting the config via
MoETuningSetup.select_fp8_tuning_config and building inputs with
MoETuningSetup.build_fp8_moe_inputs; finally run _assert_cache_key_match on the
resulting config and inputs so the test verifies the state correctly returns to
the smaller max_tokens after the larger run.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 613a2304-e6d1-4244-9cca-c21e5c45f153
📒 Files selected for processing (3)
flashinfer/autotuner.pyflashinfer/fused_moe/core.pytests/moe/test_moe_autotuner_cache_keys.py
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/autotuner.py
|
/bot run |
|
@Linda-Stadter is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
Is the drop in TFLOPs at 4k tokens between untuned and tuned FP8 expected? |
I believe this is measured noise. But could always be that autotuning does not select the best kernel for the actual data |
|
/bot run |
|
[CANCELING] Pipeline #45887410: canceled |
|
The performance of FP8 routed MoE looks more concerning to me (3/5 of the cases look worse after tuning). If it's measurement noise, is there a way we can improve the stability? e.g. increasing the repetitions, etc. |
Yes, even after measuring a second time, I still see 3/5 cases that are worse after tuning. However, the scope of the PR was not to improve the tuning process, but to fix the logic that retrieves the cached tuned tactics. |
📌 Description
The PR
Issue1: Could not find tuned tactic for trtllm_fp8_block_scale_moe
2026-02-26 09:26:35,204 - INFO - autotuner.py:444 - flashinfer.jit: [AutoTunner]: Using fallback tactic for flashinfer::trtllm_fp8_block_scale_moe with input shapes (torch.Size([1024, 4096]), torch.Size([1024, 512]), torch.Size([0]), torch.Size([0]), torch.Size([1024, 4096]), torch.Size([32, 1024]))Tuned with incorrect input:
op=flashinfer::trtllm_fp8_block_scale_moe, profile=((1024, 4096), (1024, 512), (1024,), (1024,), (1024, 4096), (1024, 16384)) -> runner_id=0, tactic=[64, 5]
Issue2: Crash when autotuning trtllm_fp8_block_scale_routed_moe
Benchmark:
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Refactor
Tests