Skip to content

fix: autotuner cache key mismatch for trtllm-gen FP8 block scale MoE and FP8 routed MoE#2640

Open
Linda-Stadter wants to merge 2 commits intoflashinfer-ai:mainfrom
Linda-Stadter:fp8_cache_key_mismatch
Open

fix: autotuner cache key mismatch for trtllm-gen FP8 block scale MoE and FP8 routed MoE#2640
Linda-Stadter wants to merge 2 commits intoflashinfer-ai:mainfrom
Linda-Stadter:fp8_cache_key_mismatch

Conversation

@Linda-Stadter
Copy link

@Linda-Stadter Linda-Stadter commented Feb 26, 2026

📌 Description

The PR

  • fixes input shape mismatches to match the autotuner cache key for MoE FP8
  • enables autotuner for fp8 block scale routed moe

Issue1: Could not find tuned tactic for trtllm_fp8_block_scale_moe
2026-02-26 09:26:35,204 - INFO - autotuner.py:444 - flashinfer.jit: [AutoTunner]: Using fallback tactic for flashinfer::trtllm_fp8_block_scale_moe with input shapes (torch.Size([1024, 4096]), torch.Size([1024, 512]), torch.Size([0]), torch.Size([0]), torch.Size([1024, 4096]), torch.Size([32, 1024]))

Tuned with incorrect input:
op=flashinfer::trtllm_fp8_block_scale_moe, profile=((1024, 4096), (1024, 512), (1024,), (1024,), (1024, 4096), (1024, 16384)) -> runner_id=0, tactic=[64, 5]

Issue2: Crash when autotuning trtllm_fp8_block_scale_routed_moe

  File "/flashinfer/flashinfer/fused_moe/core.py", line 2568, in trtllm_fp8_block_scale_routed_moe
    result = get_trtllm_moe_sm100_module().trtllm_fp8_block_scale_moe(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/flashinfer/flashinfer/fused_moe/core.py", line 1711, in trtllm_fp8_block_scale_moe_op
    _, tactic = tuner.choose_one(
                ^^^^^^^^^^^^^^^^^
  File "/flashinfer/flashinfer/autotuner.py", line 470, in choose_one
    tensors = self._prepare_input_tensors(p, inputs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/flashinfer/flashinfer/autotuner.py", line 792, in _prepare_input_tensors
    tensor = self._create_tensor_like(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/flashinfer/flashinfer/autotuner.py", line 771, in _create_tensor_like
    dtype = origin_tensor.dtype
            ^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'dtype' 

Benchmark:

Tokens BF16 (ms) BF16 TFLOPS FP8 Untuned (ms) FP8 Untuned TFLOPS FP8 Tuned (ms) FP8 Tuned TFLOPS FP8 routed Untuned (ms) FP8 routed Untuned TFLOPS FP8 routed Tuned (ms) FP8 routed Tuned TFLOPS
1024 1.877 137.32 1.455 177.07 1.187 217.07 1.337 192.80 1.514 170.27
2048 1.952 263.99 1.692 304.65 1.425 361.77 1.548 333.04 1.662 310.09
4096 2.194 469.85 2.232 461.79 2.561 402.43 2.087 493.88 1.887 546.16
8192 3.594 573.57 3.458 596.15 3.439 599.49 3.355 614.50 3.582 575.53
16384 5.423 760.37 6.329 651.47 5.852 704.53 6.026 684.17 5.670 727.18

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Corrected a typo in autotuner debug log messages.
  • Refactor

    • Consolidated MoE tuning configuration and input preparation into a centralized setup, simplifying FP8/FP4 paths, reducing duplication, and improving runtime/shape validation and configurability.
  • Tests

    • Added tests verifying autotuner cache-key behavior across quantization modes and multiple token-count scenarios.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 26, 2026

📝 Walkthrough

Walkthrough

This PR centralizes SM100 MoE tuning via a new MoETuningSetup, refactors MoERunner to use it, updates FP8 MoE input/validation and tuning selection, fixes a log-typo in autotuner.py, and adds tests verifying MoE autotuner cache-key behavior.

Changes

Cohort / File(s) Summary
Typo Fix
flashinfer/autotuner.py
Fixed debug log typo: "[AutoTunner]" → "[AutoTuner]" (two occurrences).
MoE Tuning Configuration & Runner Refactor
flashinfer/fused_moe/core.py
Added MoETuningSetup (dynamic initializers, _make_tuning_config, index groups, precomputed routing configs), refine_tuning_config, select_fp8_tuning_config, build_fp8_moe_inputs; refactored SM100 FP8/FP4 paths to use these utilities; replaced inlined tuning configs; updated MoERunner to inherit MoETuningSetup + TunableRunner; adjusted runtime/shape checks and token derivation.
Tests: MoE Autotuner Cache Keys
tests/moe/test_moe_autotuner_cache_keys.py
New tests exercising tuning-profile cache-key alignment across DeepSeek/MXFp8, routing logits vs precomputed routing, various token counts, hidden-state-scale variants, and tuning-config selection/refinement.

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

ready, op: moe-routing

Suggested reviewers

  • yzh119
  • aleozlx
  • cyx-6
  • djmmoss
  • bkryu
  • jiahanc

Poem

🐇 I hopped through configs, neat and bright,
I bundled tunings into one light,
Tokens routed, tests agree,
Cache keys clap — a tuning spree,
Cheers from a rabbit in the night ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 21.05% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main changes: fixing autotuner cache key mismatches for FP8 block scale MoE operations.
Description check ✅ Passed The description provides comprehensive coverage with specific issue descriptions, error messages, benchmark results, and confirmed completion of pre-commit checks and tests.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Linda-Stadter, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical autotuner cache key mismatch that was causing performance degradation in FP8 block scale Mixture-of-Experts (MoE) operations. The solution involves a significant refactoring of the MoE tuning configuration logic, centralizing it into a dedicated setup class. This new structure allows for more precise and dynamic selection of tuning profiles based on the specific FP8 quantization type and routing method, thereby ensuring that the autotuner can consistently find and apply optimized tactics. The changes are validated with new tests covering various MoE configurations.

Highlights

  • Autotuner Cache Key Mismatch Fix: Resolved an issue where the autotuner for trtllm-gen FP8 block scale MoE operations would encounter cache key mismatches, leading to fallback tactics due to incorrect input shape handling during tuning.
  • Refactored MoE Tuning Configuration: Introduced a new MoETuningSetup class to centralize and manage various tuning configurations for FP8 MoE operations, including different quantization types (DeepSeekFp8, MxFp8) and routing mechanisms (from logits or precomputed).
  • Dynamic Tuning Config Selection: Implemented a select_fp8_tuning_config method to dynamically choose the appropriate tuning configuration based on whether routing logits are provided and the specific FP8 quantization type being used.
  • Improved Input Tensor Handling: Added a build_fp8_moe_inputs static method to correctly prepare input tensors for the autotuner, ensuring proper handling of topk_ids and expert_weights when routing logits are used or when they are precomputed.
  • New Autotuner Cache Key Tests: Added comprehensive unit tests to verify that the autotuner correctly generates and matches cache keys for various FP8 MoE scenarios, ensuring the fix is robust and prevents future regressions.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/autotuner.py
    • Corrected a typo in debug log messages from 'AutoTunner' to 'AutoTuner'.
  • flashinfer/fused_moe/core.py
    • Refactored MoE tuning configuration logic into a new MoETuningSetup class.
    • Defined multiple TuningConfig instances within MoETuningSetup for different FP8 quantization types and routing scenarios.
    • Added select_fp8_tuning_config method to dynamically choose the correct tuning configuration.
    • Implemented build_fp8_moe_inputs static method to prepare input tensors for autotuning.
    • Updated MoERunner to inherit from MoETuningSetup.
    • Modified assertions for topk_ids and expert_weights in forward method to handle empty tensors.
    • Removed redundant refine_tuning_config method from MoERunner.
    • Updated trtllm_fp8_block_scale_moe_op to utilize the new tuning configuration selection and input building methods.
  • tests/moe/test_moe_autotuner_cache_keys.py
    • Added new test file to verify autotuner cache key matching for FP8 MoE operations.
    • Included tests for DeepSeekFp8 and MxFp8 quantization types with both routing from logits and precomputed routing.
    • Added tests for configurations without hidden state scales and for respecting max_tokens limits.
Activity
  • The pull request description includes benchmark results demonstrating performance improvements after the fix, with FP8 After (ms) generally lower than FP8 Before (ms) for various token counts.
  • The author has provided a detailed description of the issue, including a specific error message and the reason for the incorrect tuning input.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses an autotuner cache key mismatch for FP8 MoE by refactoring the tuning configuration logic. The introduction of the MoETuningSetup class is a commendable improvement, enhancing modularity and correctness by centralizing the complex tuning configurations. The changes are well-structured, and the addition of comprehensive tests in tests/moe/test_moe_autotuner_cache_keys.py is a valuable contribution that ensures the fix is robust and helps prevent future regressions. I have one suggestion to further improve the robustness of a new helper method.

@Linda-Stadter Linda-Stadter marked this pull request as ready for review February 26, 2026 19:40
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/fused_moe/core.py (1)

1185-1240: ⚠️ Potential issue | 🟠 Major

Precomputed-routing autotuning can crash when routing_logits is None.

trtllm_fp8_block_scale_moe_op now routes a precomputed path with routing_logits=None, but MoERunner.get_valid_tactics (Line 1198) and MoERunner.forward (Line 1239) still unconditionally use routing_logits.shape[0]. In tuning mode this raises before tactic selection.

🐛 Proposed fix
         def get_valid_tactics(
             self,
             inputs: List[torch.Tensor],
             profile: OptimizationProfile,
         ) -> List[int]:
             (
                 output,
                 routing_logits,
                 topk_ids,
                 expert_weights,
                 hidden_states,
                 *extra_inputs,
             ) = inputs
-            num_tokens = routing_logits.shape[0]
+            token_source = routing_logits if routing_logits is not None else topk_ids
+            assert token_source is not None, (
+                "Either routing_logits or topk_ids must be provided."
+            )
+            num_tokens = token_source.shape[0]
@@
         def forward(
             self,
             inputs: List[torch.Tensor],
             tactic: int = -1,
             do_preparation: bool = False,
             **kwargs,
         ):
             (
                 output,
                 routing_logits,
                 topk_ids,
                 expert_weights,
                 hidden_states,
                 *extra_inputs,
             ) = inputs
-            num_tokens = routing_logits.shape[0]
+            token_source = routing_logits if routing_logits is not None else topk_ids
+            assert token_source is not None, (
+                "Either routing_logits or topk_ids must be provided."
+            )
+            num_tokens = token_source.shape[0]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1185 - 1240, get_valid_tactics and
forward assume routing_logits is not None and call routing_logits.shape[0]; when
routing is precomputed routing_logits may be None and causes a crash. Fix both
MoERunner.get_valid_tactics and MoERunner.forward by computing num_tokens
defensively, e.g. use routing_logits.shape[0] if routing_logits is not None else
topk_ids.shape[0] (or another appropriate tensor like topk_ids) before building
instance_key or using num_tokens; update references to num_tokens in
get_valid_tactics and forward accordingly so they no longer access
routing_logits when it's None.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1053-1058: Replace the lambda assignment "mk = lambda idx,
dim_idx, inits: cls._make_tuning_config(...)" with a named function to satisfy
Ruff E731; define a local def (e.g., def mk(idx, dim_idx, inits): return
cls._make_tuning_config(idx, dim_idx, inits, tune_max_num_tokens)) and use that
function in place of the lambda so calls referencing mk remain unchanged and
linting passes.

---

Outside diff comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1185-1240: get_valid_tactics and forward assume routing_logits is
not None and call routing_logits.shape[0]; when routing is precomputed
routing_logits may be None and causes a crash. Fix both
MoERunner.get_valid_tactics and MoERunner.forward by computing num_tokens
defensively, e.g. use routing_logits.shape[0] if routing_logits is not None else
topk_ids.shape[0] (or another appropriate tensor like topk_ids) before building
instance_key or using num_tokens; update references to num_tokens in
get_valid_tactics and forward accordingly so they no longer access
routing_logits when it's None.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f852eb6 and af2b2de.

📒 Files selected for processing (3)
  • flashinfer/autotuner.py
  • flashinfer/fused_moe/core.py
  • tests/moe/test_moe_autotuner_cache_keys.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1049-1057: The refine_tuning_config classmethod is decorated with
functools.lru_cache but mutates multiple class attributes and returns None,
causing stale cached None to prevent re-execution; remove the
`@functools.lru_cache`(maxsize=None) decorator above refine_tuning_config so the
method always runs and updates class state (refer to refine_tuning_config and
helper _make_tuning_config in class fused_moe.core) and ensure no other callers
depend on its caching behavior.

In `@tests/moe/test_moe_autotuner_cache_keys.py`:
- Around line 17-22: Add an architecture-based skip at the top of
tests/moe/test_moe_autotuner_cache_keys.py using flashinfer.utils (e.g., import
get_compute_capability or is_sm90a_supported from flashinfer.utils) and use
pytest.mark.skipif (or a runtime pytest.skip) to skip the module when the
current GPU compute capability is not supported by the SM-specific FP8 MoE
tuning paths; update the module imports (alongside AutoTuner,
Fp8QuantizationType, MoETuningSetup) and add a skip condition referencing the
utility (get_compute_capability/is_sm90a_supported) so the entire test file is
skipped on unsupported architectures.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between af2b2de and 26a7413.

📒 Files selected for processing (2)
  • flashinfer/fused_moe/core.py
  • tests/moe/test_moe_autotuner_cache_keys.py

Comment on lines +1049 to +1057
@classmethod
@functools.lru_cache(maxsize=None)
def refine_tuning_config(cls, tune_max_num_tokens: int):
mk = lambda idx, dim_idx, inits: cls._make_tuning_config(
idx,
dim_idx,
inits,
tune_max_num_tokens,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, let's check the file structure and locate the method
fd -t f "core.py" flashinfer/fused_moe/

Repository: flashinfer-ai/flashinfer

Length of output: 96


🏁 Script executed:

#!/bin/bash
# Get the actual implementation of refine_tuning_config and surrounding code
cat -n flashinfer/fused_moe/core.py | sed -n '1040,1100p'

Repository: flashinfer-ai/flashinfer

Length of output: 2451


🏁 Script executed:

#!/bin/bash
# Check for all calls to refine_tuning_config in the codebase
rg -n "refine_tuning_config" --type py

Repository: flashinfer-ai/flashinfer

Length of output: 922


🏁 Script executed:

#!/bin/bash
# Look for where cls.tuning_config_ is assigned/mutated
rg -n "cls\.tuning_config_|self\.tuning_config_" flashinfer/fused_moe/core.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 793


🏁 Script executed:

#!/bin/bash
# Examine the test file to see if it calls with varying max_tokens
cat -n tests/moe/test_moe_autotuner_cache_keys.py | sed -n '100,130p'

Repository: flashinfer-ai/flashinfer

Length of output: 1471


🏁 Script executed:

#!/bin/bash
# Check the calling context for MoERunner.refine_tuning_config
cat -n flashinfer/fused_moe/core.py | sed -n '1465,1475p'

Repository: flashinfer-ai/flashinfer

Length of output: 510


🏁 Script executed:

#!/bin/bash
# Check another calling context
cat -n flashinfer/fused_moe/core.py | sed -n '1610,1620p'

Repository: flashinfer-ai/flashinfer

Length of output: 540


🏁 Script executed:

#!/bin/bash
# Check if there's a pattern of multiple different tune_max_num_tokens being used
rg -B5 "refine_tuning_config\(" flashinfer/fused_moe/core.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 1759


🏁 Script executed:

#!/bin/bash
# Look for any initialization or reset logic for these class attributes
rg -B10 "def __init__|tuning_config_with_hidden_states_scales\s*=" flashinfer/fused_moe/core.py | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1251


🏁 Script executed:

#!/bin/bash
# Check if refine_tuning_config is called at module initialization or only at runtime
rg -n "refine_tuning_config" flashinfer/fused_moe/core.py | grep -v "def refine_tuning_config"

Repository: flashinfer-ai/flashinfer

Length of output: 456


🏁 Script executed:

#!/bin/bash
# Check for any __init__ calls or class resets
rg -n "MoETuningSetup\(\)|MoERunner\(\)" flashinfer/fused_moe/core.py tests/

Repository: flashinfer-ai/flashinfer

Length of output: 50


Remove lru_cache from side-effectful refine_tuning_config classmethod.

This method mutates 7 class attributes (lines 1063–1097) and returns None. Caching the None return value prevents re-execution when called with the same tune_max_num_tokens after being called with a different value. For example: refine_tuning_config(4096) sets class state; refine_tuning_config(16384) updates it; refine_tuning_config(4096) returns cached None without re-mutating, leaving class attributes at the 16384 state. Since the method provides no return value benefit from caching, the decorator should be removed.

Fix
     `@classmethod`
-    `@functools.lru_cache`(maxsize=None)
     def refine_tuning_config(cls, tune_max_num_tokens: int):
         mk = lambda idx, dim_idx, inits: cls._make_tuning_config(
             idx,
             dim_idx,
             inits,
             tune_max_num_tokens,
         )
🧰 Tools
🪛 Ruff (0.15.2)

[error] 1052-1057: Do not assign a lambda expression, use a def

Rewrite mk as a def

(E731)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1049 - 1057, The
refine_tuning_config classmethod is decorated with functools.lru_cache but
mutates multiple class attributes and returns None, causing stale cached None to
prevent re-execution; remove the `@functools.lru_cache`(maxsize=None) decorator
above refine_tuning_config so the method always runs and updates class state
(refer to refine_tuning_config and helper _make_tuning_config in class
fused_moe.core) and ensure no other callers depend on its caching behavior.

@Linda-Stadter Linda-Stadter changed the title fix: autotuner cache key mismatch for trtllm-gen FP8 block scale MoE fix: autotuner cache key mismatch for trtllm-gen FP8 block scale MoE and FP8 routed MoE Feb 27, 2026
@aleozlx aleozlx self-assigned this Feb 27, 2026
@aleozlx
Copy link
Collaborator

aleozlx commented Feb 27, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !357 has been created, and the CI pipeline #44998624 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #44998624: 9/20 passed

Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
flashinfer/fused_moe/core.py (1)

1047-1097: ⚠️ Potential issue | 🟠 Major

Drop lru_cache from refine_tuning_config.

This method mutates class-level tuning configs and returns None, so a cache hit skips the rebuild entirely. A sequence like 4096 -> 16384 -> 4096 leaves the 16384 profiles installed, which can still send the autotuner down the wrong cache-key path.

🔧 Minimal fix
     `@classmethod`
-    `@functools.lru_cache`(maxsize=None)
     def refine_tuning_config(cls, tune_max_num_tokens: int, **kwargs):
         mk = lambda idx, dim_idx, inits: cls._make_tuning_config(
             idx,
             dim_idx,
             inits,

Run this read-only check to confirm the decorator behavior. The last printed line should show that the state stayed at 16384 after the second 4096 call:

#!/bin/bash
python - <<'PY'
import functools

class Demo:
    state = None

    `@classmethod`
    `@functools.lru_cache`(maxsize=None)
    def refine(cls, tokens):
        cls.state = tokens

for tokens in (4096, 16384, 4096):
    Demo.refine(tokens)
    print(f"after refine({tokens}): state={Demo.state}")
PY

rg -n -C2 'refine_tuning_config|@functools\.lru_cache' flashinfer/fused_moe/core.py
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1047 - 1097, The
refine_tuning_config method is decorated with `@functools.lru_cache` but mutates
class-level tuning config attributes (e.g.,
tuning_config_with_hidden_states_scales, tuning_config_no_hidden_states_scales,
tuning_config_routing_from_logits, etc.) and returns None, so cached calls skip
rebuilding state; remove the `@functools.lru_cache`(maxsize=None) decorator from
the classmethod refine_tuning_config to ensure each call reinitializes the class
attributes (keep the method signature and internal logic as-is). If caching of
results is desired instead, implement an explicit cache keyed by
tune_max_num_tokens that stores and returns a copy rather than using lru_cache
on the mutating method.
🧹 Nitpick comments (1)
tests/moe/test_moe_autotuner_cache_keys.py (1)

110-125: Cover the max-token flip-flop case.

This only validates one refine_tuning_config() value per invocation, so it still passes if the configs fail to switch back after a larger max_tokens run. A 4096 -> 16384 -> 4096 sequence would catch that stale-state regression directly.

🧪 Suggested regression case
-@pytest.mark.parametrize("max_tokens", [4096, 16384])
-def test_max_tokens_respected(max_tokens):
-    """Tokens at max_tokens must still hit the cache after refine."""
-    MoETuningSetup.refine_tuning_config(max_tokens)
-    config = MoETuningSetup.select_fp8_tuning_config(
-        has_routing_logits=True,
-        fp8_quantization_type=Fp8QuantizationType.DeepSeekFp8,
-    )
-    output = torch.empty(max_tokens, HIDDEN_SIZE)
-    inputs = MoETuningSetup.build_fp8_moe_inputs(
-        routing_logits=torch.empty(max_tokens, NUM_EXPERTS),
-        hidden_states=torch.empty(max_tokens, HIDDEN_SIZE),
-        hidden_states_scale=torch.empty(SCALE_DIM, max_tokens),
-        output=output,
-    )
-    _assert_cache_key_match(config, inputs)
+def test_max_tokens_respected_after_repeated_refines():
+    for max_tokens in (4096, 16384, 4096):
+        MoETuningSetup.refine_tuning_config(max_tokens)
+        config = MoETuningSetup.select_fp8_tuning_config(
+            has_routing_logits=True,
+            fp8_quantization_type=Fp8QuantizationType.DeepSeekFp8,
+        )
+        output = torch.empty(max_tokens, HIDDEN_SIZE)
+        inputs = MoETuningSetup.build_fp8_moe_inputs(
+            routing_logits=torch.empty(max_tokens, NUM_EXPERTS),
+            hidden_states=torch.empty(max_tokens, HIDDEN_SIZE),
+            hidden_states_scale=torch.empty(SCALE_DIM, max_tokens),
+            output=output,
+        )
+        _assert_cache_key_match(config, inputs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_moe_autotuner_cache_keys.py` around lines 110 - 125, Update
the test_max_tokens_respected to exercise the flip-flop sequence by calling
MoETuningSetup.refine_tuning_config with 4096, then 16384, then 4096 again
before selecting the config via MoETuningSetup.select_fp8_tuning_config and
building inputs with MoETuningSetup.build_fp8_moe_inputs; finally run
_assert_cache_key_match on the resulting config and inputs so the test verifies
the state correctly returns to the smaller max_tokens after the larger run.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1047-1097: The refine_tuning_config method is decorated with
`@functools.lru_cache` but mutates class-level tuning config attributes (e.g.,
tuning_config_with_hidden_states_scales, tuning_config_no_hidden_states_scales,
tuning_config_routing_from_logits, etc.) and returns None, so cached calls skip
rebuilding state; remove the `@functools.lru_cache`(maxsize=None) decorator from
the classmethod refine_tuning_config to ensure each call reinitializes the class
attributes (keep the method signature and internal logic as-is). If caching of
results is desired instead, implement an explicit cache keyed by
tune_max_num_tokens that stores and returns a copy rather than using lru_cache
on the mutating method.

---

Nitpick comments:
In `@tests/moe/test_moe_autotuner_cache_keys.py`:
- Around line 110-125: Update the test_max_tokens_respected to exercise the
flip-flop sequence by calling MoETuningSetup.refine_tuning_config with 4096,
then 16384, then 4096 again before selecting the config via
MoETuningSetup.select_fp8_tuning_config and building inputs with
MoETuningSetup.build_fp8_moe_inputs; finally run _assert_cache_key_match on the
resulting config and inputs so the test verifies the state correctly returns to
the smaller max_tokens after the larger run.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 613a2304-e6d1-4244-9cca-c21e5c45f153

📥 Commits

Reviewing files that changed from the base of the PR and between 26a7413 and 1a2f043.

📒 Files selected for processing (3)
  • flashinfer/autotuner.py
  • flashinfer/fused_moe/core.py
  • tests/moe/test_moe_autotuner_cache_keys.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/autotuner.py

@Linda-Stadter
Copy link
Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@Linda-Stadter is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@hypdeb
Copy link
Contributor

hypdeb commented Mar 11, 2026

Is the drop in TFLOPs at 4k tokens between untuned and tuned FP8 expected?

@Linda-Stadter
Copy link
Author

Is the drop in TFLOPs at 4k tokens between untuned and tuned FP8 expected?

I believe this is measured noise. But could always be that autotuning does not select the best kernel for the actual data

@yzh119
Copy link
Collaborator

yzh119 commented Mar 11, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !357 has been updated with latest changes, and the CI pipeline #45887410 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #45887410: canceled

@yzh119
Copy link
Collaborator

yzh119 commented Mar 14, 2026

The performance of FP8 routed MoE looks more concerning to me (3/5 of the cases look worse after tuning).

If it's measurement noise, is there a way we can improve the stability? e.g. increasing the repetitions, etc.

@Linda-Stadter
Copy link
Author

The performance of FP8 routed MoE looks more concerning to me (3/5 of the cases look worse after tuning).

If it's measurement noise, is there a way we can improve the stability? e.g. increasing the repetitions, etc.

Yes, even after measuring a second time, I still see 3/5 cases that are worse after tuning. However, the scope of the PR was not to improve the tuning process, but to fix the logic that retrieves the cached tuned tactics.
Or would you rather have the autotuner break for the routed MoE path than retrieve potentially worse tactics?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants