Fix fallback to default tactic (flashinfer autotuner) with trtllm_fp4_block_scale_moe#35088
Conversation
There was a problem hiding this comment.
Code Review
This pull request correctly addresses a performance issue with the flashinfer MoE kernels. The original code used .flatten() on the hidden_states_scale tensor, which collapsed its dimensions and prevented the flashinfer autotuner from identifying the optimal tactic, causing a fallback to a less efficient default. The fix replaces .flatten() with .reshape(*hidden_states_fp4.shape[:-1], -1), which correctly preserves the tensor's leading dimensions (including num_tokens) as expected by the kernel. This change is applied consistently in both flashinfer_trtllm_fp4_moe and flashinfer_trtllm_fp4_routed_moe functions, resolving the performance degradation. The change is correct and well-justified.
e882c41 to
19ec584
Compare
| hidden_states_scale=hidden_states_scale_linear_fp4.view( | ||
| torch.float8_e4m3fn | ||
| ).flatten(), | ||
| ).reshape(*hidden_states_fp4.shape[:-1], -1), |
There was a problem hiding this comment.
Is reshape required or could we use view here? I just ask so we avoid the implicit .contiguous() that reshape has
There was a problem hiding this comment.
Like I mentioned in the PR description, I followed the reshape from Mxfp4MoEMethod.apply_monolithic.
As far as I understand, reshape() uses a view when possible and implicit contiguous if it's required.
And view() may throw an error if the underlying storage/strides does not match the requested shape.
…_block_scale_moe Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
19ec584 to
fd87a3a
Compare
|
Rebased to get this change: seems to be related to the CI failures I see. There's still a bug in certain cases (num tokens not power of 2), but the fix is in flashinfer - see this PR: |
pavanimajety
left a comment
There was a problem hiding this comment.
Seems like the test failures are unrelated (model architectures in the failing tests likely won't invoke trtllm FP4 moe kernels)
There's still a bug in certain cases (num tokens not power of 2), but the fix is in flashinfer - see this PR:
flashinfer-ai/flashinfer#2617
With smaller batch sizes that would be enabled through #34936 in --performance-mode latency, we may see an issue. However, since this is a perf improvement, we should be okay to merge.
…_block_scale_moe (vllm-project#35088) Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
…_block_scale_moe (vllm-project#35088) Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
…_block_scale_moe (vllm-project#35088) Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
…_block_scale_moe (vllm-project#35088) Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com> Signed-off-by: Andrii Skliar <askliar@nvidia.com>
…_block_scale_moe (vllm-project#35088) Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
…_block_scale_moe (vllm-project#35088) Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com> Signed-off-by: EricccYang <yangyang4991@gmail.com>
Purpose
The flashinfer autotuner expects the first dimension of the MoE tensors to be num_tokens.
The relevant code can be found in flashinfer function
get_trtllm_moe_sm100_module:When running with vLLM main with env var
FLASHINFER_LOGGING_LEVEL=DEBUG, we can see the following:Using
reshape(...)instead offlatten()solves the issues (nousing fallback tacticprints).This change is aligned with function
Mxfp4MoEMethod.apply_monolithic:Test Plan
Verify that there is no drop in accuracy (lm_eval).
Test Result
Used the following MoE NVFP4 model:
https://huggingface.co/RedHatAI/Llama-4-Scout-17B-16E-Instruct-NVFP4
Use env vars:
To use
'FLASHINFER_TRTLLM' NvFp4 MoE backend.lm_eval gsm8k:
There was no change in
vllm bench serveperformance (tok/sec).Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.