[NVIDIA] Explicitly disable shuffled weights for flashinfer blockscale moe fp8 kernels#21411
Conversation
Signed-off-by: kaixih <kaixih@nvidia.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Code Review
This pull request aims to restore previous behavior for FlashInfer MoE kernels by explicitly disabling a new use_shuffled_weight flag. While the change is correct in its intent, it introduces a critical backward compatibility issue for users with older versions of FlashInfer. I've provided a comment with a suggested fix to address this.
| tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k, | ||
| global_num_experts), | ||
| routing_method_type=2, # DeepSeek-styled routing method | ||
| use_shuffled_weight=False, |
There was a problem hiding this comment.
This change explicitly adds the use_shuffled_weight argument, which was introduced in FlashInfer v0.2.9. This will cause a TypeError for users with older versions of FlashInfer, breaking backward compatibility.
To fix this, we should only pass the argument if the installed FlashInfer version supports it. This can be done with a version check, which requires refactoring the function call to use a kwargs dictionary.
Here is a suggested implementation to replace lines 1097-1131:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
kwargs = dict(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=w13_weight,
gemm1_weights_scale=w13_weight_scale_inv,
gemm2_weights=w2_weight,
gemm2_weights_scale=w2_weight_scale_inv,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
global_num_experts),
routing_method_type=2, # DeepSeek-styled routing method
)
try:
import flashinfer
from packaging.version import Version
# The use_shuffled_weight argument was added in flashinfer v0.2.9
if Version(flashinfer.__version__) >= Version("0.2.9"):
kwargs["use_shuffled_weight"] = False
except (ImportError, AttributeError):
# Older flashinfer version or flashinfer not installed.
# The lazy loader will handle the ImportError later if it's missing.
pass
return flashinfer_trtllm_fp8_block_scale_moe(**kwargs)|
Thanks Kaixi, should we wait for the 0.2.9 release to land this? |
|
@mgoin Sure. |
|
@mgoin can we merge this PR since the flashinfer 0.2.9rc1 is in. |
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <kaixih@nvidia.com>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <kaixih@nvidia.com>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <kaixih@nvidia.com> Signed-off-by: x22x22 <wadeking@qq.com>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <kaixih@nvidia.com>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <kaixih@nvidia.com>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <kaixih@nvidia.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <kaixih@nvidia.com> Signed-off-by: Paul Pak <paulpak58@gmail.com>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <kaixih@nvidia.com> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <kaixih@nvidia.com>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <kaixih@nvidia.com>
The latest Flashinfer (PR) introduces a new flag to the
trtllm_fp8_block_scale_moeAPI, which defaults toTrue. This PR explicitly disables it to restore the previous behavior.I have verified the perf and accuracy with the tot and we recommend to use flashinfer v0.2.9.
cc. @kushanam @mgoin