Skip to content

Fix fallback to default tactic (flashinfer autotuner) with trtllm_fp4_block_scale_moe#35088

Merged
vllm-bot merged 1 commit intovllm-project:mainfrom
de-inf:fix_trtllm_moe_autotune
Feb 24, 2026
Merged

Fix fallback to default tactic (flashinfer autotuner) with trtllm_fp4_block_scale_moe#35088
vllm-bot merged 1 commit intovllm-project:mainfrom
de-inf:fix_trtllm_moe_autotune

Conversation

@danisereb
Copy link
Copy Markdown
Contributor

@danisereb danisereb commented Feb 23, 2026

Purpose

The flashinfer autotuner expects the first dimension of the MoE tensors to be num_tokens.

The relevant code can be found in flashinfer function get_trtllm_moe_sm100_module:

        # their first dimension is num_tokens which will be tuned
        tuning_config_with_hidden_states_scales = TuningConfig(
            dynamic_tensor_specs=(
                DynamicTensorSpec(
                    (0, 1, 2, 3, 4, 5),
                    (0, 0, 0, 0, 0, 0),
                    get_last_power_of_2_num_tokens_buckets(8192, 1),
                    lambda x: min(last_positive_power_of_2(x), 8192),
                    dynamic_tensor_initializers,
                ),
            )
        )

When running with vLLM main with env var FLASHINFER_LOGGING_LEVEL=DEBUG, we can see the following:

flashinfer.jit: [AutoTunner]: Using fallback tactic for flashinfer::trtllm_fp4_block_scale_moe with input shapes ...

Using reshape(...) instead of flatten() solves the issues (no using fallback tactic prints).

This change is aligned with function Mxfp4MoEMethod.apply_monolithic:

x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)

Test Plan

Verify that there is no drop in accuracy (lm_eval).

Test Result

Used the following MoE NVFP4 model:
https://huggingface.co/RedHatAI/Llama-4-Scout-17B-16E-Instruct-NVFP4

Use env vars:

export VLLM_USE_FLASHINFER_MOE_FP4=1
export VLLM_FLASHINFER_MOE_BACKEND=latency

To use 'FLASHINFER_TRTLLM' NvFp4 MoE backend.

lm_eval gsm8k:

lm_eval \
  --model vllm \
  --model_args pretrained="/my_home/hf_models/RedHatAI/Llama-4-Scout-17B-16E-Instruct-NVFP4",dtype=auto,add_bos_token=True,max_model_len=4096,tensor_parallel_size=2,enable_chunked_prefill=True,enforce_eager=True \
  --tasks gsm8k_llama \
  --apply_chat_template \
  --fewshot_as_multiturn \
  --batch_size auto

### vLLM main commit: 
|   Tasks   |Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k_llama|      3|flexible_extract|     8|exact_match|↑  |0.9371|±  |0.0067|
|           |       |strict_match    |     8|exact_match|↑  |0.9371|±  |0.0067|


### This PR:
|   Tasks   |Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k_llama|      3|flexible_extract|     8|exact_match|↑  |0.9371|±  |0.0067|
|           |       |strict_match    |     8|exact_match|↑  |0.9371|±  |0.0067|

There was no change in vllm bench serve performance (tok/sec).


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a performance issue with the flashinfer MoE kernels. The original code used .flatten() on the hidden_states_scale tensor, which collapsed its dimensions and prevented the flashinfer autotuner from identifying the optimal tactic, causing a fallback to a less efficient default. The fix replaces .flatten() with .reshape(*hidden_states_fp4.shape[:-1], -1), which correctly preserves the tensor's leading dimensions (including num_tokens) as expected by the kernel. This change is applied consistently in both flashinfer_trtllm_fp4_moe and flashinfer_trtllm_fp4_routed_moe functions, resolving the performance degradation. The change is correct and well-justified.

@danisereb danisereb changed the title Fix fallback to default flashinfer tactic with trtllm_fp4_block_scale_moe Fix fallback to default tactic (flashinfer autotuner) with trtllm_fp4_block_scale_moe Feb 23, 2026
@danisereb danisereb force-pushed the fix_trtllm_moe_autotune branch 2 times, most recently from e882c41 to 19ec584 Compare February 23, 2026 13:37
@danisereb danisereb marked this pull request as ready for review February 23, 2026 14:59
Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, LGTM

Comment on lines 349 to +351
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).flatten(),
).reshape(*hidden_states_fp4.shape[:-1], -1),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is reshape required or could we use view here? I just ask so we avoid the implicit .contiguous() that reshape has

Copy link
Copy Markdown
Contributor Author

@danisereb danisereb Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like I mentioned in the PR description, I followed the reshape from Mxfp4MoEMethod.apply_monolithic.

As far as I understand, reshape() uses a view when possible and implicit contiguous if it's required.
And view() may throw an error if the underlying storage/strides does not match the requested shape.

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Feb 23, 2026
@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed performance Performance-related issues labels Feb 23, 2026
…_block_scale_moe

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
@danisereb danisereb force-pushed the fix_trtllm_moe_autotune branch from 19ec584 to fd87a3a Compare February 23, 2026 17:08
@danisereb
Copy link
Copy Markdown
Contributor Author

danisereb commented Feb 23, 2026

Rebased to get this change:
#34874

seems to be related to the CI failures I see.

There's still a bug in certain cases (num tokens not power of 2), but the fix is in flashinfer - see this PR:
flashinfer-ai/flashinfer#2617

Copy link
Copy Markdown
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the test failures are unrelated (model architectures in the failing tests likely won't invoke trtllm FP4 moe kernels)

There's still a bug in certain cases (num tokens not power of 2), but the fix is in flashinfer - see this PR:
flashinfer-ai/flashinfer#2617

With smaller batch sizes that would be enabled through #34936 in --performance-mode latency, we may see an issue. However, since this is a perf improvement, we should be okay to merge.

@pavanimajety pavanimajety enabled auto-merge (squash) February 23, 2026 20:07
@vllm-bot vllm-bot merged commit a0c7081 into vllm-project:main Feb 24, 2026
60 of 62 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Feb 24, 2026
tom-zju pushed a commit to tom-zju/vllm that referenced this pull request Feb 26, 2026
…_block_scale_moe (vllm-project#35088)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
…_block_scale_moe (vllm-project#35088)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…_block_scale_moe (vllm-project#35088)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
askliar pushed a commit to askliar/vllm that referenced this pull request Mar 9, 2026
…_block_scale_moe (vllm-project#35088)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
…_block_scale_moe (vllm-project#35088)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…_block_scale_moe (vllm-project#35088)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants