Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 52 additions & 43 deletions vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,27 +380,32 @@ def _apply_block_scale(
use_shuffled_weight = False
hidden_states_scale = a1q_scale.t().contiguous()

return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale,
gemm2_weights=w2,
gemm2_weights_scale=self.quant_config.w2_scale,
num_experts=global_num_experts,
top_k=self.topk,
n_group=(num_expert_group or 0),
topk_group=(topk_group or 0),
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method_type,
use_shuffled_weight=use_shuffled_weight,
fp8_quantization_type=fp8_quant_type,
)
# Disable autotune until
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is resolved.
from vllm.utils.flashinfer import autotune

with autotune(False):
return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale,
gemm2_weights=w2,
gemm2_weights_scale=self.quant_config.w2_scale,
num_experts=global_num_experts,
top_k=self.topk,
n_group=(num_expert_group or 0),
topk_group=(topk_group or 0),
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method_type,
use_shuffled_weight=use_shuffled_weight,
fp8_quantization_type=fp8_quant_type,
)

def _apply_per_tensor(
self,
Expand Down Expand Up @@ -437,28 +442,32 @@ def _apply_per_tensor(
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)

out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states,
gemm1_weights=w1,
output1_scales_scalar=self._g1_scale_c,
output1_scales_gate_scalar=self._g1_alphas,
gemm2_weights=w2,
output2_scales_scalar=self._g2_alphas,
num_experts=global_num_experts,
top_k=self.topk,
n_group=num_expert_group or 0,
topk_group=topk_group or 0,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=apply_router_weight_on_input,
routing_method_type=self.routing_method_type,
activation_type=activation_type,
)
return out
# Disable autotune until
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is resolved.
from vllm.utils.flashinfer import autotune

with autotune(False):
return flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states,
gemm1_weights=w1,
output1_scales_scalar=self._g1_scale_c,
output1_scales_gate_scalar=self._g1_alphas,
gemm2_weights=w2,
output2_scales_scalar=self._g2_alphas,
num_experts=global_num_experts,
top_k=self.topk,
n_group=num_expert_group or 0,
topk_group=topk_group or 0,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=apply_router_weight_on_input,
routing_method_type=self.routing_method_type,
activation_type=activation_type,
)

def apply(
self,
Expand Down
72 changes: 41 additions & 31 deletions vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,34 +309,44 @@ def apply(
)

# Invoke kernel.
return flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
),
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=w2,
gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=self.g1_scale_c,
output1_scale_gate_scalar=self.quant_config.g1_alphas,
output2_scale_scalar=self.quant_config.g2_alphas,
num_experts=global_num_experts,
top_k=self.topk,
n_group=(num_expert_group or 0),
topk_group=(topk_group or 0),
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method_type,
do_finalize=True,
activation_type=activation_to_flashinfer_int(activation),
)[0]
# Disable autotune until
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is resolved.
from vllm.utils.flashinfer import autotune

with autotune(False):
# Enable autotune when flashinfer#2023 is resolved.
return flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
),
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=w2,
gemm2_weights_scale=self.quant_config.w2_scale.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=self.g1_scale_c,
output1_scale_gate_scalar=self.quant_config.g1_alphas,
output2_scale_scalar=self.quant_config.g2_alphas,
num_experts=global_num_experts,
top_k=self.topk,
n_group=(num_expert_group or 0),
topk_group=(topk_group or 0),
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method_type,
do_finalize=True,
activation_type=activation_to_flashinfer_int(activation),
)[0]
40 changes: 26 additions & 14 deletions vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ def apply(
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool | None,
):
# Skip this kernel during autotuning dummy run to avoid
# CUDA errors from incompatible autotuning (flashinfer#2023).
import vllm.utils.flashinfer as fi_utils

if fi_utils._is_fi_autotuning:
return

assert self.quant_dtype == "nvfp4", (
"Only nvfp4 quantization are currently supported."
)
Expand All @@ -165,20 +172,25 @@ def apply(
if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH
else hidden_states
)
flashinfer_cutedsl_moe_masked(
hidden_states=flashinfer_hidden_states,
input_global_scale=input_global_scale,
w1=w1,
w1_blockscale=self.w1_scale,
w1_alpha=self.g1_alphas,
w2=w2,
a2_global_scale=self.a2_gscale,
w2_blockscale=self.w2_scale,
w2_alpha=self.g2_alphas,
masked_m=expert_num_tokens,
workspace=workspace2,
out=output,
)
# Disable autotune until
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is resolved.
from vllm.utils.flashinfer import autotune

with autotune(False):
flashinfer_cutedsl_moe_masked(
hidden_states=flashinfer_hidden_states,
input_global_scale=input_global_scale,
w1=w1,
w1_blockscale=self.w1_scale,
w1_alpha=self.g1_alphas,
w2=w2,
a2_global_scale=self.a2_gscale,
w2_blockscale=self.w2_scale,
w2_alpha=self.g2_alphas,
masked_m=expert_num_tokens,
workspace=workspace2,
out=output,
)


def get_cute_dtype(input: torch.Tensor) -> str:
Expand Down
65 changes: 39 additions & 26 deletions vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,13 @@ def apply(
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool | None,
):
# Skip this kernel during autotuning dummy run to avoid
# CUDA errors from incompatible autotuning (flashinfer#2023).
import vllm.utils.flashinfer as fi_utils

if fi_utils._is_fi_autotuning:
return

from flashinfer.fused_moe.core import ActivationType

activation_str_to_value_map = {
Expand Down Expand Up @@ -366,32 +373,38 @@ def apply(
fc1_expert_weights = w1
fc2_expert_weights = w2

_ = flashinfer_cutlass_fused_moe(
input=hidden_states,
token_selected_experts=topk_ids.to(torch.int),
token_final_scales=topk_weights,
fc1_expert_weights=fc1_expert_weights,
fc2_expert_weights=fc2_expert_weights,
fc1_expert_biases=fc1_expert_biases,
fc2_expert_biases=fc2_expert_biases,
swiglu_alpha=swiglu_alpha,
swiglu_beta=swiglu_beta,
swiglu_limit=swiglu_limit,
output=output,
output_dtype=self.out_dtype,
quant_scales=quant_scales,
input_sf=a1q_scale,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
activation_type=activation_str_to_value_map[activation],
# Informs FlashInfer to use the block-scale decoding path when True
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
use_w4_group_scaling=use_w4_group_scaling,
tune_max_num_tokens=max(self.max_capture_size, 1),
)
# Disable autotune until
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is resolved.
from vllm.utils.flashinfer import autotune

with autotune(False):
flashinfer_cutlass_fused_moe(
input=hidden_states,
token_selected_experts=topk_ids.to(torch.int),
token_final_scales=topk_weights,
fc1_expert_weights=fc1_expert_weights,
fc2_expert_weights=fc2_expert_weights,
fc1_expert_biases=fc1_expert_biases,
fc2_expert_biases=fc2_expert_biases,
swiglu_alpha=swiglu_alpha,
swiglu_beta=swiglu_beta,
swiglu_limit=swiglu_limit,
output=output,
output_dtype=self.out_dtype,
quant_scales=quant_scales,
input_sf=a1q_scale,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
activation_type=activation_str_to_value_map[activation],
# Informs FlashInfer to use the block-scale decoding
# path when True
use_deepseek_fp8_block_scale=(self.use_deepseek_fp8_block_scale),
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
use_w4_group_scaling=use_w4_group_scaling,
tune_max_num_tokens=max(self.max_capture_size, 1),
)

def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
# No support for LoRA in flashinfer_cutlass_fused_moe.
Expand Down
39 changes: 21 additions & 18 deletions vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,24 +94,27 @@ def flashinfer_fused_moe_bf16(
routing_method_type: int,
tune_max_num_tokens: int = 8192,
) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_bf16_moe

return flashinfer_trtllm_bf16_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
gemm1_weights=gemm1_weights,
gemm2_weights=gemm2_weights,
num_experts=num_experts,
top_k=top_k,
n_group=n_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routing_method_type=routing_method_type,
tune_max_num_tokens=tune_max_num_tokens,
)
from vllm.utils.flashinfer import autotune, flashinfer_trtllm_bf16_moe

# Disable autotune until
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is resolved.
with autotune(False):
return flashinfer_trtllm_bf16_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
gemm1_weights=gemm1_weights,
gemm2_weights=gemm2_weights,
num_experts=num_experts,
top_k=top_k,
n_group=n_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routing_method_type=routing_method_type,
tune_max_num_tokens=tune_max_num_tokens,
)


def flashinfer_fused_moe_bf16_fake(
Expand Down
Loading
Loading