diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index 18216b5965af..5b35b152ff38 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -88,7 +88,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| deep gemm | standard,batched | fp8 | G(128),A,T | silu, gelu | 6 | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] |
| cutlass_fp4 | standard,batched | nvfp4 | A,T | silu | Y | Y | [`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] |
| cutlass_fp8 | standard,batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
-| flashinfer | standard | nvfp4,fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
+| flashinfer | standard | nvfp4,fp8 | T | 5 | N | Y | [`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| marlin | standard,batched | 3 / N/A | 3 / N/A | silu,swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| trtllm | standard | mxfp4,nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py
index a1fd098aee5f..66d4f8d2caf3 100644
--- a/tests/compile/test_fusion_attn.py
+++ b/tests/compile/test_fusion_attn.py
@@ -30,7 +30,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
- kNvfp4Quant,
+ kNvfp4Dynamic,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform
@@ -202,7 +202,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
"""Test model for AttentionNvfp4QuantPattern fusion."""
- quant_key = kNvfp4Quant
+ quant_key = kNvfp4Dynamic
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -449,7 +449,7 @@ def test_attention_quant_pattern(
# Note: for fp8, fully_replaced=False because query quant ops remain in graph.
# Only output quant ops are fused into attention.
- test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant)
+ test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic)
# access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py
index eb0dee8d4e39..446f8c7626b8 100644
--- a/tests/compile/test_silu_mul_quant_fusion.py
+++ b/tests/compile/test_silu_mul_quant_fusion.py
@@ -29,7 +29,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8StaticTensorSym,
- kNvfp4Quant,
+ kNvfp4Dynamic,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
@@ -121,11 +121,11 @@ def forward(self, x):
def ops_in_model_before(self):
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
- QUANT_OPS[kNvfp4Quant],
+ QUANT_OPS[kNvfp4Dynamic],
]
def ops_in_model_after(self):
- return [FUSED_OPS[kNvfp4Quant]]
+ return [FUSED_OPS[kNvfp4Dynamic]]
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py
index 537dcae4e74b..dd0519c1953a 100644
--- a/tests/kernels/moe/modular_kernel_tools/common.py
+++ b/tests/kernels/moe/modular_kernel_tools/common.py
@@ -161,15 +161,15 @@ def is_fp8_block_quantized(self):
def is_batched_prepare_finalize(self):
info = prepare_finalize_info(self.prepare_finalize_type)
- return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
+ return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format()
def is_batched_fused_experts(self):
info = expert_info(self.fused_experts_type)
- return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
+ return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format()
def is_standard_fused_experts(self):
info = expert_info(self.fused_experts_type)
- return mk.FusedMoEActivationFormat.Standard == info.activation_format
+ return mk.FusedMoEActivationFormat.Standard == info.activation_format()
def fe_supported_types(self):
info = expert_info(self.fused_experts_type)
@@ -574,10 +574,13 @@ def next_power_of_2(x):
num_experts=config.E,
experts_per_token=config.topk,
hidden_dim=config.K,
+ intermediate_size_per_partition=config.N,
num_local_experts=config.num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype,
max_num_tokens=next_power_of_2(config.M),
+ activation="silu",
+ device=vllm_config.device_config.device,
)
# make modular kernel
diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py
index 99b168dc7554..f085706bb22c 100644
--- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py
+++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py
@@ -237,7 +237,6 @@ def expert_info(kind) -> ExpertInfo:
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize,
- create_flashinfer_prepare_finalize,
)
register_prepare_and_finalize(
@@ -389,13 +388,12 @@ def make_prepare_finalize(
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
- prepare_finalize = maybe_make_prepare_finalize(moe, quant_config)
+ # TODO(rob): add defer input quant.
+ prepare_finalize = maybe_make_prepare_finalize(
+ moe, quant_config, allow_new_interface=True
+ )
assert prepare_finalize is not None
return prepare_finalize
- elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
- return create_flashinfer_prepare_finalize(
- use_dp=moe.moe_parallel_config.dp_size > 1
- )
else:
return MoEPrepareAndFinalizeNoEP()
diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py
index bb2f6b873941..2fa00a8effb7 100644
--- a/tests/kernels/moe/test_flashinfer.py
+++ b/tests/kernels/moe/test_flashinfer.py
@@ -127,6 +127,7 @@ def make_moe_tensors_8bit(
ep_rank=0,
use_ep=False,
all2all_backend="naive",
+ isequence_parallel=False,
)
# flashinfer expects swapped rows for w13
diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py
index 1262eea70bab..d9b7f4678534 100644
--- a/tests/kernels/moe/test_flashinfer_moe.py
+++ b/tests/kernels/moe/test_flashinfer_moe.py
@@ -12,15 +12,19 @@
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
is_valid_flashinfer_cutlass_fused_moe,
)
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (
- create_flashinfer_prepare_finalize,
-)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
+from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNoEP,
+)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import set_random_seed
@@ -86,9 +90,39 @@ def test_flashinfer_fp4_moe_no_graph(
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
+ moe_config = FusedMoEConfig(
+ num_experts=e,
+ experts_per_token=topk,
+ hidden_dim=k,
+ intermediate_size_per_partition=n,
+ num_local_experts=e,
+ activation=activation,
+ device="cuda",
+ moe_parallel_config=FusedMoEParallelConfig(
+ tp_size=1,
+ pcp_size=1,
+ dp_size=1,
+ ep_size=1,
+ tp_rank=0,
+ pcp_rank=0,
+ dp_rank=0,
+ ep_rank=0,
+ use_ep=False,
+ all2all_backend="allgather_reducescatter",
+ isequence_parallel=False,
+ ),
+ in_dtype=dtype,
+ is_act_and_mul=is_gated_act,
+ )
+
flashinfer_experts = FusedMoEModularKernel(
- create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True),
- FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
+ MoEPrepareAndFinalizeNoEP(
+ defer_input_quant=FlashInferExperts.should_pf_defer_input_quant(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ )
+ ),
+ FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
)
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py
index 873d72117de7..97db526c2ee4 100644
--- a/tests/kernels/moe/test_nvfp4_moe.py
+++ b/tests/kernels/moe/test_nvfp4_moe.py
@@ -93,7 +93,6 @@ def test_cutlass_fp4_moe_no_graph(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
out_dtype=dtype,
- max_experts_per_worker=e,
quant_config=quant_config,
),
)
diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py
index f0ce5b3db7ec..8530c0dadb23 100644
--- a/vllm/compilation/activation_quant_fusion.py
+++ b/vllm/compilation/activation_quant_fusion.py
@@ -18,7 +18,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
- kNvfp4Quant,
+ kNvfp4Dynamic,
)
from vllm.platforms import current_platform
@@ -41,7 +41,7 @@
torch.ops._C, "silu_and_mul_nvfp4_quant"
)
if silu_and_mul_nvfp4_quant_supported:
- FUSED_OPS[kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
+ FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
class ActivationQuantPattern(ABC):
@@ -129,7 +129,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
"""
def __init__(self) -> None:
- super().__init__(kNvfp4Quant)
+ super().__init__(kNvfp4Dynamic)
def get_inputs(self) -> list[torch.Tensor]:
result = self.empty_quant(5, 32)
diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py
index e3c6c2f20141..667828cc6a5a 100644
--- a/vllm/compilation/fusion.py
+++ b/vllm/compilation/fusion.py
@@ -20,7 +20,7 @@
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
- kNvfp4Quant,
+ kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
@@ -63,7 +63,7 @@ def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
- QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
+ QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py
index 57448aa0b096..69dc2e3a6868 100644
--- a/vllm/compilation/fusion_attn.py
+++ b/vllm/compilation/fusion_attn.py
@@ -16,7 +16,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
- kNvfp4Quant,
+ kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
@@ -217,7 +217,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""
def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
- super().__init__(layer, kNvfp4Quant, dtype)
+ super().__init__(layer, kNvfp4Dynamic, dtype)
def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py
index eda12180d493..7bb98db5e9f6 100644
--- a/vllm/compilation/matcher_utils.py
+++ b/vllm/compilation/matcher_utils.py
@@ -21,7 +21,7 @@
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
- kNvfp4Quant,
+ kNvfp4Dynamic,
)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
@@ -38,7 +38,7 @@
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
- QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
+ QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py
index 363078aefc51..ae90d9bd5201 100644
--- a/vllm/config/parallel.py
+++ b/vllm/config/parallel.py
@@ -425,6 +425,7 @@ def stateless_init_dp_group(self) -> ProcessGroup:
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
+ # TODO(rob): investigate 'flashinfer_all2allv'?
@property
def use_sequence_parallel_moe(self) -> bool:
return (
diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py
index 7a4e81cf967d..520bb7613ee9 100644
--- a/vllm/distributed/device_communicators/all2all.py
+++ b/vllm/distributed/device_communicators/all2all.py
@@ -62,10 +62,14 @@ def naive_multicast(
def dispatch(
self,
hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> (
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ ):
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
@@ -78,11 +82,21 @@ def dispatch(
hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
- router_logits = self.naive_multicast(
- router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
+ topk_weights = self.naive_multicast(
+ topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
)
+ topk_ids = self.naive_multicast(
+ topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
+ )
+
+ if extra_tensors is None:
+ return hidden_states, topk_weights, topk_ids
- return hidden_states, router_logits
+ extra_tensors = [
+ self.naive_multicast(t, cu_tokens_across_sp_cpu, is_sequence_parallel)
+ for t in extra_tensors
+ ]
+ return hidden_states, topk_weights, topk_ids, extra_tensors
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
@@ -117,12 +131,13 @@ def __init__(self, cpu_group):
def dispatch(
self,
hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
- tuple[torch.Tensor, torch.Tensor]
- | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
@@ -134,7 +149,7 @@ def dispatch(
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
- tensors_to_gather = [hidden_states, router_logits]
+ tensors_to_gather = [hidden_states, topk_weights, topk_ids]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
@@ -144,9 +159,14 @@ def dispatch(
sizes=sizes,
)
- if extra_tensors is not None:
- return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
- return gathered_tensors[0], gathered_tensors[1]
+ hidden_states = gathered_tensors[0]
+ topk_weights = gathered_tensors[1]
+ topk_ids = gathered_tensors[2]
+
+ if extra_tensors is None:
+ return hidden_states, topk_weights, topk_ids
+
+ return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
@@ -219,10 +239,11 @@ def get_handle(self, kwargs):
def dispatch(
self,
hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
raise NotImplementedError
def combine(
@@ -267,10 +288,11 @@ def get_handle(self, kwargs):
def dispatch(
self,
hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
raise NotImplementedError
def combine(
diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py
index 8bc361741cae..4cc6bc8fef05 100644
--- a/vllm/distributed/device_communicators/base_device_communicator.py
+++ b/vllm/distributed/device_communicators/base_device_communicator.py
@@ -67,7 +67,8 @@ def get_handle(self, kwargs):
def dispatch(
self,
hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> Any:
@@ -283,15 +284,16 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None
def dispatch(
self,
hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
- return hidden_states, router_logits
+ return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py
index 9542498c453e..6d725be241e1 100644
--- a/vllm/distributed/device_communicators/cuda_communicator.py
+++ b/vllm/distributed/device_communicators/cuda_communicator.py
@@ -321,17 +321,19 @@ def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None):
def dispatch( # type: ignore[override]
self,
hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
- tuple[torch.Tensor, torch.Tensor]
- | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
- router_logits,
+ topk_weights,
+ topk_ids,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
)
diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py
index f3d9262d20cf..d7a9d832370f 100644
--- a/vllm/distributed/device_communicators/xpu_communicator.py
+++ b/vllm/distributed/device_communicators/xpu_communicator.py
@@ -76,14 +76,16 @@ def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
def dispatch(
self,
hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
- router_logits,
+ topk_weights,
+ topk_ids,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
)
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index d8c6ceba309f..2f1e4d005570 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -1003,22 +1003,24 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
def dispatch(
self,
hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
- tuple[torch.Tensor, torch.Tensor]
- | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch( # type: ignore[call-arg]
hidden_states,
- router_logits,
+ topk_weights,
+ topk_ids,
is_sequence_parallel,
extra_tensors,
)
else:
- return hidden_states, router_logits
+ return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states, is_sequence_parallel: bool = False
diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py
index 036b3cac4cb3..43f75cc31428 100644
--- a/vllm/model_executor/layers/fused_moe/all2all_utils.py
+++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py
@@ -12,9 +12,16 @@
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
+from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (
+ FlashInferAllToAllMoEPrepareAndFinalize,
+)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
)
+from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNaiveEP,
+ MoEPrepareAndFinalizeNoEP,
+)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_pplx
@@ -68,20 +75,21 @@ def maybe_make_prepare_finalize(
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
+ defer_input_quant: bool = False,
+ allow_new_interface: bool = False,
) -> FusedMoEPrepareAndFinalize | None:
if not moe.moe_parallel_config.use_all2all_kernels:
- return None
+ # What happens if DP>1 and EP=1?
+ if allow_new_interface:
+ return MoEPrepareAndFinalizeNoEP(defer_input_quant)
+ else:
+ return None
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
- # TODO(rob): update this as part of the MoE refactor.
- assert not moe.use_flashinfer_cutlass_kernels, (
- "Must be created in modelopt.py or fp8.py"
- )
-
if moe.use_pplx_kernels:
assert quant_config is not None
@@ -170,4 +178,20 @@ def maybe_make_prepare_finalize(
local_expert_global_ids=local_expert_global_ids,
)
+ elif moe.use_fi_all2allv_kernels:
+ assert quant_config is not None
+ # TODO: audit if this supports all cases.
+ prepare_finalize = FlashInferAllToAllMoEPrepareAndFinalize(
+ use_dp=True,
+ num_dispatchers=all2all_manager.world_size,
+ use_deepseek_fp8_block_scale=quant_config.is_block_quantized,
+ )
+
+ elif moe.use_naive_kernels and allow_new_interface:
+ prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
+ defer_input_quant,
+ is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
+ num_dispatchers=all2all_manager.world_size,
+ )
+
return prepare_finalize
diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
index e598ec3acb3d..0d82eec3569c 100644
--- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
@@ -7,11 +7,20 @@
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8Dynamic128Sym,
+ kFp8Static128BlockSym,
+)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
@@ -19,6 +28,7 @@
fp8_m_grouped_gemm_nt_masked,
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
+ is_deep_gemm_supported,
)
from vllm.utils.math_utils import cdiv, round_up
@@ -253,8 +263,7 @@ def persistent_masked_m_silu_mul_quant(
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
- max_num_tokens: int,
- num_dispatchers: int,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
"""
@@ -262,20 +271,37 @@ def __init__(
num_dispatchers: The number of DP dispatchers.
quant_config: Quantization configuration
"""
- super().__init__(quant_config)
+ super().__init__(moe_config, quant_config)
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
assert self.quant_config.use_fp8_w8a8
- self.max_num_tokens = max_num_tokens
- self.num_dispatchers = num_dispatchers
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.BatchedExperts,
- mk.FusedMoEActivationFormat.BatchedExperts,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.BatchedExperts
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return is_deep_gemm_supported()
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ SUPPORTED_W_A = [(kFp8Static128BlockSym, kFp8Dynamic128Sym)]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
def supports_chunking(self) -> bool:
return False
diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py
index 23b86fdca898..b5cd5167e7e1 100644
--- a/vllm/model_executor/layers/fused_moe/config.py
+++ b/vllm/model_executor/layers/fused_moe/config.py
@@ -20,7 +20,6 @@
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
-from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import cdiv
@@ -858,6 +857,7 @@ class FusedMoEParallelConfig:
use_ep: bool # whether to use EP or not
all2all_backend: str # all2all backend for MoE communication
+ is_sequence_parallel: bool # whether sequence parallelism is used
@property
def use_all2all_kernels(self):
@@ -878,6 +878,18 @@ def use_deepep_ht_kernels(self):
def use_deepep_ll_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
+ @property
+ def use_fi_all2allv_kernels(self):
+ return (
+ self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv"
+ )
+
+ @property
+ def use_naive_kernels(self):
+ return self.use_all2all_kernels and (
+ self.all2all_backend in ["allgather_reducescatter", "naive"]
+ )
+
@staticmethod
def flatten_tp_across_dp_and_pcp(
tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int
@@ -995,6 +1007,7 @@ def make(
ep_rank=0,
use_ep=False,
all2all_backend=vllm_parallel_config.all2all_backend,
+ is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
)
# DP + EP / TP + EP / DP + TP + EP
assert use_ep
@@ -1013,6 +1026,7 @@ def make(
ep_rank=ep_rank,
use_ep=True,
all2all_backend=vllm_parallel_config.all2all_backend,
+ is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
)
@@ -1022,8 +1036,10 @@ class FusedMoEConfig:
num_experts: int
experts_per_token: int
hidden_dim: int
-
+ intermediate_size_per_partition: int
num_local_experts: int
+ activation: str
+ device: torch.dtype
moe_parallel_config: FusedMoEParallelConfig
# The activation type.
@@ -1100,12 +1116,9 @@ def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
@property
- def use_flashinfer_cutlass_kernels(self):
- """
- Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
- """
- return (
- envs.VLLM_USE_FLASHINFER_MOE_FP4
- and has_flashinfer_cutlass_fused_moe()
- and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput"
- )
+ def use_fi_all2allv_kernels(self):
+ return self.moe_parallel_config.use_fi_all2allv_kernels
+
+ @property
+ def use_naive_kernels(self):
+ return self.moe_parallel_config.use_naive_kernels
diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
index c0ffa38fdb2c..e41fa3f7d50c 100644
--- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
@@ -9,7 +9,11 @@
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute,
moe_unpermute,
@@ -22,6 +26,16 @@
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8DynamicTensorSym,
+ kFp8DynamicTokenSym,
+ kFp8StaticChannelSym,
+ kFp8StaticTensorSym,
+)
+from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
+ cutlass_group_gemm_supported,
+)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@@ -236,29 +250,54 @@ def run_cutlass_moe_fp8(
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
- e: int,
- n: int,
- k: int,
- out_dtype: torch.dtype | None,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
- device: torch.dtype,
):
assert quant_config.use_fp8_w8a8
- super().__init__(quant_config)
+ super().__init__(moe_config=moe_config, quant_config=quant_config)
- # E: num_experts
- # N: intermediate size per partition
- # K: hidden dim
+ e = moe_config.num_local_experts
+ n = moe_config.intermediate_size_per_partition
+ k = moe_config.hidden_dim
+ device = moe_config.device
ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64)
ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64)
c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64)
- self.out_dtype = out_dtype
+ self.out_dtype = moe_config.in_dtype
self.ab_strides1 = ab_strides1_c_strides2
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = ab_strides1_c_strides2
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return cutlass_group_gemm_supported()
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ SUPPORTED_W_A = [
+ (kFp8StaticChannelSym, kFp8DynamicTokenSym),
+ (kFp8StaticTensorSym, kFp8DynamicTensorSym),
+ (kFp8StaticTensorSym, kFp8StaticTensorSym),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
+
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
@@ -291,7 +330,7 @@ def apply(
activation_callable = lambda o, i: self.activation(activation, o, i)
use_batched_format = (
- self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
+ self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts
)
in_dtype = hidden_states.dtype
@@ -324,14 +363,9 @@ def apply(
class CutlassExpertsFp8(CutlassExpertsFp8Base):
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return True
@@ -365,26 +399,9 @@ def workspace_shapes(
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
- def __init__(
- self,
- max_experts_per_worker: int,
- num_dispatchers: int,
- *args,
- **kwargs,
- ):
- super().__init__(*args, **kwargs)
- assert max_experts_per_worker > 0
- self.max_experts_per_worker = max_experts_per_worker
- self.num_dispatchers = num_dispatchers
-
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.BatchedExperts,
- mk.FusedMoEActivationFormat.BatchedExperts,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.BatchedExperts
def supports_chunking(self) -> bool:
return False
@@ -408,14 +425,15 @@ def workspace_shapes(
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
assert num_dp is not None
+ experts_per_worker = self.moe_config.num_local_experts
activation_out_dim = self.adjust_N_for_activation(N, activation)
- workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K))
+ workspace1 = (experts_per_worker, M * num_dp, max(N, K))
workspace2 = (
- self.max_experts_per_worker,
+ experts_per_worker,
M * num_dp,
max(activation_out_dim, K),
)
- output = (self.max_experts_per_worker, M, K)
+ output = (experts_per_worker, M, K)
return (workspace1, workspace2, output)
@@ -597,30 +615,16 @@ def run_cutlass_moe_fp4(
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
- max_experts_per_worker: int,
- out_dtype: torch.dtype,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
- use_batched_format: bool = False,
):
- super().__init__(quant_config)
- self.max_experts_per_worker = max_experts_per_worker
- self.out_dtype = out_dtype
- self.use_batched_format = use_batched_format
+ super().__init__(moe_config=moe_config, quant_config=quant_config)
+ self.max_experts_per_worker = moe_config.num_local_experts
+ self.out_dtype = moe_config.in_dtype
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- if self.use_batched_format:
- return (
- mk.FusedMoEActivationFormat.BatchedExperts,
- mk.FusedMoEActivationFormat.BatchedExperts,
- )
- else:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
def supports_expert_map(self) -> bool:
return False
@@ -649,7 +653,8 @@ def workspace_shapes(
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = ()
- if self.use_batched_format:
+ if False:
+ # if self.use_batched_format:
workspace1 = (self.max_experts_per_worker, M, max(N, K))
workspace2 = (self.max_experts_per_worker, M, activation_out_dim)
output = (self.max_experts_per_worker, M, K)
@@ -860,10 +865,11 @@ def __init__(
c_strides2: torch.Tensor,
s_strides1: torch.Tensor,
s_strides2: torch.Tensor,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
group_size: int,
):
- super().__init__(quant_config)
+ super().__init__(moe_config=moe_config, quant_config=quant_config)
self.out_dtype = out_dtype
self.a_strides1 = a_strides1
self.a_strides2 = a_strides2
@@ -875,13 +881,46 @@ def __init__(
self.s_strides2 = s_strides2
self.group_size = group_size
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ raise NotImplementedError(
+ "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ raise NotImplementedError(
+ "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ raise NotImplementedError(
+ "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ raise NotImplementedError(
+ "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ raise NotImplementedError(
+ "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
+ "This method should not be called."
)
def supports_chunking(self) -> bool:
@@ -939,7 +978,7 @@ def apply(
activation_callable = lambda o, i: self.activation(activation, o, i)
use_batched_format = (
- self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
+ self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts
)
assert not use_batched_format, "batched format not supported"
@@ -995,6 +1034,7 @@ def cutlass_moe_w4a8_fp8(
s_strides1: torch.Tensor,
s_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
+ moe_config: FusedMoEConfig,
activation: str = "silu",
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
@@ -1068,6 +1108,7 @@ def cutlass_moe_w4a8_fp8(
c_strides2=c_strides2,
s_strides1=s_strides1,
s_strides2=s_strides2,
+ moe_config=moe_config,
quant_config=quant_config,
group_size=group_size,
),
diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
index a2e5a07fbfd2..a86f39e96959 100644
--- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@@ -6,17 +6,15 @@
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
- fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
compute_aligned_M,
deepgemm_moe_permute,
deepgemm_unpermute_and_reduce,
)
-from vllm.model_executor.layers.fused_moe.prepare_finalize import (
- MoEPrepareAndFinalizeNoEP,
-)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
@@ -26,9 +24,15 @@
per_token_group_quant_fp8_packed_for_deepgemm,
silu_mul_per_token_group_quant_fp8_colmajor,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8Dynamic128Sym,
+ kFp8Static128BlockSym,
+)
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
get_mk_alignment_for_contiguous_layout,
+ is_deep_gemm_supported,
m_grouped_fp8_gemm_nt_contiguous,
)
from vllm.utils.import_utils import has_deep_gemm
@@ -109,21 +113,42 @@ def _valid_deep_gemm(
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
- def __init__(self, quant_config: FusedMoEQuantConfig):
- super().__init__(quant_config)
+ def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
+ super().__init__(moe_config=moe_config, quant_config=quant_config)
assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
assert quant_config.quant_dtype == torch.float8_e4m3fn
assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return is_deep_gemm_supported()
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ SUPPORTED_W_A = [
+ (kFp8Static128BlockSym, kFp8Dynamic128Sym),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
def supports_chunking(self) -> bool:
return True
@@ -338,27 +363,27 @@ def deep_gemm_moe_fp8(
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
- quant_config = fp8_w8a8_moe_quant_config(
- w1_scale=w1_scale,
- w2_scale=w2_scale,
- a1_scale=a1_scale,
- a2_scale=a2_scale,
- block_shape=get_mk_alignment_for_contiguous_layout(),
- )
-
- fn = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- DeepGemmExperts(quant_config),
- )
- return fn(
- hidden_states,
- w1,
- w2,
- topk_weights,
- topk_ids,
- inplace=inplace,
- activation=activation,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- apply_router_weight_on_input=apply_router_weight_on_input,
- )
+ # quant_config = fp8_w8a8_moe_quant_config(
+ # w1_scale=w1_scale,
+ # w2_scale=w2_scale,
+ # a1_scale=a1_scale,
+ # a2_scale=a2_scale,
+ # block_shape=get_mk_alignment_for_contiguous_layout(),
+ # )
+
+ # fn = mk.FusedMoEModularKernel(
+ # MoEPrepareAndFinalizeNoEP(),
+ # DeepGemmExperts(quant_config),
+ # )
+ # return fn(
+ # hidden_states,
+ # w1,
+ # w2,
+ # topk_weights,
+ # topk_ids,
+ # inplace=inplace,
+ # activation=activation,
+ # global_num_experts=global_num_experts,
+ # expert_map=expert_map,
+ # apply_router_weight_on_input=apply_router_weight_on_input,
+ # )
diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py
index 4556392144a0..9c9e416f366a 100644
--- a/vllm/model_executor/layers/fused_moe/fallback.py
+++ b/vllm/model_executor/layers/fused_moe/fallback.py
@@ -6,28 +6,73 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
+from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
"""Base class for runtime dispatching of expert implementations."""
+ _experts_cls = mk.FusedMoEPermuteExpertsUnpermute
+ _fallback_cls = mk.FusedMoEPermuteExpertsUnpermute
+
def __init__(
self,
experts: mk.FusedMoEPermuteExpertsUnpermute,
fallback_experts: mk.FusedMoEPermuteExpertsUnpermute,
):
- super().__init__(experts.quant_config)
+ super().__init__(
+ moe_config=experts.moe_config, quant_config=experts.quant_config
+ )
self.fallback_experts = fallback_experts
self.experts = experts
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
assert (
- self.fallback_experts.activation_formats == self.experts.activation_formats
+ FallbackExperts._experts_cls.activation_format()
+ == FallbackExperts._fallback_cls.activation_format()
+ )
+ return FallbackExperts._experts_cls.activation_format()
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return (
+ FallbackExperts._experts_cls._supports_current_device()
+ and FallbackExperts._fallback_cls._supports_current_device()
+ )
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return (
+ FallbackExperts._experts_cls._supports_no_act_and_mul()
+ and FallbackExperts._fallback_cls._supports_no_act_and_mul()
+ )
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ return FallbackExperts._experts_cls._supports_quant_scheme(
+ weight_key, activation_key
+ ) and FallbackExperts._fallback_cls._supports_quant_scheme(
+ weight_key, activation_key
+ )
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return FallbackExperts._experts_cls._supports_activation(
+ activation
+ ) and FallbackExperts._fallback_cls._supports_activation(activation)
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return FallbackExperts._experts_cls._supports_parallel_config(
+ moe_parallel_config
+ ) and FallbackExperts._fallback_cls._supports_parallel_config(
+ moe_parallel_config
)
- return self.fallback_experts.activation_formats
def supports_chunking(self) -> bool:
assert (
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
index 1651f3530eef..82ba560f63d3 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
@@ -6,13 +6,22 @@
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kNvfp4Dynamic,
+ kNvfp4Static,
+)
+from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
flashinfer_cutedsl_grouped_gemm_nt_masked,
- has_flashinfer_cutedsl_grouped_gemm_nt_masked,
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
)
@@ -20,54 +29,48 @@
logger = init_logger(__name__)
-def is_valid_flashinfer_cutedsl_fused_moe(
- hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor
-) -> bool:
- """
- Check if the given problem size is supported by the FlashInfer CuteDSL MoE
- kernel.
- """
- if not has_flashinfer_cutedsl_grouped_gemm_nt_masked():
- logger.debug_once(
- "FlashInferCuteDSLExperts disabled: "
- "flashinfer_cutedsl_fused_moe not available."
- )
- return False
- # Data type checks
- if (
- w1.dtype != torch.uint8
- or w2.dtype != torch.uint8
- or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16]
- ):
- logger.debug_once(
- "FlashInferCuteDSLExperts disabled: w1/w2 must be torch.uint8 "
- f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
- f"float32, float16, or bfloat16 (got {hidden_states.dtype})."
- )
- return False
- return True
-
-
class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
- out_dtype: torch.dtype,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
- super().__init__(quant_config)
+ super().__init__(moe_config, quant_config)
assert quant_config.quant_dtype == "nvfp4", (
"Only nvfp4 quantization are currently supported."
)
- self.out_dtype = out_dtype
+ self.out_dtype = moe_config.in_dtype
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.BatchedExperts,
- mk.FusedMoEActivationFormat.BatchedExperts,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.BatchedExperts
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ # TODO: add check cutedsl support?
+ return current_platform.has_device_capability((10, 0))
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ SUPPORTED_W_A = [
+ (kNvfp4Static, kNvfp4Dynamic),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
def supports_expert_map(self) -> bool:
return False
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
index ae60e15db841..5d68e8f0cce9 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
@@ -5,13 +5,22 @@
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
- create_flashinfer_prepare_finalize,
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8Dynamic128Sym,
+ kFp8Static128BlockSym,
+ kFp8StaticTensorSym,
+ kNvfp4Dynamic,
+ kNvfp4Static,
+)
+from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
flashinfer_cutlass_fused_moe,
has_flashinfer_cutlass_fused_moe,
@@ -50,40 +59,92 @@ def is_valid_flashinfer_cutlass_fused_moe(
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
- out_dtype: torch.dtype,
+ moe_config: mk.FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
- ep_rank: int = 0,
- ep_size: int = 1,
- tp_rank: int = 0,
- tp_size: int = 1,
- use_dp: bool = False,
- use_deepseek_fp8_block_scale: bool = False,
):
- super().__init__(quant_config)
+ super().__init__(moe_config, quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
"Only nvfp4, fp8, bfloat16 and"
" float16 quantization are currently supported."
)
- self.ep_rank = ep_rank
- self.ep_size = ep_size
- self.tp_rank = tp_rank
- self.tp_size = tp_size
- self.out_dtype = out_dtype
- self.use_dp = use_dp
+ self.ep_rank = moe_config.moe_parallel_config.ep_rank
+ self.ep_size = moe_config.moe_parallel_config.ep_size
+ self.tp_rank = moe_config.moe_parallel_config.tp_rank
+ self.tp_size = moe_config.moe_parallel_config.tp_size
+ self.out_dtype = moe_config.in_dtype
+ self.use_dp = moe_config.moe_parallel_config.dp_size > 1
# Enables DeepSeek-style FP8 block-scale path:
# - pass per-block weight scales to the kernel
# - skip input activation quantization (kernel applies scaling)
- self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
+ self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
+ @staticmethod
+ def should_pf_defer_input_quant(
+ moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
+ ) -> bool:
+ # NVFP4 TP kernels and FP8 block-quantized kernels apply
+ # input quantization inside FusedMoEPermuteExpertsUnpermute.
+ return (
+ quant_config.use_nvfp4_w4a4
+ and not moe_config.moe_parallel_config.use_all2all_kernels
+ ) or (quant_config.use_fp8_w8a8 and quant_config.is_block_quantized)
+
+ @staticmethod
+ def _supports_current_device() -> bool:
return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
+ current_platform.is_cuda()
+ # Is this right? Or 9.0+10.0 only?
+ and current_platform.has_device_capability((9, 0))
+ and has_flashinfer_cutlass_fused_moe()
)
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ # The following are supported by FlashInferExperts:
+ # * unquantized
+ # * fp8 static per-tensor on 9.0+
+ # * fp8 block on 9.0
+ # * nvfp4 on 10.0+
+
+ p = current_platform
+ scheme = (weight_key, activation_key)
+ return (
+ (
+ scheme
+ in [
+ (None, None),
+ (kFp8StaticTensorSym, kFp8StaticTensorSym),
+ ]
+ )
+ or (
+ (scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym))
+ and (p.is_device_capability((9, 0)))
+ )
+ or (
+ (scheme == (kNvfp4Static, kNvfp4Dynamic))
+ and (p.is_device_capability((10, 0))) # GB?
+ )
+ )
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu", "relu2_no_mul"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
+
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
def supports_expert_map(self) -> bool:
return False
@@ -226,85 +287,3 @@ def apply(
# Informs FlashInfer to use the block-scale decoding path when True
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
)
-
-
-def flashinfer_cutlass_moe_fp4(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- quant_config: FusedMoEQuantConfig,
- inplace: bool = False,
- activation: str = "silu",
- global_num_experts: int = -1,
- expert_map: torch.Tensor | None = None,
- apply_router_weight_on_input: bool = False,
-) -> torch.Tensor:
- fused_experts = mk.FusedMoEModularKernel(
- create_flashinfer_prepare_finalize(
- use_dp=False, use_nvfp4=True, enable_alltoallv=False
- ),
- FlashInferExperts(
- out_dtype=hidden_states.dtype,
- quant_config=quant_config,
- use_dp=False,
- ),
- )
-
- return fused_experts(
- hidden_states=hidden_states,
- w1=w1,
- w2=w2,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=inplace,
- activation=activation,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- apply_router_weight_on_input=apply_router_weight_on_input,
- )
-
-
-def flashinfer_cutlass_moe(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- quant_config: FusedMoEQuantConfig,
- inplace: bool = False,
- activation: str = "silu",
- global_num_experts: int = -1,
- expert_map: torch.Tensor | None = None,
- apply_router_weight_on_input: bool = False,
- tp_rank: int = 0,
- tp_size: int = 1,
- ep_rank: int = 0,
- ep_size: int = 1,
- use_dp: bool = False,
-) -> torch.Tensor:
- fused_experts = mk.FusedMoEModularKernel(
- create_flashinfer_prepare_finalize(use_dp=use_dp),
- FlashInferExperts(
- out_dtype=hidden_states.dtype,
- quant_config=quant_config,
- tp_rank=tp_rank,
- tp_size=tp_size,
- ep_rank=ep_rank,
- ep_size=ep_size,
- ),
- )
-
- return fused_experts(
- hidden_states=hidden_states,
- w1=w1,
- w2=w2,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=inplace,
- activation=activation,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- apply_router_weight_on_input=apply_router_weight_on_input,
- )
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
index dfff860750d6..44f7b3b9bc0a 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
@@ -4,18 +4,12 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from vllm.distributed import get_dp_group, get_ep_group
+from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.base_device_communicator import (
All2AllManagerBase,
)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
-from vllm.model_executor.layers.fused_moe.prepare_finalize import (
- MoEPrepareAndFinalizeNoEP,
-)
-from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
- TopKWeightAndReduceNoOP,
-)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
@@ -162,90 +156,6 @@ def finalize(
output.copy_(fused_expert_output)
-class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize):
- def __init__(
- self,
- use_dp: bool,
- num_dispatchers: int = 1,
- use_deepseek_fp8_block_scale: bool = False,
- ):
- super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale)
-
- def prepare(
- self,
- a1: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- num_experts: int,
- expert_map: torch.Tensor | None,
- apply_router_weight_on_input: bool,
- quant_config: FusedMoEQuantConfig,
- ) -> mk.PrepareResultType:
- self._apply_router_weight_on_input(
- a1, topk_weights, topk_ids, apply_router_weight_on_input
- )
- is_nvfp4 = quant_config.quant_dtype == "nvfp4"
- if not self.use_dp and is_nvfp4:
- return a1, None, None, topk_ids, topk_weights
-
- if not self.use_deepseek_fp8_block_scale:
- a1q, a1q_scale = moe_kernel_quantize_input(
- a1,
- quant_config.a1_gscale if is_nvfp4 else quant_config.a1_scale,
- quant_config.quant_dtype,
- quant_config.per_act_token_quant,
- quant_config.block_shape,
- is_fp4_scale_swizzled=not self.use_dp,
- )
- else:
- # Block-scale path: pass activations through, omit per-token scales
- a1q = a1
- a1q_scale = None
-
- if self.use_dp:
- # Build gather list conditionally - omit a1q_scale if None
- # (block-scale path)
- gather_list = [topk_weights, topk_ids, a1q]
- if a1q_scale is not None:
- gather_list.append(a1q_scale)
- gathered = get_dp_group().all_gatherv(
- gather_list,
- dim=0,
- sizes=get_local_sizes(),
- )
- topk_weights, topk_ids, a1q, a1q_scale = gathered
- else:
- gathered = get_dp_group().all_gatherv(
- gather_list,
- dim=0,
- sizes=get_local_sizes(),
- )
- topk_weights, topk_ids, a1q = gathered
- a1q_scale = None
-
- if is_nvfp4 and a1q_scale is not None:
- a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
-
- return a1q, a1q_scale, None, topk_ids, topk_weights
-
- def finalize(
- self,
- output: torch.Tensor,
- fused_expert_output: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- apply_router_weight_on_input: bool,
- weight_and_reduce_impl: mk.TopKWeightAndReduce,
- ) -> None:
- assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceNoOP)
-
- if self.use_dp:
- fused_expert_output = get_dp_group().reduce_scatterv(
- fused_expert_output, dim=0, sizes=get_local_sizes()
- )
- output.copy_(fused_expert_output)
-
-
def flashinfer_alltoall_dispatch(
all2all_manager: All2AllManagerBase,
global_num_tokens_cpu: list[int],
@@ -346,26 +256,3 @@ def flashinfer_alltoall_combine(
top_k=top_k,
token_count=token_count,
)
-
-
-def create_flashinfer_prepare_finalize(
- use_dp: bool,
- use_nvfp4: bool = False,
- enable_alltoallv: bool = False,
- use_deepseek_fp8_block_scale: bool = False,
-) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP:
- """Factory function to create the appropriate FlashInfer implementation."""
-
- if use_dp:
- if enable_alltoallv:
- assert use_nvfp4
- return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
- return FlashInferAllGatherMoEPrepareAndFinalize(
- use_dp=True,
- use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
- )
- else:
- # CUTLASS FP8 BLOCK and CUTLASS NVFP4 apply input quantization
- # in a single call with the MoE experts kernel.
- defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4
- return MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant)
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
index 3bb5a23abb7b..30af9d912dff 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
@@ -3,7 +3,12 @@
import torch
-from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ RoutingMethodType,
+)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim,
@@ -11,8 +16,82 @@
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8Dynamic128Sym,
+ kFp8Static128BlockSym,
+ kFp8StaticTensorSym,
+)
+from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
+#
+# Methods used by the oracle for kernel selection.
+#
+
+
+def _supports_current_device() -> bool:
+ """Supports only Blackwell-family GPUs."""
+ p = current_platform
+ # Add check flashinfer trtllm is available
+ return p.is_cuda() and p.is_device_capability_family(100)
+
+
+def _supports_no_act_and_mul() -> bool:
+ """Does not support non-gated MoE (i.e. Nanotron-Mini)."""
+ return False
+
+
+def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+) -> bool:
+ """Supports Fp8 per-tensor and Fp8 block."""
+ SUPPORTED_W_A = [
+ (kFp8Static128BlockSym, kFp8Dynamic128Sym),
+ (kFp8StaticTensorSym, kFp8StaticTensorSym),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+
+def _supports_activation(activation: str) -> bool:
+ """Supports silu activation only."""
+ return activation in ["silu"]
+
+
+def _supports_moe_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ """Supports EP."""
+ return True
+
+
+def is_supported_config_trtllm(
+ moe_config: FusedMoEConfig,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ activation_format: mk.FusedMoEActivationFormat,
+) -> tuple[bool, str | None]:
+ """
+ This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
+ """
+
+ def _make_reason(reason: str) -> str:
+ return f"kernel does not support {reason}"
+
+ if not _supports_current_device():
+ return False, _make_reason("current device")
+ elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
+ return False, _make_reason("no act_and_mul MLP layer")
+ elif not _supports_activation(moe_config.activation):
+ return False, _make_reason(f"{moe_config.activation} activation")
+ elif not _supports_quant_scheme(weight_key, activation_key):
+ return False, _make_reason("quantization scheme")
+ elif not _supports_moe_parallel_config(moe_config.moe_parallel_config):
+ return False, _make_reason("parallel config")
+ elif activation_format != mk.FusedMoEActivationFormat.Standard:
+ return False, _make_reason("activation format")
+
+ return True, None
+
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
index fb93464392ea..fbb30d408457 100644
--- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
@@ -5,7 +5,11 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
@@ -17,7 +21,11 @@
normalize_batched_scales_shape,
normalize_scales_shape,
)
-from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ group_broadcast,
+)
+from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
@@ -633,26 +641,41 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
- max_num_tokens: int,
- num_dispatchers: int,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
- super().__init__(quant_config)
+ super().__init__(moe_config, quant_config)
assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI"
assert self.quant_config.ocp_mx_scheme is None, "NYI"
- self.max_num_tokens = max_num_tokens
- self.num_dispatchers = num_dispatchers
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.BatchedExperts,
- mk.FusedMoEActivationFormat.BatchedExperts,
- )
+ @staticmethod
+ def activation_formats() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.BatchedExperts
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return True
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ return False
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu", "swigluoai"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
def supports_chunking(self) -> bool:
return False
@@ -826,28 +849,48 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
- max_num_tokens: int,
- num_dispatchers: int,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
- super().__init__(quant_config)
+ super().__init__(moe_config, quant_config)
assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI"
assert self.quant_config.ocp_mx_scheme is None, "NYI"
- assert max_num_tokens > 0
- assert num_dispatchers > 0
- self.max_num_tokens = max_num_tokens
- self.num_dispatchers = num_dispatchers
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.BatchedExperts,
- mk.FusedMoEActivationFormat.BatchedExperts,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.BatchedExperts
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return current_platform.is_cuda_alike()
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ # p = current_platform
+ # device_supports_fp8 = p.is_rocm() or (
+ # p.is_cuda() and p.has_device_capability((9, 0))
+ # )
+ # return quant_scheme.is_unquantized or (
+ # quant_scheme.is_fp8_w8a8 and device_supports_fp8
+ # )
+ return False
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu", "gelu", "swigluoai"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
def supports_chunking(self) -> bool:
return False
diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
index 77c6b97eaea3..5b552dc93287 100644
--- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
@@ -8,7 +8,11 @@
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size,
moe_align_block_size,
@@ -23,6 +27,12 @@
marlin_moe_intermediate_size,
marlin_quant_input,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8Static128BlockSym,
+ kFp8StaticChannelSym,
+ kNvfp4Static,
+)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
@@ -530,6 +540,7 @@ def batched_fused_marlin_moe(
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None,
@@ -549,7 +560,38 @@ def __init__(
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
self.is_k_full = is_k_full
- super().__init__(quant_config)
+ super().__init__(moe_config, quant_config)
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ p = current_platform
+ return p.is_cuda() and p.has_device_capability((7, 5))
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ # TODO(rob): add int4, mxfp4, int8 as integrations
+ # are migrated to use the oracle one-by-one.
+ SUPPORTED_W = [
+ kFp8Static128BlockSym,
+ kFp8StaticChannelSym,
+ kNvfp4Static,
+ ]
+ return weight_key in SUPPORTED_W
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu", "swigluoai"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
@property
def quant_type_id(self) -> int:
@@ -595,38 +637,15 @@ def moe_problem_size(
class MarlinExperts(MarlinExpertsBase):
- def __init__(
- self,
- quant_config: FusedMoEQuantConfig,
- w13_g_idx: torch.Tensor | None = None,
- w2_g_idx: torch.Tensor | None = None,
- w13_g_idx_sort_indices: torch.Tensor | None = None,
- w2_g_idx_sort_indices: torch.Tensor | None = None,
- is_k_full: bool = True,
- ):
- super().__init__(
- quant_config,
- w13_g_idx,
- w2_g_idx,
- w13_g_idx_sort_indices,
- w2_g_idx_sort_indices,
- is_k_full,
- )
-
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return True
@@ -720,42 +739,15 @@ def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
class BatchedMarlinExperts(MarlinExpertsBase):
- def __init__(
- self,
- max_num_tokens: int,
- num_dispatchers: int,
- quant_config: FusedMoEQuantConfig,
- w13_g_idx: torch.Tensor | None = None,
- w2_g_idx: torch.Tensor | None = None,
- w13_g_idx_sort_indices: torch.Tensor | None = None,
- w2_g_idx_sort_indices: torch.Tensor | None = None,
- is_k_full: bool = True,
- ):
- super().__init__(
- quant_config,
- w13_g_idx,
- w2_g_idx,
- w13_g_idx_sort_indices,
- w2_g_idx_sort_indices,
- is_k_full,
- )
- self.max_num_tokens = max_num_tokens
- self.num_dispatchers = num_dispatchers
-
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceDelegate()
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.BatchedExperts,
- mk.FusedMoEActivationFormat.BatchedExperts,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.BatchedExperts
def supports_chunking(self) -> bool:
return False
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index f2d9463e9183..4564e156041f 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -21,6 +21,8 @@
)
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
_get_config_dtype_str,
)
@@ -49,6 +51,14 @@
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8Dynamic128Sym,
+ kFp8DynamicTokenSym,
+ kFp8Static128BlockSym,
+ kFp8StaticChannelSym,
+ kFp8StaticTensorSym,
+)
from vllm.model_executor.utils import maybe_disable_graph_partition
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
@@ -2299,19 +2309,52 @@ def fused_experts_impl(
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
- super().__init__(quant_config)
-
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
+ super().__init__(moe_config, quant_config)
+
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return current_platform.is_cuda_alike()
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ p = current_platform
+ device_supports_fp8 = p.is_rocm() or (
+ p.is_cuda() and p.has_device_capability((9, 0))
)
+ if not device_supports_fp8:
+ return (weight_key, activation_key) == (None, None)
+
+ SUPPORTED_W_A = [
+ (kFp8Static128BlockSym, kFp8Dynamic128Sym),
+ (kFp8StaticChannelSym, kFp8DynamicTokenSym),
+ (kFp8StaticTensorSym, kFp8DynamicTokenSym),
+ (kFp8StaticTensorSym, kFp8StaticTensorSym),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu", "gelu", "swigluoai"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
+
def supports_chunking(self) -> bool:
return True
@@ -2486,11 +2529,12 @@ def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
class TritonWNA16Experts(TritonExperts):
- def __init__(
- self,
- quant_config: FusedMoEQuantConfig,
- ):
- super().__init__(quant_config)
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ return False
def apply(
self,
@@ -2629,10 +2673,12 @@ def apply(
def modular_triton_fused_moe(
- quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
+ moe_config: FusedMoEConfig,
+ quant_config: FusedMoEQuantConfig,
+ shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
- TritonExperts(quant_config),
+ TritonExperts(moe_config, quant_config),
shared_experts,
)
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
index 389ccf358c56..b3085d43c0b2 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
@@ -30,6 +30,17 @@ def __init__(self, moe: FusedMoEConfig):
self.moe: FusedMoEConfig = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None
+ @property
+ def supports_mk_interally(self) -> bool:
+ """
+ Returns True if this method supports using modular kernels (MK)
+ internally for MoE operations, False otherwise.
+
+ This method should be overridden by subclasses that support
+ modular kernels internally.
+ """
+ return False
+
@abstractmethod
def create_weights(
self,
diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
index c4bc1824aa1f..b209820cdfa9 100644
--- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
@@ -9,12 +9,16 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+)
from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels
@@ -241,8 +245,43 @@ def make_routing_data(
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
- def __init__(self, quant_config: FusedMoEQuantConfig):
- super().__init__(quant_config)
+ @staticmethod
+ def _supports_current_device() -> bool:
+ raise NotImplementedError(
+ "OAITritonExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ raise NotImplementedError(
+ "OAITritonExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ raise NotImplementedError(
+ "OAITritonExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ raise NotImplementedError(
+ "OAITritonExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ raise NotImplementedError(
+ "OAITritonExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
def supports_expert_map(self) -> bool:
return True
@@ -297,19 +336,9 @@ def _make_routing_data(
class OAITritonExperts(BaseOAITritonExperts):
- def __init__(self, quant_config: FusedMoEQuantConfig):
- # TODO (varun) : Enable activation quantization
- assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
- super().__init__(quant_config)
-
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return True
@@ -391,19 +420,9 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
One use case for it is to inject LoRA modules on the activation and moe_sum.
"""
- def __init__(self, quant_config: FusedMoEQuantConfig):
- # TODO (varun) : Enable activation quantization
- assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
- super().__init__(quant_config)
-
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return True
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index 3b3a789f6d6e..77898da75b1d 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -360,7 +360,7 @@ def __init__(
enable_eplb: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
- is_sequence_parallel=False,
+ is_sequence_parallel: bool = False,
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
routing_method_type: RoutingMethodType | None = None,
@@ -587,6 +587,7 @@ def __init__(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
+ intermediate_size_per_partition=self.intermediate_size_per_partition,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=moe_in_dtype,
@@ -595,9 +596,8 @@ def __init__(
has_bias=has_bias,
is_act_and_mul=is_act_and_mul,
is_lora_enabled=vllm_config.lora_config is not None,
- )
- self.moe_config_use_flashinfer_cutlass_kernels = (
- self.moe_config.use_flashinfer_cutlass_kernels
+ activation=activation,
+ device=vllm_config.device_config.device,
)
self.quant_config = quant_config
@@ -685,6 +685,10 @@ def _get_quant_method() -> FusedMoEMethodBase:
# This is called after all weight loading and post-processing, so it
# should be safe to swap out the quant_method.
def maybe_init_modular_kernel(self) -> None:
+ if self.quant_method.supports_mk_interally:
+ logger.info_once("DEBUG: SKIPPING MK INIT: Handled Internally!!!!")
+ return
+
self.ensure_moe_quant_config_init()
# routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend.
@@ -763,14 +767,6 @@ def use_deepep_ht_kernels(self):
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
- @property
- def use_flashinfer_cutlass_kernels(self):
- return (
- self.moe_quant_config is not None
- and self.moe_quant_config.quant_dtype == "nvfp4"
- and self.moe_config_use_flashinfer_cutlass_kernels
- )
-
@property
def use_marlin_kernels(self):
return getattr(self.quant_method, "use_marlin", False)
@@ -780,7 +776,7 @@ def use_dp_chunking(self) -> bool:
return (
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
- or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
+ or self.moe_parallel_config.use_fi_all2allv_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
@property
@@ -1688,6 +1684,7 @@ def must_reduce_shared_expert_outputs(self) -> bool:
early.
"""
assert self.quant_method is not None
+ # TODO(rob): investigate this.
return (
isinstance(self.quant_method, FusedMoEModularMethod)
and self.quant_method.fused_experts.output_is_reduced()
@@ -1894,8 +1891,15 @@ def forward_impl(
self.ensure_moe_quant_config_init()
self.ensure_dp_chunking_init()
+ # TODO: figure out a better way to express who is responsible for the SE.
+ mk_has_shared_expert = (
+ self.quant_method.supports_mk_interally
+ and hasattr(self.quant_method, "kernel")
+ and getattr(self.quant_method.kernel, "shared_experts", None) is not None
+ )
has_separate_shared_experts = (
not isinstance(self.quant_method, FusedMoEModularMethod)
+ and not mk_has_shared_expert
and self.shared_experts is not None
)
@@ -1919,8 +1923,9 @@ def forward_impl(
hidden_states, router_logits, has_separate_shared_experts
)
- do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance(
- self.quant_method, FusedMoEModularMethod
+ do_naive_dispatch_combine: bool = self.dp_size > 1 and not (
+ isinstance(self.quant_method, FusedMoEModularMethod)
+ or self.quant_method.supports_mk_interally
)
ctx = get_forward_context()
diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py
index a6df2b20af9c..fad57edcd16b 100644
--- a/vllm/model_executor/layers/fused_moe/modular_kernel.py
+++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py
@@ -13,6 +13,7 @@
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
@@ -22,6 +23,9 @@
count_expert_num_tokens,
disable_inplace,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+)
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import (
dbo_enabled,
@@ -374,18 +378,85 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def __init__(
self,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
"""
+ moe_config: MoE layer configuration.
quant_config: Quantization parameters for this experts instance.
"""
+ self.moe_config = moe_config
self.quant_config = quant_config
+ self._max_num_tokens: int | None = None
+ self._num_dispatchers: int | None = None
+
+ @staticmethod
+ def should_pf_defer_input_quant(
+ moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
+ ) -> bool:
+ """
+ Whether or not the PrepareFinalize should defer input quantization
+ in the prepare step. If True, then the Experts kernel will
+ execute the input quantization itself.
+
+ Sample subclasses that override are AITER and FlashInfer CUTLASS.
+ """
+ return False
+
+ def _init_batched_experts_addl_params(
+ self,
+ max_num_tokens: int,
+ num_dispatchers: int,
+ ):
+ """
+ Initialize any additional parameters needed for batched experts.
+ """
+ self._max_num_tokens = max_num_tokens
+ self._num_dispatchers = num_dispatchers
@property
+ def max_num_tokens(self) -> int:
+ if self._max_num_tokens is None:
+ raise AttributeError("max_num_tokens only valid for BatchedExperts")
+ return self._max_num_tokens
+
+ @property
+ def num_dispatchers(self) -> int:
+ if self._num_dispatchers is None:
+ raise AttributeError("num_dispatchers only valid for BatchedExperts")
+ return self._num_dispatchers
+
+ @classmethod
+ def make_standard_experts(
+ cls,
+ moe_config: FusedMoEConfig,
+ quant_config: FusedMoEQuantConfig,
+ ) -> "FusedMoEPermuteExpertsUnpermute":
+ """
+ Factory method to create an instance of this class.
+ """
+ assert cls.activation_format() == FusedMoEActivationFormat.Standard
+ return cls(moe_config, quant_config)
+
+ @classmethod
+ def make_batched_experts(
+ cls,
+ moe_config: FusedMoEConfig,
+ quant_config: FusedMoEQuantConfig,
+ max_num_tokens: int,
+ num_dispatchers: int,
+ ) -> "FusedMoEPermuteExpertsUnpermute":
+ """
+ Factory method to create an instance of this class.
+ """
+ assert cls.activation_format() == FusedMoEActivationFormat.BatchedExperts
+ instance = cls(moe_config, quant_config)
+ instance._init_batched_experts_addl_params(max_num_tokens, num_dispatchers)
+ return instance
+
+ @staticmethod
@abstractmethod
- def activation_formats(
- self,
- ) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
+ def activation_format() -> FusedMoEActivationFormat:
"""
A property which is a tuple of the input and output activation formats
for the 'apply' method.
@@ -435,6 +506,78 @@ def moe_problem_size(
return E, M, N, K, topk
+ #
+ # Various helpers for registering support for various features.
+ # Used by the oracle to select a particular kernel for a deployment.
+ #
+
+ @staticmethod
+ def is_supported_config(
+ cls: type["FusedMoEPermuteExpertsUnpermute"],
+ moe_config: FusedMoEConfig,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ activation_format: FusedMoEActivationFormat,
+ ) -> tuple[bool, str | None]:
+ def _make_reason(reason: str) -> str:
+ return f"kernel does not support {reason}"
+
+ if not cls._supports_current_device():
+ return False, _make_reason("current device")
+ elif not (moe_config.is_act_and_mul or cls._supports_no_act_and_mul()):
+ return False, _make_reason("no act_and_mul MLP layer")
+ elif not cls._supports_activation(moe_config.activation):
+ return False, _make_reason(f"{moe_config.activation} activation")
+ elif not cls._supports_quant_scheme(weight_key, activation_key):
+ return False, _make_reason("quantization scheme")
+ elif not cls._supports_parallel_config(moe_config.moe_parallel_config):
+ return False, _make_reason("parallel config")
+ elif activation_format != cls.activation_format():
+ return False, _make_reason(f"{activation_format.value} activation format")
+ return True, None
+
+ @staticmethod
+ @abstractmethod
+ def _supports_current_device() -> bool:
+ """
+ Whether the kernel supports the current device type
+ (compute cability and current platform).
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def _supports_no_act_and_mul() -> bool:
+ """
+ Whether the kernel supports act_and_mul=False, i.e.
+ non-gated MoE models like Nemotron-Nano.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def _supports_activation(activation: str) -> bool:
+ """
+ Whether the kernel supports a particular act function.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ """
+ Whether the kernel supports deployment in expert parallel.
+ """
+ raise NotImplementedError
+
#
# Various helpers for accessing quantization parameters from the
# quant_config.
@@ -715,12 +858,12 @@ def __init__(
self._post_init_setup()
assert (
- prepare_finalize.activation_format == fused_experts.activation_formats[0]
+ prepare_finalize.activation_format == fused_experts.activation_format()
), (
f"{prepare_finalize.__class__.__name__}."
f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}."
- f"{fused_experts.activation_formats[0]}"
+ f"{fused_experts.activation_format()}"
)
def _post_init_setup(self):
diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py
index f5c3b9af611f..fe6d5cc68bc9 100644
--- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py
+++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
+from typing import TYPE_CHECKING
import torch
@@ -8,12 +9,18 @@
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe.all2all_utils import (
+ maybe_make_prepare_finalize,
+)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config,
)
+from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
+ is_supported_config_trtllm,
+)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
@@ -26,128 +33,279 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
-from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
- cutlass_group_gemm_supported,
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
)
-from vllm.platforms import current_platform
-from vllm.utils.deep_gemm import is_deep_gemm_supported
-from vllm.utils.flashinfer import has_flashinfer_moe
-from vllm.utils.import_utils import has_deep_gemm
+
+if TYPE_CHECKING:
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
logger = init_logger(__name__)
class Fp8MoeBackend(Enum):
NONE = 0
- FLASHINFER_TRTLLM = 1
- FLASHINFER_CUTLASS = 2
- DEEPGEMM = 3
- MARLIN = 4
- TRITON = 5
- AITER = 6
- VLLM_CUTLASS = 7
+ FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
+ FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS"
+ DEEPGEMM = "DEEPGEMM"
+ BATCHED_DEEPGEMM = "BATCHED_DEEPGEMM"
+ MARLIN = "MARLIN"
+ TRITON = "TRITON"
+ BATCHED_TRITON = "BATCHED_TRITON"
+ AITER = "AITER"
+ VLLM_CUTLASS = "VLLM_CUTLASS"
+
+
+def backend_2_kernel_cls(
+ backend: Fp8MoeBackend,
+) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
+ if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
+ raise NotImplementedError
+
+ elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
+ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
+ FlashInferExperts,
+ )
+
+ return FlashInferExperts
+
+ elif backend == Fp8MoeBackend.DEEPGEMM:
+ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
+ TritonOrDeepGemmExperts,
+ )
+
+ return TritonOrDeepGemmExperts
+
+ elif backend == Fp8MoeBackend.BATCHED_DEEPGEMM:
+ from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
+ BatchedDeepGemmExperts,
+ )
+
+ return BatchedDeepGemmExperts
+
+ elif backend == Fp8MoeBackend.MARLIN:
+ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
+ MarlinExperts,
+ )
+
+ return MarlinExperts
+
+ elif backend == Fp8MoeBackend.TRITON:
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
+ TritonExperts,
+ )
+
+ return TritonExperts
+
+ elif backend == Fp8MoeBackend.BATCHED_TRITON:
+ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
+ BatchedTritonExperts,
+ )
+
+ return BatchedTritonExperts
+
+ elif backend == Fp8MoeBackend.AITER:
+ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
+ AiterExperts,
+ )
+
+ return AiterExperts
+
+ elif backend == Fp8MoeBackend.VLLM_CUTLASS:
+ from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import (
+ TritonOrCutlassExperts,
+ )
+
+ return TritonOrCutlassExperts
+
+ else:
+ raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
def select_fp8_moe_backend(
- block_quant: bool,
- tp_size: int,
- with_lora_support: bool,
- is_act_and_mul: bool = True,
+ config: FusedMoEConfig,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
allow_vllm_cutlass: bool = False,
-) -> Fp8MoeBackend:
+) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
- # TODO(rob): in a future PR, we will query each mk for
- # supported features and return the mk directly, just like
- # we do for the Attention Backend.
-
- if with_lora_support:
- return Fp8MoeBackend.TRITON
-
- def _make_log_backend(backend_name: str):
- return f"Using {backend_name} backend for FP8 MoE"
-
- # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
- if (
- current_platform.is_cuda()
- and (
- current_platform.is_device_capability_family(100)
- or current_platform.is_device_capability(90)
+
+ if config.is_lora_enabled:
+ return Fp8MoeBackend.TRITON, backend_2_kernel_cls(Fp8MoeBackend.TRITON)
+
+ # NOTE(rob): this is kind of a hack. We need to peak into
+ # the prepare-finalize selection to determine if we are using
+ # the batched or standard expert format.
+ use_batched = (
+ config.moe_parallel_config.use_deepep_ll_kernels
+ or config.moe_parallel_config.use_pplx_kernels
+ )
+ activation_format = (
+ mk.FusedMoEActivationFormat.BatchedExperts
+ if use_batched
+ else mk.FusedMoEActivationFormat.Standard
+ )
+
+ def _make_log_backend(backend: Fp8MoeBackend):
+ return f"Using `{backend.value}` backend for FP8 MoE"
+
+ def _make_log_unsupported(backend: Fp8MoeBackend, reason: str | None) -> str:
+ if reason:
+ return (
+ f"FP8 MoE backend `{backend.value}` does not support the "
+ f"deployment configuration since {reason}."
+ )
+ else:
+ return (
+ f"FP8 MoE backend `{backend.value}` does not support the "
+ "deployment configuration."
+ )
+
+ def _return_or_raise(
+ backend: Fp8MoeBackend,
+ config: FusedMoEConfig,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ activation_format: mk.FusedMoEActivationFormat,
+ ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
+ k_cls = backend_2_kernel_cls(backend)
+ supported, reason = k_cls.is_supported_config(
+ k_cls, config, weight_key, activation_key, activation_format
)
- and envs.VLLM_USE_FLASHINFER_MOE_FP8
- and has_flashinfer_moe()
- ):
- backend = get_flashinfer_moe_backend()
- if backend == FlashinferMoeBackend.TENSORRT_LLM:
- logger.info_once(_make_log_backend("FlashInfer TRTLLM"))
- if not is_act_and_mul:
- raise ValueError(
- "FlashInfer TRTLLM FP8 MoE backend only supports "
- "act_and_mul gate_up_project fusion. Please set "
- "VLLM_USE_FLASHINFER_MOE_FP8=throughput to use the "
- "FlashInfer CUTLASS backend instead."
+ if supported:
+ logger.info_once(_make_log_backend(backend))
+ return backend, k_cls
+ raise ValueError(_make_log_unsupported(backend, reason))
+
+ # NOTE: the kernels are selected in the following order.
+ AVAILABLE_BACKENDS = [
+ Fp8MoeBackend.AITER,
+ Fp8MoeBackend.FLASHINFER_TRTLLM,
+ Fp8MoeBackend.FLASHINFER_CUTLASS,
+ Fp8MoeBackend.DEEPGEMM,
+ Fp8MoeBackend.BATCHED_DEEPGEMM,
+ Fp8MoeBackend.VLLM_CUTLASS,
+ Fp8MoeBackend.TRITON,
+ Fp8MoeBackend.BATCHED_TRITON,
+ Fp8MoeBackend.MARLIN,
+ ]
+
+ # Handle explicit FlashInfer FP8 configuration.
+ if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP8"):
+ if not envs.VLLM_USE_FLASHINFER_MOE_FP8:
+ # If the user rejects FlashInfer remove those backends.
+ AVAILABLE_BACKENDS.remove(Fp8MoeBackend.FLASHINFER_TRTLLM)
+ AVAILABLE_BACKENDS.remove(Fp8MoeBackend.FLASHINFER_CUTLASS)
+
+ elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
+ # If user is explicit about backend, validate it.
+ fi_backend = get_flashinfer_moe_backend()
+
+ if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
+ backend = Fp8MoeBackend.FLASHINFER_TRTLLM
+ supported, reason = is_supported_config_trtllm(
+ config, weight_key, activation_key, activation_format
)
- return Fp8MoeBackend.FLASHINFER_TRTLLM
- else:
- if block_quant and current_platform.is_device_capability_family(100):
- raise ValueError(
- "FlashInfer FP8 MoE throughput backend does not "
- "support block quantization on SM100. Please use "
- "VLLM_FLASHINFER_MOE_BACKEND=latency to use the "
- "FlashInfer TRTLLM backend instead."
+ if supported:
+ logger.info_once(_make_log_backend(backend))
+ return backend, None
+ else:
+ raise ValueError(_make_log_unsupported(backend, reason))
+
+ elif fi_backend == FlashinferMoeBackend.CUTLASS:
+ backend = Fp8MoeBackend.FLASHINFER_CUTLASS
+ return _return_or_raise(
+ backend, config, weight_key, activation_key, activation_format
)
- logger.info_once(_make_log_backend("FlashInfer CUTLASS"))
- return Fp8MoeBackend.FLASHINFER_CUTLASS
-
- # weight-only path for older GPUs without native FP8
- if (
- current_platform.is_cuda() and not current_platform.has_device_capability(89)
- ) or envs.VLLM_TEST_FORCE_FP8_MARLIN:
- logger.info_once(_make_log_backend("Marlin"), scope="local")
- return Fp8MoeBackend.MARLIN
-
- # Determine if we should use DeepGEMM with block-quantized weights:
- # - If explicitly set by user, respect their choice
- # - If not explicitly set (default), disable when TP size is >= 8
- moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
- if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and tp_size >= 8:
- moe_use_deep_gemm = False
- logger.info_once(
- "DeepGEMM MoE is disabled by default when TP size is >= 8. "
- "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
- scope="local",
- )
- use_deep_gemm = envs.VLLM_USE_DEEP_GEMM
- if not is_deep_gemm_supported():
- use_deep_gemm = False
- logger.info_once(
- "DeepGEMM is disabled because the platform does not support it.",
- scope="local",
+ else:
+ assert fi_backend == FlashinferMoeBackend.CUTEDSL
+ raise ValueError("FlashInfer MaskedGEMM not supported for FP8")
+
+ else:
+ # If the user is not explicit about the backend, try both.
+ for backend in [
+ Fp8MoeBackend.FLASHINFER_TRTLLM,
+ Fp8MoeBackend.FLASHINFER_CUTLASS,
+ ]:
+ k_cls = backend_2_kernel_cls(backend)
+ if k_cls.is_supported_config(
+ k_cls, config, weight_key, activation_key, activation_format
+ ):
+ logger.info_once(_make_log_backend(backend))
+ return backend, k_cls
+
+ raise NotImplementedError(
+ "Found VLLM_USE_FLASHINFER_MOE_FP8=1, but no "
+ "FlashInfer FP8 MoE backend supports the configuration."
+ )
+
+ # Handle explicit DeepGEMM FP8 configuration.
+ if envs.is_set("VLLM_USE_DEEP_GEMM") or envs.is_set("VLLM_MOE_USE_DEEP_GEMM"):
+ if not envs.VLLM_USE_DEEP_GEMM or not envs.VLLM_MOE_USE_DEEP_GEMM:
+ AVAILABLE_BACKENDS.remove(Fp8MoeBackend.DEEPGEMM)
+ else:
+ backend = (
+ Fp8MoeBackend.DEEPGEMM
+ if activation_format == mk.FusedMoEActivationFormat.Standard
+ else Fp8MoeBackend.BATCHED_DEEPGEMM
+ )
+ return _return_or_raise(
+ backend, config, weight_key, activation_key, activation_format
+ )
+
+ # Handle explicit MARLIN FP8 configuration.
+ if envs.VLLM_TEST_FORCE_FP8_MARLIN:
+ backend = Fp8MoeBackend.MARLIN
+ return _return_or_raise(
+ backend, config, weight_key, activation_key, activation_format
)
- if use_deep_gemm and moe_use_deep_gemm and block_quant:
- if not has_deep_gemm():
- logger.warning_once(
- "DeepGEMM backend requested but not available.", scope="local"
+ # Handle explicit AITER FP8 configuration.
+ if envs.is_set("VLLM_ROCM_USE_AITER") or envs.is_set("VLLM_ROCM_USE_AITER_MOE"):
+ if not envs.VLLM_ROCM_USE_AITER or not envs.VLLM_ROCM_USE_AITER_MOE:
+ AVAILABLE_BACKENDS.remove(Fp8MoeBackend.AITER)
+ else:
+ backend = Fp8MoeBackend.AITER
+ return _return_or_raise(
+ backend, config, weight_key, activation_key, activation_format
)
- elif is_deep_gemm_supported():
- logger.info_once(_make_log_backend("DeepGEMM"), scope="local")
- return Fp8MoeBackend.DEEPGEMM
- if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
- logger.info_once(_make_log_backend("ROCm AITER"), scope="local")
- return Fp8MoeBackend.AITER
+ if not allow_vllm_cutlass:
+ AVAILABLE_BACKENDS.remove(Fp8MoeBackend.VLLM_CUTLASS)
+
+ # Select kernels in order of backend.
+ for backend in AVAILABLE_BACKENDS:
+ if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
+ k_cls = None # type: ignore[assignment]
+ supported, reason = is_supported_config_trtllm(
+ config,
+ weight_key,
+ activation_key,
+ activation_format,
+ )
+ else:
+ k_cls = backend_2_kernel_cls(backend)
+ supported, reason = k_cls.is_supported_config(
+ k_cls,
+ config,
+ weight_key,
+ activation_key,
+ activation_format,
+ )
- if allow_vllm_cutlass and not block_quant and cutlass_group_gemm_supported():
- logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local")
- return Fp8MoeBackend.VLLM_CUTLASS
+ if supported:
+ logger.info_once(_make_log_backend(backend), scope="local")
+ return backend, k_cls
+ else:
+ logger.info_once(_make_log_unsupported(backend, reason), scope="local")
- # default to Triton
- logger.info_once(_make_log_backend("Triton"), scope="local")
- return Fp8MoeBackend.TRITON
+ raise NotImplementedError(
+ "No FP8 MoE backend supports the deployment configuration."
+ )
def convert_to_fp8_moe_kernel_format(
@@ -205,6 +363,8 @@ def make_fp8_moe_quant_config(
a1_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
block_shape: list[int] | None = None,
+ per_act_token_quant: bool = False,
+ per_out_ch_quant: bool = False,
) -> FusedMoEQuantConfig | None:
"""
Create FusedMoEQuantConfig for the specifed FP8 Backend.
@@ -257,102 +417,63 @@ def make_fp8_moe_quant_config(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
+ per_act_token_quant=per_act_token_quant,
+ per_out_ch_quant=per_out_ch_quant,
)
def make_fp8_moe_kernel(
- layer: torch.nn.Module,
+ layer: "FusedMoE",
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
fp8_backend: Fp8MoeBackend,
+ experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> tuple[mk.FusedMoEModularKernel, bool]:
- # Delayed import is required since the oracle is imported
- # by CPU backends which cannot import all of these experts.
- # TODO: update the experts to make this not happen.
- from vllm.model_executor.layers.fused_moe.prepare_finalize import (
- MoEPrepareAndFinalizeNoEP,
+ # Create Prepare/Finalize.
+ prepare_finalize = maybe_make_prepare_finalize(
+ moe=moe_config,
+ quant_config=moe_quant_config,
+ routing_tables=None, # TODO: init routing tables here?
+ defer_input_quant=experts_cls.should_pf_defer_input_quant(
+ moe_config, moe_quant_config
+ ),
+ allow_new_interface=True,
)
+ assert prepare_finalize is not None
- # NOTE(rob): this is a WIP refactor. We are first migrating
- # all of the kernels in the TP case to use mk. Once this is
- # done, then we will initialzie the TP case and DP/EP case
- # via the same code path (i.e. via maybe_init_modular_kernel).
- # NOTE(rob): in progress migrating all into this format.
- use_inplace = True
- if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
- FlashInferExperts,
- )
+ logger.info_once("Using %s", prepare_finalize.__class__.__name__)
- kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(
- defer_input_quant=moe_quant_config.is_block_quantized
- ),
- FlashInferExperts(
- out_dtype=layer.orig_dtype,
- quant_config=moe_quant_config,
- ep_rank=moe_config.ep_rank,
- ep_size=moe_config.ep_size,
- tp_rank=moe_config.tp_rank,
- tp_size=moe_config.tp_size,
- use_dp=(moe_config.dp_size > 1),
- use_deepseek_fp8_block_scale=moe_quant_config.is_block_quantized,
- ),
- )
- use_inplace = False
-
- elif fp8_backend == Fp8MoeBackend.AITER:
- from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
- AiterExperts,
- )
-
- kernel = mk.FusedMoEModularKernel(
- # TODO: make defer_input_quant an attr of the AiterExperts
- MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
- AiterExperts(quant_config=moe_quant_config),
- )
- elif fp8_backend == Fp8MoeBackend.MARLIN:
- from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
- MarlinExperts,
- )
-
- kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- MarlinExperts(quant_config=moe_quant_config),
- )
- elif fp8_backend == Fp8MoeBackend.VLLM_CUTLASS:
- from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import (
- TritonOrCutlassExperts,
- )
-
- kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- TritonOrCutlassExperts(
- out_dtype=moe_config.in_dtype,
- e=layer.local_num_experts,
- n=layer.intermediate_size_per_partition,
- k=layer.hidden_size,
- device=layer.w13_weight.device,
- quant_config=moe_quant_config,
- ),
- )
- elif fp8_backend == Fp8MoeBackend.DEEPGEMM:
- from vllm.model_executor.layers.fused_moe import (
- TritonOrDeepGemmExperts,
- )
-
- kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- TritonOrDeepGemmExperts(quant_config=moe_quant_config),
+ # Create Experts.
+ if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.Standard:
+ experts = experts_cls.make_standard_experts(
+ moe_config=moe_config,
+ quant_config=moe_quant_config,
)
else:
- from vllm.model_executor.layers.fused_moe.fused_moe import (
- TritonExperts,
+ max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
+ assert max_num_tokens_per_rank is not None
+ experts = experts_cls.make_batched_experts(
+ moe_config=moe_config,
+ quant_config=moe_quant_config,
+ max_num_tokens=max_num_tokens_per_rank,
+ num_dispatchers=prepare_finalize.num_dispatchers(),
)
- assert fp8_backend == Fp8MoeBackend.TRITON
- kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- TritonExperts(quant_config=moe_quant_config),
- )
- return kernel, use_inplace
+ # NOTE(rob): we only want the ModularKernel to control the SharedExpert
+ # if we are using all2all (for SBO). Need to make a change somewhere
+ # else to prevent double running the Shared Expert.
+ # This needs to be refactored.
+ kernel = mk.FusedMoEModularKernel(
+ prepare_finalize,
+ experts,
+ shared_experts=(
+ getattr(layer, "shared_expert", None)
+ if moe_config.moe_parallel_config.use_all2all_kernels
+ else None
+ ),
+ moe_parallel_config=moe_config.moe_parallel_config,
+ )
+
+ # TODO(rob): update inplace logic to be part of the kernel.
+ inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
+ return kernel, inplace
diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
index 547a2a795d19..e3d8bb4ab062 100644
--- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
+++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
@@ -1,33 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
+from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe.all2all_utils import (
+ maybe_make_prepare_finalize,
+)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config,
)
-from vllm.model_executor.layers.fused_moe.cutlass_moe import (
- CutlassExpertsFp4,
-)
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
- FlashInferExperts,
-)
-from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
- MarlinExperts,
-)
-from vllm.model_executor.layers.fused_moe.prepare_finalize import (
- MoEPrepareAndFinalizeNoEP,
-)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
- is_flashinfer_fp4_cutedsl_moe_available,
- is_flashinfer_fp4_cutlass_moe_available,
+ is_supported_config_trtllm,
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
@@ -35,22 +26,24 @@
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
- is_fp4_marlin_supported,
prepare_nvfp4_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
- cutlass_fp4_supported,
+ QuantKey,
)
+if TYPE_CHECKING:
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
+
logger = init_logger(__name__)
class NvFp4MoeBackend(Enum):
- FLASHINFER_CUTLASS = "FlashInfer CUTLASS"
- FLASHINFER_TRTLLM = "FlashInfer TRTLLM"
- FLASHINFER_CUTEDSL = "FlashInfer CUTEDSL"
- VLLM_CUTLASS = "vLLM CUTASS"
- MARLIN = "vLLM MARLIN"
+ FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
+ FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS"
+ FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL"
+ VLLM_CUTLASS = "VLLM_CUTLASS"
+ MARLIN = "MARLIN"
FLASHINFER_NVFP4_MOE_BACKENDS = [
@@ -71,44 +64,184 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
# of all experts in Expert Parallel Mode when all experts are not
# on the same rank.
- return backend in [
- NvFp4MoeBackend.FLASHINFER_CUTLASS,
- NvFp4MoeBackend.FLASHINFER_TRTLLM,
- ]
+ return backend in FLASHINFER_NVFP4_MOE_BACKENDS
+
+
+def backend_2_kernel_cls(
+ backend: NvFp4MoeBackend,
+) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
+ if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
+ raise NotImplementedError
+
+ elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
+ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
+ FlashInferExperts,
+ )
+
+ return FlashInferExperts
+
+ elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
+ from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
+ FlashInferCuteDSLExperts,
+ )
+ return FlashInferCuteDSLExperts
+
+ elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
+ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
+ CutlassExpertsFp4,
+ )
+
+ return CutlassExpertsFp4
+
+ elif backend == NvFp4MoeBackend.MARLIN:
+ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
+ MarlinExperts,
+ )
+
+ return MarlinExperts
+ else:
+ raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
+
+
+def select_nvfp4_moe_backend(
+ config: FusedMoEConfig,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
+ """
+ Select the primary NvFP4 MoE backend
+ Note: Shape-specific fallbacks may still occur at runtime.
+ """
+
+ # NOTE(rob): this is kind of a hack. We need to peak into
+ # the prepare-finalize selection to determine if we are using
+ # the batched or standard expert format.
+ use_batched = (
+ config.moe_parallel_config.use_deepep_ll_kernels
+ or config.moe_parallel_config.use_pplx_kernels
+ )
+ activation_format = (
+ mk.FusedMoEActivationFormat.BatchedExperts
+ if use_batched
+ else mk.FusedMoEActivationFormat.Standard
+ )
-def select_nvfp4_moe_backend() -> NvFp4MoeBackend:
def _make_log_backend(backend: NvFp4MoeBackend):
- return f"Using {backend.value} backend for NvFp4 MoE"
+ return f"Using '{backend.value}' backend for NvFp4 MoE"
+
+ def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str:
+ if reason:
+ return (
+ f"NvFP4 MoE backend '{backend.value}' does not support the "
+ f"deployment configuration since {reason}."
+ )
+ else:
+ return (
+ f"NvFP4 MoE backend '{backend.value}' does not support the "
+ "deployment configuration."
+ )
- if cutlass_fp4_supported() and not envs.VLLM_TEST_FORCE_FP8_MARLIN:
- allow_flashinfer = (
- is_flashinfer_fp4_cutlass_moe_available()
- or is_flashinfer_fp4_cutedsl_moe_available()
+ def _return_or_raise(
+ backend: NvFp4MoeBackend,
+ config: FusedMoEConfig,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ activation_format: mk.FusedMoEActivationFormat,
+ ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
+ k_cls = backend_2_kernel_cls(backend)
+ supported, reason = k_cls.is_supported_config(
+ k_cls, config, weight_key, activation_key, activation_format
)
- if allow_flashinfer and envs.VLLM_USE_FLASHINFER_MOE_FP4:
- backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
+ if supported:
+ logger.info_once(_make_log_backend(backend))
+ return backend, k_cls
+ raise ValueError(_make_log_unsupported(backend, reason))
+
+ # NOTE: the kernels are selected in the following order.
+ AVAILABLE_BACKENDS = [
+ NvFp4MoeBackend.FLASHINFER_TRTLLM,
+ NvFp4MoeBackend.FLASHINFER_CUTEDSL,
+ NvFp4MoeBackend.FLASHINFER_CUTLASS,
+ NvFp4MoeBackend.MARLIN,
+ # NvFp4MoeBackend.VLLM_CUTLASS,
+ ]
+
+ if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"):
+ if not envs.VLLM_USE_FLASHINFER_MOE_FP4:
+ # If the user rejects FlashInfer remove those backends.
+ for fi_backend in FLASHINFER_NVFP4_MOE_BACKENDS:
+ AVAILABLE_BACKENDS.remove(fi_backend)
+
+ elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
+ # If user is explicit about backend, validate it.
+ fi_backend = get_flashinfer_moe_backend()
+
+ if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
+ backend = NvFp4MoeBackend.FLASHINFER_TRTLLM
+ supported, reason = is_supported_config_trtllm(
+ config, weight_key, activation_key, activation_format
+ )
+ if supported:
+ logger.info_once(_make_log_backend(backend))
+ return backend, None
+ else:
+ raise ValueError(_make_log_unsupported(backend, reason))
+ else:
+ backend = fi_2_vllm_backend_map[fi_backend]
+ return _return_or_raise(
+ backend, config, weight_key, activation_key, activation_format
+ )
else:
- backend = NvFp4MoeBackend.VLLM_CUTLASS
- elif is_fp4_marlin_supported():
- backend = NvFp4MoeBackend.MARLIN
- else:
- raise ValueError("No NvFp4 kernel backend available for NvFp4 MoE.")
+ # If the user is not explicit about the backend, try each.
+ for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
+ k_cls = backend_2_kernel_cls(backend)
+ if k_cls.is_supported_config(
+ k_cls, config, weight_key, activation_key, activation_format
+ ):
+ logger.info_once(_make_log_backend(backend))
+ return backend, k_cls
- # Log warning if FI backend requested but not available.
- if (
- backend not in FLASHINFER_NVFP4_MOE_BACKENDS
- and envs.VLLM_USE_FLASHINFER_MOE_FP4
- ):
- logger.warning_once(
- "Requested FlashInfer backend for NvFp4 MoE, but it's not available. "
- "Falling back to %s for NvFp4 MoE",
- backend.value,
- scope="local",
+ raise NotImplementedError(
+ "Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
+ "FlashInfer NVFP4 MoE backend supports the configuration."
+ )
+
+ if envs.VLLM_TEST_FORCE_FP8_MARLIN:
+ backend = NvFp4MoeBackend.MARLIN
+ return _return_or_raise(
+ backend, config, weight_key, activation_key, activation_format
)
- else:
- logger.info_once(_make_log_backend(backend), scope="local")
- return backend
+
+ # Select kernels in order of backend.
+ for backend in AVAILABLE_BACKENDS:
+ if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
+ k_cls = None # type: ignore[assignment]
+ supported, reason = is_supported_config_trtllm(
+ config,
+ weight_key,
+ activation_key,
+ activation_format,
+ )
+ else:
+ k_cls = backend_2_kernel_cls(backend)
+ supported, reason = k_cls.is_supported_config(
+ k_cls,
+ config,
+ weight_key,
+ activation_key,
+ activation_format,
+ )
+
+ if supported:
+ logger.info_once(_make_log_backend(backend), scope="local")
+ return backend, k_cls
+ else:
+ logger.info_once(_make_log_unsupported(backend, reason), scope="local")
+
+ raise NotImplementedError(
+ "No NvFp4 MoE backend supports the deployment configuration."
+ )
def convert_to_nvfp4_moe_kernel_format(
@@ -227,54 +360,54 @@ def make_nvfp4_moe_quant_config(
def make_nvfp4_moe_kernel(
- backend: NvFp4MoeBackend,
- quant_config: FusedMoEQuantConfig,
+ layer: "FusedMoE",
+ moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
-) -> mk.FusedMoEModularKernel | None:
- assert moe_config.dp_size == 1
-
- UNSUPPORTED_BACKENDS = [
- # TRTLLM does not use the modular kernl abstraction.
- NvFp4MoeBackend.FLASHINFER_TRTLLM,
- # CUTEDSL is used with BATCHED (masked) format only.
- # TODO: add here once we support dp/ep via the oracle.
- NvFp4MoeBackend.FLASHINFER_CUTEDSL,
- ]
+ experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
+) -> mk.FusedMoEModularKernel:
+ # Create Prepare/Finalize.
+ prepare_finalize = maybe_make_prepare_finalize(
+ moe=moe_config,
+ quant_config=moe_quant_config,
+ routing_tables=None, # TODO: init routing tables here?
+ defer_input_quant=experts_cls.should_pf_defer_input_quant(
+ moe_config, moe_quant_config
+ ),
+ allow_new_interface=True,
+ )
+ assert prepare_finalize is not None
- if backend in UNSUPPORTED_BACKENDS:
- return None
+ logger.info_once("Using %s", prepare_finalize.__class__.__name__)
- elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
- return mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
- FlashInferExperts(
- out_dtype=moe_config.in_dtype,
- quant_config=quant_config,
- ep_rank=moe_config.ep_rank,
- ep_size=moe_config.ep_size,
- tp_rank=moe_config.tp_rank,
- tp_size=moe_config.tp_size,
- use_dp=False,
- use_deepseek_fp8_block_scale=False,
- ),
+ # Create Experts.
+ if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.Standard:
+ experts = experts_cls.make_standard_experts(
+ moe_config=moe_config,
+ quant_config=moe_quant_config,
)
-
- elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
- return mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
- CutlassExpertsFp4(
- out_dtype=moe_config.in_dtype,
- # TODO(rob): see what impact this has on expert map?
- max_experts_per_worker=moe_config.num_experts,
- quant_config=quant_config,
- ),
+ else:
+ max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
+ assert max_num_tokens_per_rank is not None
+ experts = experts_cls.make_batched_experts(
+ moe_config=moe_config,
+ quant_config=moe_quant_config,
+ max_num_tokens=max_num_tokens_per_rank,
+ num_dispatchers=prepare_finalize.num_dispatchers(),
)
- elif backend == NvFp4MoeBackend.MARLIN:
- return mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- MarlinExperts(quant_config=quant_config),
- )
+ # NOTE(rob): we only want the ModularKernel to control the SharedExpert
+ # if we are using all2all (for SBO). Need to make a change somewhere
+ # else to prevent double running the Shared Expert.
+ # This needs to be refactored.
+ kernel = mk.FusedMoEModularKernel(
+ prepare_finalize,
+ experts,
+ shared_experts=(
+ getattr(layer, "shared_expert", None)
+ if moe_config.moe_parallel_config.use_all2all_kernels
+ else None
+ ),
+ moe_parallel_config=moe_config.moe_parallel_config,
+ )
- else:
- raise ValueError(f"Unknown NvFp4 MoE backend: {backend}")
+ return kernel
diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py
index 5d806fa843a3..a4d38d184e9b 100644
--- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py
@@ -4,12 +4,143 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.distributed import get_ep_group
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous,
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
+from vllm.utils.flashinfer import nvfp4_block_scale_interleave
+
+
+class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
+ def __init__(
+ self,
+ defer_input_quant: bool = False,
+ is_sequence_parallel: bool = False,
+ num_dispatchers: int = 1,
+ ) -> None:
+ super().__init__()
+ self.defer_input_quant = defer_input_quant
+ self.is_sequence_parallel = is_sequence_parallel
+ self._num_dispatchers = num_dispatchers
+
+ @property
+ def activation_format(self) -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ def max_num_tokens_per_rank(self) -> int | None:
+ return None
+
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ return None
+
+ def num_dispatchers(self) -> int:
+ return self._num_dispatchers
+
+ def output_is_reduced(self) -> bool:
+ return False
+
+ def prepare(
+ self,
+ a1: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ num_experts: int,
+ expert_map: torch.Tensor | None,
+ apply_router_weight_on_input: bool,
+ quant_config: FusedMoEQuantConfig,
+ ) -> mk.PrepareResultType:
+ if apply_router_weight_on_input:
+ topk = topk_ids.size(1)
+ assert topk == 1, (
+ "apply_router_weight_on_input is only implemented for topk=1"
+ )
+ # Note: do not use inplace for shared experts overlap
+ a1 = a1 * topk_weights.to(a1.dtype)
+
+ # Defer input quantization to the MoE kernel.
+ if self.defer_input_quant:
+ a1q = a1
+ a1q_scale = None
+ else:
+ use_nvfp4 = quant_config.use_nvfp4_w4a4
+ a1q, a1q_scale = moe_kernel_quantize_input(
+ a1,
+ quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale,
+ quant_config.quant_dtype,
+ quant_config.per_act_token_quant,
+ quant_config.block_shape,
+ is_fp4_scale_swizzled=False,
+ )
+
+ # TODO - this is just for deepgemm?
+ expert_tokens_meta = None
+
+ from vllm.platforms import current_platform
+
+ # The torch ops do not support fp8, so use an int8 view.
+ # Since dispatch does not do a reduce, this is safe to do.
+ use_int8_view = a1q.dtype == current_platform.fp8_dtype()
+ if use_int8_view:
+ a1q = a1q.view(torch.int8)
+
+ # Skip gathering scales if we have static quantization
+ # (the scale is a scalar, replicated on all ranks) or
+ # if quantization is deferred.
+ skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
+ scales = None if skip_gather_scales else [a1q_scale]
+
+ res = get_ep_group().dispatch(
+ a1q,
+ topk_weights,
+ topk_ids,
+ is_sequence_parallel=self.is_sequence_parallel,
+ extra_tensors=scales,
+ )
+ if skip_gather_scales:
+ a1q, topk_weights, topk_ids = res
+ else:
+ a1q, topk_weights, topk_ids, scales = res
+ assert scales is not None and len(scales) == 1
+ a1q_scale = scales[0]
+
+ if use_int8_view:
+ a1q = a1q.view(current_platform.fp8_dtype())
+
+ # NOTE: shuffle into format expected by FLASHINFER_CUTLASS
+ # There are currently no other kernels that use this P/F
+ # with nvfp4. If we add other kernels in the future, we
+ # can regsiter a shuffle that gets called here.
+ if use_nvfp4:
+ a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
+
+ return a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights
+
+ def finalize(
+ self,
+ output: torch.Tensor,
+ fused_expert_output: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ apply_router_weight_on_input: bool,
+ weight_and_reduce_impl: mk.TopKWeightAndReduce,
+ ) -> None:
+ if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
+ weight_and_reduce_impl = TopKWeightAndReduceContiguous()
+
+ out = weight_and_reduce_impl.apply(
+ output=None,
+ fused_expert_output=fused_expert_output,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ )
+
+ output.copy_(
+ get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
+ )
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
index b78794c6bd83..cb1a211e9925 100644
--- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
@@ -9,11 +9,15 @@
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+)
class QuantMethod(IntEnum):
@@ -269,17 +273,46 @@ def rocm_aiter_fused_experts(
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
- def __init__(self, quant_config):
- super().__init__(quant_config)
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ @staticmethod
+ def should_pf_defer_input_quant(
+ fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
+ ) -> bool:
+ """
+ AITER Fused MoE kernels handle input quantization.
+ """
+ return True
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return rocm_aiter_ops.is_fused_moe_enabled()
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ # return (
+ # quant_scheme.is_unquantized
+ # or quant_scheme.is_fp8_w8a8
+ # or quant_scheme.is_mxfp4_w4a4
+ # )
+ return False
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu", "gelu"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
def supports_expert_map(self):
return True
diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py
index a143347b19f2..c3e3744935aa 100644
--- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py
@@ -37,7 +37,7 @@ def __init__(
use_overlapped
and not (
(self.enable_eplb and backend != "allgather_reducescatter")
- or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
+ or self.moe_config.use_fi_all2allv_kernels
)
and self._shared_experts is not None
)
diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
index 09d5e45c1ec2..e21d17226671 100644
--- a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
@@ -5,7 +5,10 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
@@ -15,19 +18,18 @@
class TritonOrCutlassExperts(FallbackExperts):
"""Cutlass with fallback to Triton for low latency shapes on SM100."""
+ _experts_cls = CutlassExpertsFp8
+ _fallback_cls = TritonExperts
+
def __init__(
self,
- e: int,
- n: int,
- k: int,
- out_dtype: torch.dtype | None,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
- device: torch.dtype,
):
self.is_sm100 = current_platform.has_device_capability(100)
super().__init__(
- experts=CutlassExpertsFp8(e, n, k, out_dtype, quant_config, device),
- fallback_experts=TritonExperts(quant_config),
+ experts=CutlassExpertsFp8(moe_config, quant_config),
+ fallback_experts=TritonExperts(moe_config, quant_config),
)
def workspace_shapes(
diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
index 55b1e1211b0a..6c13903fccd3 100644
--- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
@@ -4,7 +4,10 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts,
_valid_deep_gemm,
@@ -20,10 +23,13 @@
class TritonOrDeepGemmExperts(FallbackExperts):
"""DeepGemm with fallback to Triton for low latency shapes."""
- def __init__(self, quant_config: FusedMoEQuantConfig):
+ _experts_cls = DeepGemmExperts
+ _fallback_cls = TritonExperts
+
+ def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
super().__init__(
- experts=DeepGemmExperts(quant_config),
- fallback_experts=TritonExperts(quant_config),
+ experts=DeepGemmExperts(moe_config, quant_config),
+ fallback_experts=TritonExperts(moe_config, quant_config),
)
def workspace_shapes(
diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py
index c46f59564930..e7cec11c237d 100644
--- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py
@@ -6,38 +6,61 @@
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+)
+from vllm.platforms import current_platform
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
- moe: FusedMoEConfig,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
max_capture_size,
):
- super().__init__(quant_config)
- self.moe = moe
+ super().__init__(moe_config, quant_config)
self.gemm1_alpha = gemm1_alpha
self.gemm1_beta = gemm1_beta
self.gemm1_clamp_limit = gemm1_clamp_limit
self.max_capture_size = max_capture_size
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return current_platform.has_device_capability((10, 0))
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ # return quant_scheme.is_mxfp4_w4a16 or quant_scheme.is_mxfp4_w4a8
+ return False
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["swigluoai"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
def supports_chunking(self) -> bool:
return True
@@ -86,7 +109,7 @@ def apply(
topk = topk_ids.size(-1)
local_num_experts = w1.size(0)
intermediate_size = w2.size(1)
- local_expert_offset = self.moe.ep_rank * local_num_experts
+ local_expert_offset = self.moe_config.ep_rank * local_num_experts
x_quant = hidden_states
x_scale = a1q_scale
diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
index c31c0223eb06..ef0f79a18846 100644
--- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
+++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
@@ -121,14 +121,18 @@ def select_gemm_impl(
== FusedMoEActivationFormat.BatchedExperts
):
logger.debug("BatchedTritonExperts %s", self.moe)
- return BatchedTritonExperts(
+ return BatchedTritonExperts.make_batched_experts(
+ moe_config=self.moe,
+ quant_config=self.moe_quant_config,
max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
- quant_config=self.moe_quant_config,
)
else:
logger.debug("TritonExperts %s", self.moe)
- return TritonExperts(self.moe_quant_config)
+ return TritonExperts(
+ moe_config=self.moe,
+ quant_config=self.moe_quant_config,
+ )
def create_weights(
self,
@@ -257,8 +261,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.use_inplace = True
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
- AiterExperts(self.moe_quant_config),
- shared_experts=None,
+ AiterExperts(
+ moe_config=self.moe,
+ quant_config=self.moe_quant_config,
+ ),
)
elif self.flashinfer_cutlass_moe_enabled:
@@ -270,19 +276,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
- out_dtype=layer.params_dtype,
+ moe_config=self.moe,
quant_config=self.moe_quant_config,
- tp_rank=self.moe.moe_parallel_config.tp_rank,
- tp_size=self.moe.moe_parallel_config.tp_size,
- ep_rank=self.moe.moe_parallel_config.ep_rank,
- ep_size=self.moe.moe_parallel_config.ep_size,
),
)
else:
self.use_inplace = True
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
- TritonExperts(self.moe_quant_config),
+ TritonExperts(
+ moe_config=self.moe,
+ quant_config=self.moe_quant_config,
+ ),
shared_experts=None,
)
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index b71921c5c3d7..a6060ce01c4c 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -19,16 +19,15 @@
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEActivationFormat,
- FusedMoEConfig,
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
FusedMoEQuantConfig,
- fp8_w8a8_moe_quant_config,
- fp8_w8a16_moe_quant_config,
+ RoutingMethodType,
int4_w4a16_moe_quant_config,
int4_w4afp8_moe_quant_config,
int8_w8a8_moe_quant_config,
@@ -45,10 +44,10 @@
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
+ make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
- FLASHINFER_NVFP4_MOE_BACKENDS,
NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
@@ -61,10 +60,11 @@
WNA16_SUPPORTED_TYPES_MAP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
- build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
- select_nvfp4_gemm_impl,
+)
+from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
+ apply_fi_trtllm_fp8_per_tensor_moe,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe,
@@ -77,12 +77,16 @@
marlin_make_workspace_new,
marlin_moe_permute_scales,
)
-from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
- is_fp4_marlin_supported,
-)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
+ kFp8Dynamic128Sym,
+ kFp8DynamicTokenSym,
+ kFp8Static128BlockSym,
+ kFp8StaticChannelSym,
+ kFp8StaticTensorSym,
+ kNvfp4Dynamic,
+ kNvfp4Static,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz,
@@ -195,7 +199,7 @@ def get_moe_method(
f"or None for NVFP4A16, found {input_quant}",
)
return CompressedTensorsW4A4Nvfp4MoEMethod(
- layer.moe_config, layer_name, use_marlin=input_quant is None
+ layer.moe_config, layer_name, use_a16=(input_quant is None)
)
elif (
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
@@ -229,7 +233,7 @@ def __init__(
self,
moe: FusedMoEConfig,
layer_name: str | None = None,
- use_marlin: bool = False,
+ use_a16: bool = False,
):
if not moe.is_act_and_mul:
raise ValueError(
@@ -239,20 +243,31 @@ def __init__(
super().__init__(moe)
self.group_size = 16
- if use_marlin:
- if is_fp4_marlin_supported():
- self.nvfp4_backend = NvFp4MoeBackend.MARLIN
- else:
- raise ValueError(
- "Marlin FP4 MoE kernel requested but not ",
- "supported on current platform.",
- )
- else:
- self.nvfp4_backend = select_nvfp4_moe_backend()
+
+ # Select experts implementation.
+ self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
+ config=self.moe,
+ weight_key=kNvfp4Static,
+ activation_key=None if use_a16 else kNvfp4Dynamic,
+ )
+
+ # Delay creation of the kernel until after process-weights.
+ self.kernel: mk.FusedMoEModularKernel | None = None
+
+ # Used for weight loading.
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend
)
- self.kernel: mk.FusedMoEModularKernel | None = None
+
+ @property
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ if self.kernel is not None:
+ return self.kernel.prepare_finalize.topk_indices_dtype()
+ return None
+
+ @property
+ def supports_mk_interally(self) -> bool:
+ return True
def create_weights(
self,
@@ -425,50 +440,34 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_input_scale = a13_scale
layer.w2_input_scale = a2_scale
- # Initialize the kernel that will be called in apply().
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
- use_dp = self.moe.dp_size > 1
- if self.moe_quant_config is not None and not use_dp:
+ if self.moe_quant_config is not None:
+ assert self.experts_cls is not None
self.kernel = make_nvfp4_moe_kernel(
- backend=self.nvfp4_backend,
- quant_config=self.moe_quant_config,
+ layer=layer,
+ moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
+ experts_cls=self.experts_cls,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM]
- if self.nvfp4_backend in UNSUPPORTED:
- return None
- elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
- # TP case: avoid convert to ModularKernelMethod - to be refactored.
- if self.moe.dp_size == 1:
- return None
- # For now, fp4 moe only works with the flashinfer dispatcher.
- prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
- self.moe
- )
- logger.debug_once("%s", prepare_finalize.__class__.__name__)
- return prepare_finalize
- else:
- return super().maybe_make_prepare_finalize(routing_tables)
+ raise ValueError(
+ "CompressedTensorsW4A4NvFp4MoE uses the new modular kernel "
+ "initialization logic. This function should not be called."
+ )
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
- ) -> mk.FusedMoEPermuteExpertsUnpermute:
- assert self.moe_quant_config is not None
- """Return the appropriate GEMM experts implementation."""
- experts = select_nvfp4_gemm_impl(
- self.moe,
- self.moe_quant_config,
- allow_flashinfer=(self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS),
+ ) -> FusedMoEPermuteExpertsUnpermute:
+ raise ValueError(
+ "CompressedTensorsW4A4NvFp4MoE uses the new modular kernel "
+ "initialization logic. This function should not be called."
)
- logger.debug_once("Using %s", experts.__class__.__name__)
- return experts
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@@ -579,32 +578,45 @@ def __init__(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization."
)
- self.fp8_backend = select_fp8_moe_backend(
- block_quant=self.block_quant,
- tp_size=moe.tp_size,
- with_lora_support=moe.is_lora_enabled,
- # TODO(rob): enable selecting this externally.
+
+ ct2vllm_weight = {
+ QuantizationStrategy.CHANNEL: kFp8StaticChannelSym,
+ QuantizationStrategy.TENSOR: kFp8StaticTensorSym,
+ QuantizationStrategy.BLOCK: kFp8Static128BlockSym,
+ }
+ ct2vllm_act = {
+ QuantizationStrategy.TOKEN: kFp8DynamicTokenSym,
+ QuantizationStrategy.TENSOR: (
+ kFp8StaticTensorSym if self.static_input_scales else kFp8Dynamic128Sym
+ ),
+ }
+ weight_key = ct2vllm_weight[self.weight_quant.strategy]
+ if weight_key == kFp8Static128BlockSym:
+ activation_key = kFp8Dynamic128Sym
+ else:
+ activation_key = ct2vllm_act[self.input_quant.strategy]
+
+ # Select Fp8 MoE backend
+ self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
+ config=self.moe,
+ weight_key=weight_key,
+ activation_key=activation_key,
allow_vllm_cutlass=True,
)
- if self.fp8_backend != Fp8MoeBackend.MARLIN:
- per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
- per_channel_quant = (
- self.weight_quant.strategy == QuantizationStrategy.CHANNEL
- )
- if per_act_token != per_channel_quant:
- raise NotImplementedError(
- "For FP8 Fused MoE layers, per-token and per-channel must be "
- "used together."
- )
- # TODO(rob): hook this up in a follow up PR.
- if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- raise NotImplementedError(
- "FlashInfer TRTLLM backend not supported for compressed-tensors yet."
- )
- self.disable_expert_map = False
+ # Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
+ @property
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ if self.kernel is not None:
+ return self.kernel.prepare_finalize.topk_indices_dtype()
+ return None
+
+ @property
+ def supports_mk_interally(self) -> bool:
+ return True
+
def create_weights(
self,
layer: torch.nn.Module,
@@ -818,138 +830,55 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
+ assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel(
layer=layer,
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
+ experts_cls=self.experts_cls,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
- return None
- else:
- return super().maybe_make_prepare_finalize(routing_tables)
+ raise ValueError(
+ "CompressedTensorsW8A8Fp8MoE uses the new modular kernel "
+ "initialization logic. This function should not be called."
+ )
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
- # cutlass path
- assert self.moe_quant_config is not None
- if self.fp8_backend == Fp8MoeBackend.VLLM_CUTLASS:
- from vllm.model_executor.layers.fused_moe import (
- CutlassBatchedExpertsFp8,
- CutlassExpertsFp8,
- )
-
- experts: FusedMoEPermuteExpertsUnpermute
-
- num_dispatchers = prepare_finalize.num_dispatchers()
-
- if (
- prepare_finalize.activation_format
- == FusedMoEActivationFormat.BatchedExperts
- ):
- logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__)
- experts = CutlassBatchedExpertsFp8(
- max_experts_per_worker=self.moe.num_local_experts,
- num_dispatchers=num_dispatchers,
- out_dtype=self.moe.in_dtype,
- e=layer.local_num_experts,
- n=layer.intermediate_size_per_partition,
- k=layer.hidden_size,
- device=layer.w13_weight.device,
- quant_config=self.moe_quant_config,
- )
- else:
- logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
- experts = CutlassExpertsFp8(
- out_dtype=self.moe.in_dtype,
- e=layer.local_num_experts,
- n=layer.intermediate_size_per_partition,
- k=layer.hidden_size,
- device=layer.w13_weight.device,
- quant_config=self.moe_quant_config,
- )
-
- # TODO(rob): investigate disable_expert_map
- self.disable_expert_map = (
- num_dispatchers > 1 or not experts.supports_expert_map()
- )
-
- return experts
-
- from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
- BatchedDeepGemmExperts,
- )
- from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
- BatchedTritonExperts,
- )
- from vllm.model_executor.layers.fused_moe.fused_moe import (
- TritonExperts,
+ raise ValueError(
+ "CompressedTensorsW8A8Fp8MoE uses the new modular kernel "
+ "initialization logic. This function should not be called."
)
- from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
- TritonOrDeepGemmExperts,
- )
-
- assert self.fp8_backend not in [Fp8MoeBackend.AITER, Fp8MoeBackend.MARLIN]
-
- if (
- prepare_finalize.activation_format
- == FusedMoEActivationFormat.BatchedExperts
- ):
- max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
- assert max_num_tokens_per_rank is not None
-
- if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
- logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
- return BatchedDeepGemmExperts(
- max_num_tokens=max_num_tokens_per_rank,
- num_dispatchers=prepare_finalize.num_dispatchers(),
- quant_config=self.moe_quant_config,
- )
- else:
- logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
- return BatchedTritonExperts(
- max_num_tokens=max_num_tokens_per_rank,
- num_dispatchers=prepare_finalize.num_dispatchers(),
- quant_config=self.moe_quant_config,
- )
-
- else:
- if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
- logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
- return TritonOrDeepGemmExperts(self.moe_quant_config)
- else:
- logger.debug("TritonExperts(%s)", self.__class__.__name__)
- return TritonExperts(self.moe_quant_config)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
- if self.fp8_backend == Fp8MoeBackend.MARLIN:
- return fp8_w8a16_moe_quant_config(
- w1_scale=layer.w13_weight_scale,
- w2_scale=layer.w2_weight_scale,
- block_shape=self.weight_block_size,
- )
-
- per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
- per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
+ w1_scale = layer.w13_weight_scale
+ w2_scale = layer.w2_weight_scale
+ a1_scale = layer.w13_input_scale
+ a2_scale = layer.w2_input_scale
- return fp8_w8a8_moe_quant_config(
- w1_scale=layer.w13_weight_scale,
- w2_scale=layer.w2_weight_scale,
- a1_scale=layer.w13_input_scale,
- a2_scale=layer.w2_input_scale,
- per_act_token_quant=per_act_token,
- per_out_ch_quant=per_channel_quant,
- block_shape=layer.weight_block_size,
+ return make_fp8_moe_quant_config(
+ fp8_backend=self.fp8_backend,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ a1_scale=a1_scale,
+ a2_scale=a2_scale,
+ per_act_token_quant=(
+ self.input_quant.strategy == QuantizationStrategy.TOKEN
+ ),
+ per_out_ch_quant=(
+ self.weight_quant.strategy == QuantizationStrategy.CHANNEL
+ ),
+ block_shape=self.weight_block_size,
)
def apply(
@@ -959,6 +888,55 @@ def apply(
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
+ if layer.enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `FlashInfer TRTLLM FP8 MoE`."
+ )
+ assert layer.activation == "silu"
+
+ if self.block_quant:
+ import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
+
+ e_score_correction_bias = (
+ layer.e_score_correction_bias.to(x.dtype)
+ if layer.e_score_correction_bias is not None
+ else None
+ )
+ routing_method_type = layer.routing_method_type
+ return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
+ routing_logits=router_logits.to(torch.float32)
+ if routing_method_type == RoutingMethodType.DeepSeekV3
+ else router_logits,
+ routing_bias=e_score_correction_bias,
+ x=x,
+ w13_weight=layer.w13_weight,
+ w13_weight_scale_inv=layer.w13_weight_scale,
+ w2_weight=layer.w2_weight,
+ w2_weight_scale_inv=layer.w2_weight_scale,
+ global_num_experts=layer.global_num_experts,
+ top_k=layer.top_k,
+ num_expert_group=layer.num_expert_group,
+ topk_group=layer.topk_group,
+ intermediate_size=layer.intermediate_size_per_partition,
+ expert_offset=layer.ep_rank * layer.local_num_experts,
+ local_num_experts=layer.local_num_experts,
+ block_shape=self.weight_block_size,
+ routing_method_type=routing_method_type,
+ routed_scaling=layer.routed_scaling_factor,
+ )
+ else:
+ result = apply_fi_trtllm_fp8_per_tensor_moe(
+ layer=layer,
+ hidden_states=x,
+ router_logits=router_logits,
+ routing_bias=layer.e_score_correction_bias,
+ global_num_experts=layer.global_num_experts,
+ top_k=layer.top_k,
+ num_expert_group=layer.num_expert_group,
+ topk_group=layer.topk_group,
+ apply_router_weight_on_input=layer.apply_router_weight_on_input,
+ )
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
@@ -976,7 +954,7 @@ def apply(
global_num_experts=layer.global_num_experts,
# TODO(rob): investigate the disable_expert_map introduced by:
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
- expert_map=None if self.disable_expert_map else layer.expert_map,
+ expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
@@ -1435,8 +1413,7 @@ def select_gemm_impl(
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
return BatchedMarlinExperts(
- max_num_tokens=max_num_tokens_per_rank,
- num_dispatchers=prepare_finalize.num_dispatchers(),
+ moe_config=self.moe,
quant_config=self.moe_quant_config,
w13_g_idx=layer.w13_weight_g_idx,
w2_g_idx=layer.w2_weight_g_idx,
@@ -1446,6 +1423,7 @@ def select_gemm_impl(
)
else:
return MarlinExperts(
+ moe_config=self.moe,
quant_config=self.moe_quant_config,
w13_g_idx=layer.w13_weight_g_idx,
w2_g_idx=layer.w2_weight_g_idx,
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 1c0c35bf6f41..7fd0fa4f81cb 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -19,7 +19,6 @@
)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
- FusedMoEActivationFormat,
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
@@ -51,8 +50,6 @@
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
- build_flashinfer_fp8_cutlass_moe_prepare_finalize,
- select_cutlass_fp8_gemm_impl,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
@@ -76,6 +73,9 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
is_layer_skipped,
+ kFp8Dynamic128Sym,
+ kFp8Static128BlockSym,
+ kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
@@ -633,37 +633,40 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
self.weight_scale_name = (
"weight_scale_inv" if self.block_quant else "weight_scale"
)
- self.fp8_backend = select_fp8_moe_backend(
- block_quant=self.block_quant,
- tp_size=layer.moe_parallel_config.tp_size,
- with_lora_support=self.moe.is_lora_enabled,
- )
- if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- if self.block_quant and self.weight_block_size != [128, 128]:
- raise NotImplementedError(
- "FlashInfer CUTLASS FP8 MoE backend only supports block "
- "size [128, 128]."
- )
- if layer.activation != "silu":
- raise NotImplementedError(
- "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
- "activation function, but got {layer.activation}."
- )
- dynamic_per_token = (
- not self.block_quant and self.quant_config.activation_scheme != "static"
- )
- if dynamic_per_token and self.fp8_backend in [
- Fp8MoeBackend.FLASHINFER_TRTLLM,
- Fp8MoeBackend.FLASHINFER_CUTLASS,
- ]:
- raise NotImplementedError(
- "FlashInfer FP8 MoE backend does not support dynamic per token "
- "activation quantization."
+ # Set weight key and activation key for kernel compatibility
+ if self.block_quant:
+ weight_key = kFp8Static128BlockSym
+ activation_key = kFp8Dynamic128Sym
+ else:
+ weight_key = kFp8StaticTensorSym
+ activation_key = (
+ kFp8StaticTensorSym
+ if self.quant_config.activation_scheme == "static"
+ else kFp8Dynamic128Sym
)
+ # Select Fp8 MoE backend
+ self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
+ config=self.moe,
+ weight_key=weight_key,
+ activation_key=activation_key,
+ allow_vllm_cutlass=False,
+ )
+
+ # Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
+ @property
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ if self.kernel is not None:
+ return self.kernel.prepare_finalize.topk_indices_dtype()
+ return None
+
+ @property
+ def supports_mk_interally(self) -> bool:
+ return True
+
def create_weights(
self,
layer: Module,
@@ -819,11 +822,13 @@ def _setup_kernel(
# Setup modular kernel for TP case.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
+ assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel(
layer=layer,
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
+ experts_cls=self.experts_cls,
)
def process_weights_after_loading(self, layer: Module) -> None:
@@ -878,93 +883,28 @@ def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- if self.fp8_backend in [
- Fp8MoeBackend.AITER,
- Fp8MoeBackend.MARLIN,
- Fp8MoeBackend.FLASHINFER_TRTLLM,
- ]:
- return None
- elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
- self.moe,
- use_deepseek_fp8_block_scale=self.block_quant,
- )
- logger.debug_once("%s", prepare_finalize.__class__.__name__)
- return prepare_finalize
- return super().maybe_make_prepare_finalize(routing_tables)
+ raise ValueError(
+ "Fp8FusedMoE uses the new modular kernel initialization logic. "
+ "This function should not be called."
+ )
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
- from vllm.model_executor.layers.fused_moe import (
- BatchedDeepGemmExperts,
- BatchedTritonExperts,
- TritonExperts,
- TritonOrDeepGemmExperts,
- )
-
- if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
- raise NotImplementedError(
- "Marlin and ROCm AITER are not supported with all2all yet."
+ # TODO(rob): look into whether we can just not do this for LoRA.
+ if self.moe.is_lora_enabled:
+ from vllm.model_executor.layers.fused_moe import (
+ TritonExperts,
)
- assert self.moe_quant_config is not None
-
- if (
- prepare_finalize.activation_format
- == FusedMoEActivationFormat.BatchedExperts
- ):
- max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
- assert max_num_tokens_per_rank is not None
-
- experts_impl = (
- BatchedDeepGemmExperts
- if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
- else BatchedTritonExperts
- )
- logger.debug(
- "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
- experts_impl.__name__,
- self.__class__.__name__,
- max_num_tokens_per_rank,
- self.weight_block_size,
- False,
- )
- return experts_impl(
- max_num_tokens=max_num_tokens_per_rank,
- num_dispatchers=prepare_finalize.num_dispatchers(),
- quant_config=self.moe_quant_config,
- )
- elif self.moe.is_lora_enabled:
return TritonExperts(quant_config=self.moe_quant_config)
- elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- # Select GEMM experts with block-scale when weights are block-quantized
- experts = select_cutlass_fp8_gemm_impl(
- self.moe,
- self.moe_quant_config,
- use_deepseek_fp8_block_scale=self.block_quant,
- )
- logger.debug_once("Using %s", experts.__class__.__name__)
- return experts
- elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
- logger.debug(
- "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
- self.__class__.__name__,
- self.weight_block_size,
- False,
- )
- return TritonOrDeepGemmExperts(self.moe_quant_config)
- else:
- assert self.fp8_backend == Fp8MoeBackend.TRITON
- logger.debug(
- "TritonExperts(%s): block_size=%s, per_act_token=%s",
- self.__class__.__name__,
- self.weight_block_size,
- False,
- )
- return TritonExperts(self.moe_quant_config)
+
+ raise ValueError(
+ "Fp8FusedMoE uses the new modular kernel initialization logic. "
+ "This function should not be called."
+ )
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py
index bcda7b42c2ec..d999351d8264 100644
--- a/vllm/model_executor/layers/quantization/modelopt.py
+++ b/vllm/model_executor/layers/quantization/modelopt.py
@@ -31,7 +31,6 @@
select_fp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
- FLASHINFER_NVFP4_MOE_BACKENDS,
NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
@@ -51,15 +50,11 @@
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
- build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
- select_nvfp4_gemm_impl,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
- build_flashinfer_fp8_cutlass_moe_prepare_finalize,
- select_cutlass_fp8_gemm_impl,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
@@ -78,6 +73,9 @@
GroupShape,
cutlass_fp4_supported,
is_layer_skipped,
+ kFp8StaticTensorSym,
+ kNvfp4Dynamic,
+ kNvfp4Static,
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@@ -729,45 +727,45 @@ def __init__(
super().__init__(moe_config)
self.quant_config = quant_config
assert self.quant_config.is_checkpoint_fp8_serialized
- self.fp8_backend = select_fp8_moe_backend(
- block_quant=False,
- tp_size=moe_config.moe_parallel_config.tp_size,
- with_lora_support=self.moe.is_lora_enabled,
+
+ # Select Fp8 MoE backend
+ self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
+ config=self.moe,
+ weight_key=kFp8StaticTensorSym,
+ activation_key=kFp8StaticTensorSym,
)
+
+ # Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
+ @property
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ if self.kernel is not None:
+ return self.kernel.prepare_finalize.topk_indices_dtype()
+ return None
+
+ @property
+ def supports_mk_interally(self) -> bool:
+ return True
+
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- # TRT LLM not supported with all2all yet.
- if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- return None
- elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- # TP case: avoid convert to ModularKernelMethod - to be refactored.
- if self.moe.dp_size == 1:
- return None
-
- prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
- self.moe,
- use_deepseek_fp8_block_scale=False,
- )
- logger.debug_once("%s", prepare_finalize.__class__.__name__)
- return prepare_finalize
- return super().maybe_make_prepare_finalize(routing_tables)
+ raise ValueError(
+ "ModelOptFp8MoEMethod uses the new modular kernel initialization "
+ "logic. This function should not be called."
+ )
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
- assert self.moe_quant_config is not None
- experts = select_cutlass_fp8_gemm_impl(
- self.moe,
- self.moe_quant_config,
+ raise ValueError(
+ "ModelOptFp8MoEMethod uses the new modular kernel initialization "
+ "logic. This function should not be called."
)
- logger.debug_once("Using %s", experts.__class__.__name__)
- return experts
def create_weights(
self,
@@ -879,14 +877,16 @@ def _setup_kernel(
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight_scale", w2_scale)
- # Setup modular kernel for TP case.
+ # Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
+ assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel(
layer=layer,
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
+ experts_cls=self.experts_cls,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@@ -1335,55 +1335,49 @@ def __init__(
) -> None:
super().__init__(moe_config)
self.quant_config = quant_config
- self.nvfp4_backend = select_nvfp4_moe_backend()
- # TODO: move this type of check into the oracle.
- if (
- not self.moe.is_act_and_mul
- and not self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS
- ):
- raise NotImplementedError(
- "Non-gated activations are only supported by FlashInfer "
- "CUTLASS NvFP4 MoE backend."
- )
+ # Select experts implementation.
+ self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
+ config=self.moe,
+ weight_key=kNvfp4Static,
+ activation_key=kNvfp4Dynamic,
+ )
+ # Delay creation of the kernel until after process-weights.
+ self.kernel: mk.FusedMoEModularKernel | None = None
+
+ # Used for weight loading.
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend
)
- self.kernel: mk.FusedMoEModularKernel | None = None
+
+ @property
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ if self.kernel is not None:
+ return self.kernel.prepare_finalize.topk_indices_dtype()
+ return None
+
+ @property
+ def supports_mk_interally(self) -> bool:
+ return True
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM]
- if self.nvfp4_backend in UNSUPPORTED:
- return None
- elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
- # TP case: avoid convert to ModularKernelMethod - to be refactored.
- if self.moe.dp_size == 1:
- return None
- # For now, fp4 moe only works with the flashinfer dispatcher.
- prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
- self.moe
- )
- logger.debug_once("%s", prepare_finalize.__class__.__name__)
- return prepare_finalize
- else:
- return super().maybe_make_prepare_finalize(routing_tables)
+ raise ValueError(
+ "ModelOptNvFp4FusedMoE uses the new modular kernel "
+ "initialization logic. This function should not be called."
+ )
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
- assert self.moe_quant_config is not None
- experts = select_nvfp4_gemm_impl(
- self.moe,
- self.moe_quant_config,
- allow_flashinfer=self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS,
+ raise ValueError(
+ "ModelOptNvFp4FusedMoE uses the new modular kernel "
+ "initialization logic. This function should not be called."
)
- logger.debug_once("Using %s", experts.__class__.__name__)
- return experts
def uses_weight_scale_2_pattern(self) -> bool:
"""
@@ -1556,12 +1550,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
replace_parameter(layer, "w2_input_scale", a2_scale)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
- use_dp = self.moe.dp_size > 1
- if self.moe_quant_config is not None and not use_dp:
+ if self.moe_quant_config is not None:
+ assert self.experts_cls is not None
self.kernel = make_nvfp4_moe_kernel(
- backend=self.nvfp4_backend,
- quant_config=self.moe_quant_config,
+ layer=layer,
+ moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
+ experts_cls=self.experts_cls,
)
def prepare_dp_allgather_tensor(
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
index 912ff5a4a12a..8517162a9512 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
@@ -11,19 +11,13 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
- FusedMoEQuantConfig,
+ FusedMoEParallelConfig,
RoutingMethodType,
)
-from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
- FlashInferCuteDSLExperts,
-)
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
- FlashInferExperts,
-)
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
- create_flashinfer_prepare_finalize,
-)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kNvfp4Dynamic,
+ kNvfp4Static,
swizzle_blockscale,
)
from vllm.platforms import current_platform
@@ -44,9 +38,74 @@
"is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available",
"reorder_w1w3_to_w3w1",
- "build_flashinfer_fp4_cutlass_moe_prepare_finalize",
]
+#
+# Methods used by the oracle for kernel selection.
+#
+
+
+def _supports_current_device() -> bool:
+ """Supports only Blackwell-family GPUs."""
+ p = current_platform
+ # Add check flashinfer trtllm is available
+ return p.is_cuda() and p.is_device_capability_family(100)
+
+
+def _supports_no_act_and_mul() -> bool:
+ """Does not support non-gated MoE (i.e. Nemotron-Nano)."""
+ return False
+
+
+def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+) -> bool:
+ """Supports Nvfp4 quantization."""
+ SUPPORTED_W_A = [
+ (kNvfp4Static, kNvfp4Dynamic),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+
+def _supports_activation(activation: str) -> bool:
+ """Supports silu activation only."""
+ return activation in ["silu"]
+
+
+def _supports_moe_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ """Supports EP."""
+ return True
+
+
+def is_supported_config_trtllm(
+ moe_config: FusedMoEConfig,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ activation_format: mk.FusedMoEActivationFormat,
+) -> tuple[bool, str | None]:
+ """
+ This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
+ """
+
+ def _make_reason(reason: str) -> str:
+ return f"kernel does not support {reason}"
+
+ if not _supports_current_device():
+ return False, _make_reason("current device")
+ elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
+ return False, _make_reason("no act_and_mul MLP layer")
+ elif not _supports_activation(moe_config.activation):
+ return False, _make_reason(f"{moe_config.activation} activation")
+ elif not _supports_quant_scheme(weight_key, activation_key):
+ return False, _make_reason("quantization scheme")
+ elif not _supports_moe_parallel_config(moe_config.moe_parallel_config):
+ return False, _make_reason("parallel config")
+ elif activation_format != mk.FusedMoEActivationFormat.Standard:
+ return False, _make_reason("activation format")
+
+ return True, None
+
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
@@ -85,48 +144,6 @@ def reorder_w1w3_to_w3w1(
)
-def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
- moe: FusedMoEConfig,
-) -> mk.FusedMoEPrepareAndFinalize:
- """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
- use_dp = moe.moe_parallel_config.dp_size > 1
- enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv"
- return create_flashinfer_prepare_finalize(
- use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv
- )
-
-
-def select_nvfp4_gemm_impl(
- moe: FusedMoEConfig,
- moe_quant_config: FusedMoEQuantConfig,
- allow_flashinfer: bool,
-) -> mk.FusedMoEPermuteExpertsUnpermute:
- """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
-
- if allow_flashinfer:
- if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
- return FlashInferCuteDSLExperts(
- out_dtype=moe.in_dtype,
- quant_config=moe_quant_config,
- )
- elif envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput":
- return FlashInferExperts(
- out_dtype=moe.in_dtype,
- quant_config=moe_quant_config,
- ep_rank=moe.moe_parallel_config.ep_rank,
- ep_size=moe.moe_parallel_config.ep_size,
- tp_rank=moe.moe_parallel_config.tp_rank,
- tp_size=moe.moe_parallel_config.tp_size,
- use_dp=moe.moe_parallel_config.dp_size > 1,
- )
-
- # native cutlass experts currently don't support DP; TP case won't call this
- raise ValueError(
- "CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS "
- "Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)"
- )
-
-
def prepare_static_weights_for_trtllm_fp4_moe(
# args_dequant,
# args,
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
index 799854479823..19e7124684ad 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
@@ -4,19 +4,8 @@
import torch
-import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe.config import (
- FusedMoEConfig,
- FusedMoEQuantConfig,
-)
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
- FlashInferExperts,
-)
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
- create_flashinfer_prepare_finalize,
-)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
@@ -191,45 +180,6 @@ def make_fp8_moe_alpha_scales_for_fi(
return g1_alphas, g2_alphas
-def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
- moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False
-) -> mk.FusedMoEPrepareAndFinalize:
- """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
- use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
- # Propagate block-scale flag so prepare/finalize can skip act quantization
- # and inform the kernel to consume per-block weight scales.
- return create_flashinfer_prepare_finalize(
- use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
- )
-
-
-def select_cutlass_fp8_gemm_impl(
- moe: FusedMoEConfig | None,
- quant_config: FusedMoEQuantConfig,
- out_dtype: torch.dtype | None = None,
- use_deepseek_fp8_block_scale: bool = False,
-) -> mk.FusedMoEPermuteExpertsUnpermute:
- """Return a GEMM *experts* implementation for fused-MoE layers"""
-
- if moe is not None:
- return FlashInferExperts(
- out_dtype=moe.in_dtype,
- quant_config=quant_config,
- ep_rank=moe.moe_parallel_config.ep_rank,
- ep_size=moe.moe_parallel_config.ep_size,
- tp_rank=moe.moe_parallel_config.tp_rank,
- tp_size=moe.moe_parallel_config.tp_size,
- use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
- )
-
- assert out_dtype is not None, "If moe config is None, out_dtype must be passed"
- return FlashInferExperts(
- out_dtype=out_dtype,
- quant_config=quant_config,
- use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
- )
-
-
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS,
diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py
index 08262ed1a314..15b7f3444cde 100644
--- a/vllm/model_executor/layers/quantization/utils/quant_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py
@@ -48,6 +48,7 @@ class GroupShape(_GroupShape):
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar["GroupShape"]
PER_TOKEN: ClassVar["GroupShape"]
+ PER_CHANNEL: ClassVar["GroupShape"]
def is_per_tensor(self) -> bool:
return self.row == -1 and self.col == -1
@@ -55,12 +56,16 @@ def is_per_tensor(self) -> bool:
def is_per_token(self) -> bool:
return self.row == 1 and self.col == -1
+ def is_per_channel(self) -> bool:
+ return self.row == -1 and self.col == 1
+
def is_per_group(self) -> bool:
return self.row == 1 and self.col >= 1
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
+GroupShape.PER_CHANNEL = GroupShape(-1, 1)
@dataclass(frozen=True)
@@ -77,16 +82,12 @@ class ScaleDesc:
group_shape: GroupShape
def __str__(self):
- group_shape = (
- "per_tensor"
- if self.group_shape == GroupShape.PER_TENSOR
- else (
- "per_token"
- if self.group_shape == GroupShape.PER_TOKEN
- else str(self.group_shape)
- )
- )
-
+ d = {
+ GroupShape.PER_TENSOR: "per_tensor",
+ GroupShape.PER_TOKEN: "per_token",
+ GroupShape.PER_CHANNEL: "per_channel",
+ }
+ group_shape = d.get(self.group_shape, str(self.group_shape))
return (
f"{fx.graph.dtype_abbrs[self.dtype]},"
f"{'static' if self.static else 'dynamic'},{group_shape}"
@@ -123,15 +124,28 @@ def __str__(self):
kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
+kStaticChannelScale = ScaleDesc(torch.float32, True, GroupShape.PER_CHANNEL)
+kFp8StaticChannelSym = QuantKey(FP8_DTYPE, kStaticChannelScale, symmetric=True)
+
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
-kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
-kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale)
+kNvfp4DynamicGroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
+kNvfp4Dynamic = QuantKey(
+ FP4_DTYPE, scale=kNvfp4DynamicGroupScale, scale2=kStaticTensorScale
+)
+
+kNvfp4StaticGroupScale = ScaleDesc(FP8_DTYPE, True, GroupShape(1, 16))
+kNvfp4Static = QuantKey(
+ FP4_DTYPE, scale=kNvfp4StaticGroupScale, scale2=kStaticTensorScale
+)
kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128))
kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True)
+kStatic128BlockScale = ScaleDesc(torch.float32, True, GroupShape(128, 128))
+kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True)
+
kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64))
kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)
diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py
index dde6db7c204b..7cc170638428 100644
--- a/vllm/model_executor/models/llama4.py
+++ b/vllm/model_executor/models/llama4.py
@@ -434,6 +434,7 @@ def load_moe_expert_weights(
# Whether the MoE expert weights are loaded successfully.
expert_param_loaded = False
+ loaded_weight = loaded_weight.to("cuda")
# If fused is True, the loaded weight is in the layout of:
# [num_experts, hidden_in, hidden_out], so we must transpose the last
diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py
index 9892c360d3d6..5ed860cc56a9 100755
--- a/vllm/v1/attention/backends/flashinfer.py
+++ b/vllm/v1/attention/backends/flashinfer.py
@@ -29,7 +29,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
- kNvfp4Quant,
+ kNvfp4Dynamic,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
@@ -1184,7 +1184,7 @@ def fused_output_quant_supported(self, quant_key: QuantKey):
return (
self.support_trtllm_attn
and self.kv_cache_dtype.startswith("fp8")
- and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
+ and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
)
# FlashInfer requires attention sinks to be float32