Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
Expand Down Expand Up @@ -1650,16 +1651,19 @@ def apply(
use_llama4_routing = (
custom_routing_function is Llama4MoE.custom_routing_function
)
routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
routing_method_type = layer.routing_method_type
if use_llama4_routing:
routing_method_type = flashinfer.RoutingMethodType.Llama4
routing_method_type = RoutingMethodType.Llama4
router_logits = (
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
)
routing_bias = e_score_correction_bias
if routing_bias is not None:
routing_bias = routing_bias.to(torch.bfloat16)
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits
if use_llama4_routing
else router_logits.to(torch.float32),
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
Expand All @@ -1683,8 +1687,8 @@ def apply(
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group if num_expert_group is not None else 0,
topk_group=topk_group if topk_group is not None else 0,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,5 +291,8 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:

def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> bool:
# TODO(shuw@nvidia): Update when new backends are added.
backends_supporting_global_sf = (FlashinferMoeBackend.CUTLASS,)
backends_supporting_global_sf = (
FlashinferMoeBackend.CUTLASS,
FlashinferMoeBackend.TENSORRT_LLM,
)
return backend in backends_supporting_global_sf