diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 18216b5965af..5b35b152ff38 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -88,7 +88,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels | deep gemm | standard,
batched | fp8 | G(128),A,T | silu, gelu | 6 | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],
[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],
[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] | | cutlass_fp4 | standard,
batched | nvfp4 | A,T | silu | Y | Y | [`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] | | cutlass_fp8 | standard,
batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],
[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] | -| flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],
[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | +| flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | marlin | standard,
batched | 3 / N/A | 3 / N/A | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index a1fd098aee5f..66d4f8d2caf3 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform @@ -202,7 +202,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): """Test model for AttentionNvfp4QuantPattern fusion.""" - quant_key = kNvfp4Quant + quant_key = kNvfp4Dynamic def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -449,7 +449,7 @@ def test_attention_quant_pattern( # Note: for fp8, fully_replaced=False because query quant ops remain in graph. # Only output quant ops are fused into attention. - test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant) + test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic) # access the underlying `AttnFusionPass` on the `LazyInitPass` assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index eb0dee8d4e39..446f8c7626b8 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -29,7 +29,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, @@ -121,11 +121,11 @@ def forward(self, x): def ops_in_model_before(self): return [ SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul, - QUANT_OPS[kNvfp4Quant], + QUANT_OPS[kNvfp4Dynamic], ] def ops_in_model_after(self): - return [FUSED_OPS[kNvfp4Quant]] + return [FUSED_OPS[kNvfp4Dynamic]] class TestSiluMulGroupFp8QuantModel(torch.nn.Module): diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index 537dcae4e74b..dd0519c1953a 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -161,15 +161,15 @@ def is_fp8_block_quantized(self): def is_batched_prepare_finalize(self): info = prepare_finalize_info(self.prepare_finalize_type) - return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format + return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format() def is_batched_fused_experts(self): info = expert_info(self.fused_experts_type) - return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format + return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format() def is_standard_fused_experts(self): info = expert_info(self.fused_experts_type) - return mk.FusedMoEActivationFormat.Standard == info.activation_format + return mk.FusedMoEActivationFormat.Standard == info.activation_format() def fe_supported_types(self): info = expert_info(self.fused_experts_type) @@ -574,10 +574,13 @@ def next_power_of_2(x): num_experts=config.E, experts_per_token=config.topk, hidden_dim=config.K, + intermediate_size_per_partition=config.N, num_local_experts=config.num_local_experts, moe_parallel_config=moe_parallel_config, in_dtype=config.dtype, max_num_tokens=next_power_of_2(config.M), + activation="silu", + device=vllm_config.device_config.device, ) # make modular kernel diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 99b168dc7554..f085706bb22c 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -237,7 +237,6 @@ def expert_info(kind) -> ExpertInfo: ) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 FlashInferCutlassMoEPrepareAndFinalize, - create_flashinfer_prepare_finalize, ) register_prepare_and_finalize( @@ -389,13 +388,12 @@ def make_prepare_finalize( quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEPrepareAndFinalize: if backend != "naive" and backend is not None: - prepare_finalize = maybe_make_prepare_finalize(moe, quant_config) + # TODO(rob): add defer input quant. + prepare_finalize = maybe_make_prepare_finalize( + moe, quant_config, allow_new_interface=True + ) assert prepare_finalize is not None return prepare_finalize - elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize: - return create_flashinfer_prepare_finalize( - use_dp=moe.moe_parallel_config.dp_size > 1 - ) else: return MoEPrepareAndFinalizeNoEP() diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index bb2f6b873941..2fa00a8effb7 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -127,6 +127,7 @@ def make_moe_tensors_8bit( ep_rank=0, use_ep=False, all2all_backend="naive", + isequence_parallel=False, ) # flashinfer expects swapped rows for w13 diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index 1262eea70bab..d9b7f4678534 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -12,15 +12,19 @@ from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe, ) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( - create_flashinfer_prepare_finalize, -) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.torch_utils import set_random_seed @@ -86,9 +90,39 @@ def test_flashinfer_fp4_moe_no_graph( assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) + moe_config = FusedMoEConfig( + num_experts=e, + experts_per_token=topk, + hidden_dim=k, + intermediate_size_per_partition=n, + num_local_experts=e, + activation=activation, + device="cuda", + moe_parallel_config=FusedMoEParallelConfig( + tp_size=1, + pcp_size=1, + dp_size=1, + ep_size=1, + tp_rank=0, + pcp_rank=0, + dp_rank=0, + ep_rank=0, + use_ep=False, + all2all_backend="allgather_reducescatter", + isequence_parallel=False, + ), + in_dtype=dtype, + is_act_and_mul=is_gated_act, + ) + flashinfer_experts = FusedMoEModularKernel( - create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True), - FlashInferExperts(out_dtype=dtype, quant_config=quant_config), + MoEPrepareAndFinalizeNoEP( + defer_input_quant=FlashInferExperts.should_pf_defer_input_quant( + moe_config=moe_config, + quant_config=quant_config, + ) + ), + FlashInferExperts(moe_config=moe_config, quant_config=quant_config), ) fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation] diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 873d72117de7..97db526c2ee4 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -93,7 +93,6 @@ def test_cutlass_fp4_moe_no_graph( MoEPrepareAndFinalizeNoEP(defer_input_quant=True), CutlassExpertsFp4( out_dtype=dtype, - max_experts_per_worker=e, quant_config=quant_config, ), ) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index f0ce5b3db7ec..8530c0dadb23 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.platforms import current_platform @@ -41,7 +41,7 @@ torch.ops._C, "silu_and_mul_nvfp4_quant" ) if silu_and_mul_nvfp4_quant_supported: - FUSED_OPS[kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 + FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 class ActivationQuantPattern(ABC): @@ -129,7 +129,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): """ def __init__(self) -> None: - super().__init__(kNvfp4Quant) + super().__init__(kNvfp4Dynamic) def get_inputs(self) -> list[torch.Tensor]: result = self.empty_quant(5, 32) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index e3c6c2f20141..667828cc6a5a 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -20,7 +20,7 @@ kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, kStaticTensorScale, ) from vllm.platforms import current_platform @@ -63,7 +63,7 @@ def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor: kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): - QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default + QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default if current_platform.is_cuda(): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 57448aa0b096..69dc2e3a6868 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -16,7 +16,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, - kNvfp4Quant, + kNvfp4Dynamic, kStaticTensorScale, ) from vllm.platforms import current_platform @@ -217,7 +217,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): """ def __init__(self, layer: Attention, dtype: torch.dtype) -> None: - super().__init__(layer, kNvfp4Quant, dtype) + super().__init__(layer, kNvfp4Dynamic, dtype) def _register(self, pm_pass: PatternMatcherPass) -> None: def pattern( diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index eda12180d493..7bb98db5e9f6 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -21,7 +21,7 @@ kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform @@ -38,7 +38,7 @@ } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): - QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 + QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 if current_platform.is_cuda(): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 363078aefc51..ae90d9bd5201 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -425,6 +425,7 @@ def stateless_init_dp_group(self) -> ProcessGroup: # to avoid the excess work. # # Not needed for pplx-kernels as it can handle duplicate input tokens. + # TODO(rob): investigate 'flashinfer_all2allv'? @property def use_sequence_parallel_moe(self) -> bool: return ( diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 7a4e81cf967d..520bb7613ee9 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -62,10 +62,14 @@ def naive_multicast( def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): if extra_tensors is not None: raise NotImplementedError( "extra_tensors is not supported for NaiveAll2AllManager" @@ -78,11 +82,21 @@ def dispatch( hidden_states = self.naive_multicast( hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel ) - router_logits = self.naive_multicast( - router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel + topk_weights = self.naive_multicast( + topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel ) + topk_ids = self.naive_multicast( + topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel + ) + + if extra_tensors is None: + return hidden_states, topk_weights, topk_ids - return hidden_states, router_logits + extra_tensors = [ + self.naive_multicast(t, cu_tokens_across_sp_cpu, is_sequence_parallel) + for t in extra_tensors + ] + return hidden_states, topk_weights, topk_ids, extra_tensors def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False @@ -117,12 +131,13 @@ def __init__(self, cpu_group): def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, ) -> ( - tuple[torch.Tensor, torch.Tensor] - | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] ): """ Gather hidden_states and router_logits from all dp ranks. @@ -134,7 +149,7 @@ def dispatch( dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] - tensors_to_gather = [hidden_states, router_logits] + tensors_to_gather = [hidden_states, topk_weights, topk_ids] if extra_tensors is not None: tensors_to_gather.extend(extra_tensors) @@ -144,9 +159,14 @@ def dispatch( sizes=sizes, ) - if extra_tensors is not None: - return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:]) - return gathered_tensors[0], gathered_tensors[1] + hidden_states = gathered_tensors[0] + topk_weights = gathered_tensors[1] + topk_ids = gathered_tensors[2] + + if extra_tensors is None: + return hidden_states, topk_weights, topk_ids + + return hidden_states, topk_weights, topk_ids, gathered_tensors[3:] def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False @@ -219,10 +239,11 @@ def get_handle(self, kwargs): def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: raise NotImplementedError def combine( @@ -267,10 +288,11 @@ def get_handle(self, kwargs): def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: raise NotImplementedError def combine( diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 8bc361741cae..4cc6bc8fef05 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -67,7 +67,8 @@ def get_handle(self, kwargs): def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, ) -> Any: @@ -283,15 +284,16 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Dispatch the hidden states and router logits to the appropriate device. This is a no-op in the base class. """ - return hidden_states, router_logits + return hidden_states, topk_weights, topk_ids def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 9542498c453e..6d725be241e1 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -321,17 +321,19 @@ def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None): def dispatch( # type: ignore[override] self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, ) -> ( - tuple[torch.Tensor, torch.Tensor] - | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] ): assert self.all2all_manager is not None return self.all2all_manager.dispatch( hidden_states, - router_logits, + topk_weights, + topk_ids, is_sequence_parallel, extra_tensors, # type: ignore[call-arg] ) diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index f3d9262d20cf..d7a9d832370f 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -76,14 +76,16 @@ def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert self.all2all_manager is not None return self.all2all_manager.dispatch( hidden_states, - router_logits, + topk_weights, + topk_ids, is_sequence_parallel, extra_tensors, # type: ignore[call-arg] ) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d8c6ceba309f..2f1e4d005570 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1003,22 +1003,24 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module): def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, ) -> ( - tuple[torch.Tensor, torch.Tensor] - | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] ): if self.device_communicator is not None: return self.device_communicator.dispatch( # type: ignore[call-arg] hidden_states, - router_logits, + topk_weights, + topk_ids, is_sequence_parallel, extra_tensors, ) else: - return hidden_states, router_logits + return hidden_states, topk_weights, topk_ids def combine( self, hidden_states, is_sequence_parallel: bool = False diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 036b3cac4cb3..43f75cc31428 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -12,9 +12,16 @@ FusedMoEParallelConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( + FlashInferAllToAllMoEPrepareAndFinalize, +) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPrepareAndFinalize, ) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNaiveEP, + MoEPrepareAndFinalizeNoEP, +) from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_ep, has_pplx @@ -68,20 +75,21 @@ def maybe_make_prepare_finalize( moe: FusedMoEConfig, quant_config: FusedMoEQuantConfig | None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + defer_input_quant: bool = False, + allow_new_interface: bool = False, ) -> FusedMoEPrepareAndFinalize | None: if not moe.moe_parallel_config.use_all2all_kernels: - return None + # What happens if DP>1 and EP=1? + if allow_new_interface: + return MoEPrepareAndFinalizeNoEP(defer_input_quant) + else: + return None all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None prepare_finalize: FusedMoEPrepareAndFinalize | None = None - # TODO(rob): update this as part of the MoE refactor. - assert not moe.use_flashinfer_cutlass_kernels, ( - "Must be created in modelopt.py or fp8.py" - ) - if moe.use_pplx_kernels: assert quant_config is not None @@ -170,4 +178,20 @@ def maybe_make_prepare_finalize( local_expert_global_ids=local_expert_global_ids, ) + elif moe.use_fi_all2allv_kernels: + assert quant_config is not None + # TODO: audit if this supports all cases. + prepare_finalize = FlashInferAllToAllMoEPrepareAndFinalize( + use_dp=True, + num_dispatchers=all2all_manager.world_size, + use_deepseek_fp8_block_scale=quant_config.is_block_quantized, + ) + + elif moe.use_naive_kernels and allow_new_interface: + prepare_finalize = MoEPrepareAndFinalizeNaiveEP( + defer_input_quant, + is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel), + num_dispatchers=all2all_manager.world_size, + ) + return prepare_finalize diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index e598ec3acb3d..0d82eec3569c 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -7,11 +7,20 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, ) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, +) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import ( @@ -19,6 +28,7 @@ fp8_m_grouped_gemm_nt_masked, get_mk_alignment_for_contiguous_layout, is_deep_gemm_e8m0_used, + is_deep_gemm_supported, ) from vllm.utils.math_utils import cdiv, round_up @@ -253,8 +263,7 @@ def persistent_masked_m_silu_mul_quant( class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_num_tokens: int, - num_dispatchers: int, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): """ @@ -262,20 +271,37 @@ def __init__( num_dispatchers: The number of DP dispatchers. quant_config: Quantization configuration """ - super().__init__(quant_config) + super().__init__(moe_config, quant_config) assert self.block_shape == get_mk_alignment_for_contiguous_layout() assert self.quant_config.use_fp8_w8a8 - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + @staticmethod + def _supports_current_device() -> bool: + return is_deep_gemm_supported() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [(kFp8Static128BlockSym, kFp8Dynamic128Sym)] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True def supports_chunking(self) -> bool: return False diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 23b86fdca898..b5cd5167e7e1 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -20,7 +20,6 @@ ) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform -from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.import_utils import has_triton_kernels from vllm.utils.math_utils import cdiv @@ -858,6 +857,7 @@ class FusedMoEParallelConfig: use_ep: bool # whether to use EP or not all2all_backend: str # all2all backend for MoE communication + is_sequence_parallel: bool # whether sequence parallelism is used @property def use_all2all_kernels(self): @@ -878,6 +878,18 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" + @property + def use_fi_all2allv_kernels(self): + return ( + self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv" + ) + + @property + def use_naive_kernels(self): + return self.use_all2all_kernels and ( + self.all2all_backend in ["allgather_reducescatter", "naive"] + ) + @staticmethod def flatten_tp_across_dp_and_pcp( tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int @@ -995,6 +1007,7 @@ def make( ep_rank=0, use_ep=False, all2all_backend=vllm_parallel_config.all2all_backend, + is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe, ) # DP + EP / TP + EP / DP + TP + EP assert use_ep @@ -1013,6 +1026,7 @@ def make( ep_rank=ep_rank, use_ep=True, all2all_backend=vllm_parallel_config.all2all_backend, + is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe, ) @@ -1022,8 +1036,10 @@ class FusedMoEConfig: num_experts: int experts_per_token: int hidden_dim: int - + intermediate_size_per_partition: int num_local_experts: int + activation: str + device: torch.dtype moe_parallel_config: FusedMoEParallelConfig # The activation type. @@ -1100,12 +1116,9 @@ def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels @property - def use_flashinfer_cutlass_kernels(self): - """ - Whether to use FlashInfer cutlass kernels for NVFP4 MoE. - """ - return ( - envs.VLLM_USE_FLASHINFER_MOE_FP4 - and has_flashinfer_cutlass_fused_moe() - and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput" - ) + def use_fi_all2allv_kernels(self): + return self.moe_parallel_config.use_fi_all2allv_kernels + + @property + def use_naive_kernels(self): + return self.moe_parallel_config.use_naive_kernels diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index c0ffa38fdb2c..e41fa3f7d50c 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -9,7 +9,11 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( moe_permute, moe_unpermute, @@ -22,6 +26,16 @@ TopKWeightAndReduceNoOP, ) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + cutlass_group_gemm_supported, +) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -236,29 +250,54 @@ def run_cutlass_moe_fp8( class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - e: int, - n: int, - k: int, - out_dtype: torch.dtype | None, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, - device: torch.dtype, ): assert quant_config.use_fp8_w8a8 - super().__init__(quant_config) + super().__init__(moe_config=moe_config, quant_config=quant_config) - # E: num_experts - # N: intermediate size per partition - # K: hidden dim + e = moe_config.num_local_experts + n = moe_config.intermediate_size_per_partition + k = moe_config.hidden_dim + device = moe_config.device ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64) ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64) c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64) - self.out_dtype = out_dtype + self.out_dtype = moe_config.in_dtype self.ab_strides1 = ab_strides1_c_strides2 self.ab_strides2 = ab_strides2 self.c_strides1 = c_strides1 self.c_strides2 = ab_strides1_c_strides2 + @staticmethod + def _supports_current_device() -> bool: + return cutlass_group_gemm_supported() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [ + (kFp8StaticChannelSym, kFp8DynamicTokenSym), + (kFp8StaticTensorSym, kFp8DynamicTensorSym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceDelegate() @@ -291,7 +330,7 @@ def apply( activation_callable = lambda o, i: self.activation(activation, o, i) use_batched_format = ( - self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts + self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts ) in_dtype = hidden_states.dtype @@ -324,14 +363,9 @@ def apply( class CutlassExpertsFp8(CutlassExpertsFp8Base): - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_chunking(self) -> bool: return True @@ -365,26 +399,9 @@ def workspace_shapes( class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): - def __init__( - self, - max_experts_per_worker: int, - num_dispatchers: int, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - assert max_experts_per_worker > 0 - self.max_experts_per_worker = max_experts_per_worker - self.num_dispatchers = num_dispatchers - - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts def supports_chunking(self) -> bool: return False @@ -408,14 +425,15 @@ def workspace_shapes( ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers assert num_dp is not None + experts_per_worker = self.moe_config.num_local_experts activation_out_dim = self.adjust_N_for_activation(N, activation) - workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K)) + workspace1 = (experts_per_worker, M * num_dp, max(N, K)) workspace2 = ( - self.max_experts_per_worker, + experts_per_worker, M * num_dp, max(activation_out_dim, K), ) - output = (self.max_experts_per_worker, M, K) + output = (experts_per_worker, M, K) return (workspace1, workspace2, output) @@ -597,30 +615,16 @@ def run_cutlass_moe_fp4( class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_experts_per_worker: int, - out_dtype: torch.dtype, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, - use_batched_format: bool = False, ): - super().__init__(quant_config) - self.max_experts_per_worker = max_experts_per_worker - self.out_dtype = out_dtype - self.use_batched_format = use_batched_format + super().__init__(moe_config=moe_config, quant_config=quant_config) + self.max_experts_per_worker = moe_config.num_local_experts + self.out_dtype = moe_config.in_dtype - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - if self.use_batched_format: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) - else: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_expert_map(self) -> bool: return False @@ -649,7 +653,8 @@ def workspace_shapes( workspace1: tuple[int, ...] = () workspace2: tuple[int, ...] = () output: tuple[int, ...] = () - if self.use_batched_format: + if False: + # if self.use_batched_format: workspace1 = (self.max_experts_per_worker, M, max(N, K)) workspace2 = (self.max_experts_per_worker, M, activation_out_dim) output = (self.max_experts_per_worker, M, K) @@ -860,10 +865,11 @@ def __init__( c_strides2: torch.Tensor, s_strides1: torch.Tensor, s_strides2: torch.Tensor, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, group_size: int, ): - super().__init__(quant_config) + super().__init__(moe_config=moe_config, quant_config=quant_config) self.out_dtype = out_dtype self.a_strides1 = a_strides1 self.a_strides2 = a_strides2 @@ -875,13 +881,46 @@ def __init__( self.s_strides2 = s_strides2 self.group_size = group_size - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_activation(activation: str) -> bool: + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." ) def supports_chunking(self) -> bool: @@ -939,7 +978,7 @@ def apply( activation_callable = lambda o, i: self.activation(activation, o, i) use_batched_format = ( - self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts + self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts ) assert not use_batched_format, "batched format not supported" @@ -995,6 +1034,7 @@ def cutlass_moe_w4a8_fp8( s_strides1: torch.Tensor, s_strides2: torch.Tensor, quant_config: FusedMoEQuantConfig, + moe_config: FusedMoEConfig, activation: str = "silu", expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, @@ -1068,6 +1108,7 @@ def cutlass_moe_w4a8_fp8( c_strides2=c_strides2, s_strides1=s_strides1, s_strides2=s_strides2, + moe_config=moe_config, quant_config=quant_config, group_size=group_size, ), diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index a2e5a07fbfd2..a86f39e96959 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -6,17 +6,15 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, FusedMoEQuantConfig, - fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce, ) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) @@ -26,9 +24,15 @@ per_token_group_quant_fp8_packed_for_deepgemm, silu_mul_per_token_group_quant_fp8_colmajor, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, +) from vllm.utils.deep_gemm import ( DeepGemmQuantScaleFMT, get_mk_alignment_for_contiguous_layout, + is_deep_gemm_supported, m_grouped_fp8_gemm_nt_contiguous, ) from vllm.utils.import_utils import has_deep_gemm @@ -109,21 +113,42 @@ def _valid_deep_gemm( class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config: FusedMoEQuantConfig): - super().__init__(quant_config) + def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig): + super().__init__(moe_config=moe_config, quant_config=quant_config) assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout() assert quant_config.quant_dtype == torch.float8_e4m3fn assert not quant_config.per_act_token_quant assert not quant_config.per_out_ch_quant - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + return is_deep_gemm_supported() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True def supports_chunking(self) -> bool: return True @@ -338,27 +363,27 @@ def deep_gemm_moe_fp8( Returns: - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. """ - quant_config = fp8_w8a8_moe_quant_config( - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=get_mk_alignment_for_contiguous_layout(), - ) - - fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - DeepGemmExperts(quant_config), - ) - return fn( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace=inplace, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) + # quant_config = fp8_w8a8_moe_quant_config( + # w1_scale=w1_scale, + # w2_scale=w2_scale, + # a1_scale=a1_scale, + # a2_scale=a2_scale, + # block_shape=get_mk_alignment_for_contiguous_layout(), + # ) + + # fn = mk.FusedMoEModularKernel( + # MoEPrepareAndFinalizeNoEP(), + # DeepGemmExperts(quant_config), + # ) + # return fn( + # hidden_states, + # w1, + # w2, + # topk_weights, + # topk_ids, + # inplace=inplace, + # activation=activation, + # global_num_experts=global_num_experts, + # expert_map=expert_map, + # apply_router_weight_on_input=apply_router_weight_on_input, + # ) diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py index 4556392144a0..9c9e416f366a 100644 --- a/vllm/model_executor/layers/fused_moe/fallback.py +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -6,28 +6,73 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): """Base class for runtime dispatching of expert implementations.""" + _experts_cls = mk.FusedMoEPermuteExpertsUnpermute + _fallback_cls = mk.FusedMoEPermuteExpertsUnpermute + def __init__( self, experts: mk.FusedMoEPermuteExpertsUnpermute, fallback_experts: mk.FusedMoEPermuteExpertsUnpermute, ): - super().__init__(experts.quant_config) + super().__init__( + moe_config=experts.moe_config, quant_config=experts.quant_config + ) self.fallback_experts = fallback_experts self.experts = experts - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: assert ( - self.fallback_experts.activation_formats == self.experts.activation_formats + FallbackExperts._experts_cls.activation_format() + == FallbackExperts._fallback_cls.activation_format() + ) + return FallbackExperts._experts_cls.activation_format() + + @staticmethod + def _supports_current_device() -> bool: + return ( + FallbackExperts._experts_cls._supports_current_device() + and FallbackExperts._fallback_cls._supports_current_device() + ) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return ( + FallbackExperts._experts_cls._supports_no_act_and_mul() + and FallbackExperts._fallback_cls._supports_no_act_and_mul() + ) + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return FallbackExperts._experts_cls._supports_quant_scheme( + weight_key, activation_key + ) and FallbackExperts._fallback_cls._supports_quant_scheme( + weight_key, activation_key + ) + + @staticmethod + def _supports_activation(activation: str) -> bool: + return FallbackExperts._experts_cls._supports_activation( + activation + ) and FallbackExperts._fallback_cls._supports_activation(activation) + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return FallbackExperts._experts_cls._supports_parallel_config( + moe_parallel_config + ) and FallbackExperts._fallback_cls._supports_parallel_config( + moe_parallel_config ) - return self.fallback_experts.activation_formats def supports_chunking(self) -> bool: assert ( diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index 1651f3530eef..82ba560f63d3 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -6,13 +6,22 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kNvfp4Dynamic, + kNvfp4Static, +) +from vllm.platforms import current_platform from vllm.utils.flashinfer import ( flashinfer_cutedsl_grouped_gemm_nt_masked, - has_flashinfer_cutedsl_grouped_gemm_nt_masked, scaled_fp4_grouped_quantize, silu_and_mul_scaled_nvfp4_experts_quantize, ) @@ -20,54 +29,48 @@ logger = init_logger(__name__) -def is_valid_flashinfer_cutedsl_fused_moe( - hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor -) -> bool: - """ - Check if the given problem size is supported by the FlashInfer CuteDSL MoE - kernel. - """ - if not has_flashinfer_cutedsl_grouped_gemm_nt_masked(): - logger.debug_once( - "FlashInferCuteDSLExperts disabled: " - "flashinfer_cutedsl_fused_moe not available." - ) - return False - # Data type checks - if ( - w1.dtype != torch.uint8 - or w2.dtype != torch.uint8 - or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16] - ): - logger.debug_once( - "FlashInferCuteDSLExperts disabled: w1/w2 must be torch.uint8 " - f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be " - f"float32, float16, or bfloat16 (got {hidden_states.dtype})." - ) - return False - return True - - class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - out_dtype: torch.dtype, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): - super().__init__(quant_config) + super().__init__(moe_config, quant_config) assert quant_config.quant_dtype == "nvfp4", ( "Only nvfp4 quantization are currently supported." ) - self.out_dtype = out_dtype + self.out_dtype = moe_config.in_dtype - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + @staticmethod + def _supports_current_device() -> bool: + # TODO: add check cutedsl support? + return current_platform.has_device_capability((10, 0)) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [ + (kNvfp4Static, kNvfp4Dynamic), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True def supports_expert_map(self) -> bool: return False diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index ae60e15db841..5d68e8f0cce9 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -5,13 +5,22 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - create_flashinfer_prepare_finalize, +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, + FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, + kFp8StaticTensorSym, + kNvfp4Dynamic, + kNvfp4Static, +) +from vllm.platforms import current_platform from vllm.utils.flashinfer import ( flashinfer_cutlass_fused_moe, has_flashinfer_cutlass_fused_moe, @@ -50,40 +59,92 @@ def is_valid_flashinfer_cutlass_fused_moe( class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - out_dtype: torch.dtype, + moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig, - ep_rank: int = 0, - ep_size: int = 1, - tp_rank: int = 0, - tp_size: int = 1, - use_dp: bool = False, - use_deepseek_fp8_block_scale: bool = False, ): - super().__init__(quant_config) + super().__init__(moe_config, quant_config) assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), ( "Only nvfp4, fp8, bfloat16 and" " float16 quantization are currently supported." ) - self.ep_rank = ep_rank - self.ep_size = ep_size - self.tp_rank = tp_rank - self.tp_size = tp_size - self.out_dtype = out_dtype - self.use_dp = use_dp + self.ep_rank = moe_config.moe_parallel_config.ep_rank + self.ep_size = moe_config.moe_parallel_config.ep_size + self.tp_rank = moe_config.moe_parallel_config.tp_rank + self.tp_size = moe_config.moe_parallel_config.tp_size + self.out_dtype = moe_config.in_dtype + self.use_dp = moe_config.moe_parallel_config.dp_size > 1 # Enables DeepSeek-style FP8 block-scale path: # - pass per-block weight scales to the kernel # - skip input activation quantization (kernel applies scaling) - self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale + self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + @staticmethod + def should_pf_defer_input_quant( + moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig + ) -> bool: + # NVFP4 TP kernels and FP8 block-quantized kernels apply + # input quantization inside FusedMoEPermuteExpertsUnpermute. + return ( + quant_config.use_nvfp4_w4a4 + and not moe_config.moe_parallel_config.use_all2all_kernels + ) or (quant_config.use_fp8_w8a8 and quant_config.is_block_quantized) + + @staticmethod + def _supports_current_device() -> bool: return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, + current_platform.is_cuda() + # Is this right? Or 9.0+10.0 only? + and current_platform.has_device_capability((9, 0)) + and has_flashinfer_cutlass_fused_moe() ) + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # The following are supported by FlashInferExperts: + # * unquantized + # * fp8 static per-tensor on 9.0+ + # * fp8 block on 9.0 + # * nvfp4 on 10.0+ + + p = current_platform + scheme = (weight_key, activation_key) + return ( + ( + scheme + in [ + (None, None), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + ) + or ( + (scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym)) + and (p.is_device_capability((9, 0))) + ) + or ( + (scheme == (kNvfp4Static, kNvfp4Dynamic)) + and (p.is_device_capability((10, 0))) # GB? + ) + ) + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu", "relu2_no_mul"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + def supports_expert_map(self) -> bool: return False @@ -226,85 +287,3 @@ def apply( # Informs FlashInfer to use the block-scale decoding path when True use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale, ) - - -def flashinfer_cutlass_moe_fp4( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_config: FusedMoEQuantConfig, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, -) -> torch.Tensor: - fused_experts = mk.FusedMoEModularKernel( - create_flashinfer_prepare_finalize( - use_dp=False, use_nvfp4=True, enable_alltoallv=False - ), - FlashInferExperts( - out_dtype=hidden_states.dtype, - quant_config=quant_config, - use_dp=False, - ), - ) - - return fused_experts( - hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=inplace, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - - -def flashinfer_cutlass_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_config: FusedMoEQuantConfig, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - tp_rank: int = 0, - tp_size: int = 1, - ep_rank: int = 0, - ep_size: int = 1, - use_dp: bool = False, -) -> torch.Tensor: - fused_experts = mk.FusedMoEModularKernel( - create_flashinfer_prepare_finalize(use_dp=use_dp), - FlashInferExperts( - out_dtype=hidden_states.dtype, - quant_config=quant_config, - tp_rank=tp_rank, - tp_size=tp_size, - ep_rank=ep_rank, - ep_size=ep_size, - ), - ) - - return fused_experts( - hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=inplace, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index dfff860750d6..44f7b3b9bc0a 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -4,18 +4,12 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.distributed import get_dp_group, get_ep_group +from vllm.distributed import get_ep_group from vllm.distributed.device_communicators.base_device_communicator import ( All2AllManagerBase, ) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) -from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP, -) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils.flashinfer import nvfp4_block_scale_interleave @@ -162,90 +156,6 @@ def finalize( output.copy_(fused_expert_output) -class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize): - def __init__( - self, - use_dp: bool, - num_dispatchers: int = 1, - use_deepseek_fp8_block_scale: bool = False, - ): - super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale) - - def prepare( - self, - a1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: torch.Tensor | None, - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - self._apply_router_weight_on_input( - a1, topk_weights, topk_ids, apply_router_weight_on_input - ) - is_nvfp4 = quant_config.quant_dtype == "nvfp4" - if not self.use_dp and is_nvfp4: - return a1, None, None, topk_ids, topk_weights - - if not self.use_deepseek_fp8_block_scale: - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_gscale if is_nvfp4 else quant_config.a1_scale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=not self.use_dp, - ) - else: - # Block-scale path: pass activations through, omit per-token scales - a1q = a1 - a1q_scale = None - - if self.use_dp: - # Build gather list conditionally - omit a1q_scale if None - # (block-scale path) - gather_list = [topk_weights, topk_ids, a1q] - if a1q_scale is not None: - gather_list.append(a1q_scale) - gathered = get_dp_group().all_gatherv( - gather_list, - dim=0, - sizes=get_local_sizes(), - ) - topk_weights, topk_ids, a1q, a1q_scale = gathered - else: - gathered = get_dp_group().all_gatherv( - gather_list, - dim=0, - sizes=get_local_sizes(), - ) - topk_weights, topk_ids, a1q = gathered - a1q_scale = None - - if is_nvfp4 and a1q_scale is not None: - a1q_scale = nvfp4_block_scale_interleave(a1q_scale) - - return a1q, a1q_scale, None, topk_ids, topk_weights - - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: - assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceNoOP) - - if self.use_dp: - fused_expert_output = get_dp_group().reduce_scatterv( - fused_expert_output, dim=0, sizes=get_local_sizes() - ) - output.copy_(fused_expert_output) - - def flashinfer_alltoall_dispatch( all2all_manager: All2AllManagerBase, global_num_tokens_cpu: list[int], @@ -346,26 +256,3 @@ def flashinfer_alltoall_combine( top_k=top_k, token_count=token_count, ) - - -def create_flashinfer_prepare_finalize( - use_dp: bool, - use_nvfp4: bool = False, - enable_alltoallv: bool = False, - use_deepseek_fp8_block_scale: bool = False, -) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP: - """Factory function to create the appropriate FlashInfer implementation.""" - - if use_dp: - if enable_alltoallv: - assert use_nvfp4 - return FlashInferAllToAllMoEPrepareAndFinalize(use_dp) - return FlashInferAllGatherMoEPrepareAndFinalize( - use_dp=True, - use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - ) - else: - # CUTLASS FP8 BLOCK and CUTLASS NVFP4 apply input quantization - # in a single call with the MoE experts kernel. - defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4 - return MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 3bb5a23abb7b..30af9d912dff 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -3,7 +3,12 @@ import torch -from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + RoutingMethodType, +) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( calculate_tile_tokens_dim, @@ -11,8 +16,82 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, + kFp8StaticTensorSym, +) +from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op +# +# Methods used by the oracle for kernel selection. +# + + +def _supports_current_device() -> bool: + """Supports only Blackwell-family GPUs.""" + p = current_platform + # Add check flashinfer trtllm is available + return p.is_cuda() and p.is_device_capability_family(100) + + +def _supports_no_act_and_mul() -> bool: + """Does not support non-gated MoE (i.e. Nanotron-Mini).""" + return False + + +def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, +) -> bool: + """Supports Fp8 per-tensor and Fp8 block.""" + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + +def _supports_activation(activation: str) -> bool: + """Supports silu activation only.""" + return activation in ["silu"] + + +def _supports_moe_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + """Supports EP.""" + return True + + +def is_supported_config_trtllm( + moe_config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + activation_format: mk.FusedMoEActivationFormat, +) -> tuple[bool, str | None]: + """ + This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config + """ + + def _make_reason(reason: str) -> str: + return f"kernel does not support {reason}" + + if not _supports_current_device(): + return False, _make_reason("current device") + elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()): + return False, _make_reason("no act_and_mul MLP layer") + elif not _supports_activation(moe_config.activation): + return False, _make_reason(f"{moe_config.activation} activation") + elif not _supports_quant_scheme(weight_key, activation_key): + return False, _make_reason("quantization scheme") + elif not _supports_moe_parallel_config(moe_config.moe_parallel_config): + return False, _make_reason("parallel config") + elif activation_format != mk.FusedMoEActivationFormat.Standard: + return False, _make_reason("activation format") + + return True, None + def flashinfer_fused_moe_blockscale_fp8( routing_logits: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index fb93464392ea..fbb30d408457 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -5,7 +5,11 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, @@ -17,7 +21,11 @@ normalize_batched_scales_shape, normalize_scales_shape, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + group_broadcast, +) +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -633,26 +641,41 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_num_tokens: int, - num_dispatchers: int, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): - super().__init__(quant_config) + super().__init__(moe_config, quant_config) assert not self.quant_config.use_int8_w8a8, "NYI" assert not self.quant_config.use_int8_w8a16, "NYI" assert not self.quant_config.use_int4_w4a16, "NYI" assert self.quant_config.ocp_mx_scheme is None, "NYI" - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) + @staticmethod + def activation_formats() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + @staticmethod + def _supports_current_device() -> bool: + return True + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return False + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu", "swigluoai"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True def supports_chunking(self) -> bool: return False @@ -826,28 +849,48 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_num_tokens: int, - num_dispatchers: int, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): - super().__init__(quant_config) + super().__init__(moe_config, quant_config) assert not self.quant_config.use_int8_w8a8, "NYI" assert not self.quant_config.use_int8_w8a16, "NYI" assert not self.quant_config.use_int4_w4a16, "NYI" assert self.quant_config.ocp_mx_scheme is None, "NYI" - assert max_num_tokens > 0 - assert num_dispatchers > 0 - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + @staticmethod + def _supports_current_device() -> bool: + return current_platform.is_cuda_alike() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # p = current_platform + # device_supports_fp8 = p.is_rocm() or ( + # p.is_cuda() and p.has_device_capability((9, 0)) + # ) + # return quant_scheme.is_unquantized or ( + # quant_scheme.is_fp8_w8a8 and device_supports_fp8 + # ) + return False + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu", "gelu", "swigluoai"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True def supports_chunking(self) -> bool: return False diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 77c6b97eaea3..5b552dc93287 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -8,7 +8,11 @@ import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( batched_moe_align_block_size, moe_align_block_size, @@ -23,6 +27,12 @@ marlin_moe_intermediate_size, marlin_quant_input, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Static128BlockSym, + kFp8StaticChannelSym, + kNvfp4Static, +) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -530,6 +540,7 @@ def batched_fused_marlin_moe( class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, w13_g_idx: torch.Tensor | None = None, w2_g_idx: torch.Tensor | None = None, @@ -549,7 +560,38 @@ def __init__( self.w13_g_idx_sort_indices = w13_g_idx_sort_indices self.w2_g_idx_sort_indices = w2_g_idx_sort_indices self.is_k_full = is_k_full - super().__init__(quant_config) + super().__init__(moe_config, quant_config) + + @staticmethod + def _supports_current_device() -> bool: + p = current_platform + return p.is_cuda() and p.has_device_capability((7, 5)) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # TODO(rob): add int4, mxfp4, int8 as integrations + # are migrated to use the oracle one-by-one. + SUPPORTED_W = [ + kFp8Static128BlockSym, + kFp8StaticChannelSym, + kNvfp4Static, + ] + return weight_key in SUPPORTED_W + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu", "swigluoai"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True @property def quant_type_id(self) -> int: @@ -595,38 +637,15 @@ def moe_problem_size( class MarlinExperts(MarlinExpertsBase): - def __init__( - self, - quant_config: FusedMoEQuantConfig, - w13_g_idx: torch.Tensor | None = None, - w2_g_idx: torch.Tensor | None = None, - w13_g_idx_sort_indices: torch.Tensor | None = None, - w2_g_idx_sort_indices: torch.Tensor | None = None, - is_k_full: bool = True, - ): - super().__init__( - quant_config, - w13_g_idx, - w2_g_idx, - w13_g_idx_sort_indices, - w2_g_idx_sort_indices, - is_k_full, - ) - def supports_expert_map(self) -> bool: return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_chunking(self) -> bool: return True @@ -720,42 +739,15 @@ def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: class BatchedMarlinExperts(MarlinExpertsBase): - def __init__( - self, - max_num_tokens: int, - num_dispatchers: int, - quant_config: FusedMoEQuantConfig, - w13_g_idx: torch.Tensor | None = None, - w2_g_idx: torch.Tensor | None = None, - w13_g_idx_sort_indices: torch.Tensor | None = None, - w2_g_idx_sort_indices: torch.Tensor | None = None, - is_k_full: bool = True, - ): - super().__init__( - quant_config, - w13_g_idx, - w2_g_idx, - w13_g_idx_sort_indices, - w2_g_idx_sort_indices, - is_k_full, - ) - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers - def supports_expert_map(self) -> bool: return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceDelegate() - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts def supports_chunking(self) -> bool: return False diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f2d9463e9183..4564e156041f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -21,6 +21,8 @@ ) from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEConfig, + FusedMoEParallelConfig, FusedMoEQuantConfig, _get_config_dtype_str, ) @@ -49,6 +51,14 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8DynamicTokenSym, + kFp8Static128BlockSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, +) from vllm.model_executor.utils import maybe_disable_graph_partition from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -2299,19 +2309,52 @@ def fused_experts_impl( class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): - super().__init__(quant_config) - - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, + super().__init__(moe_config, quant_config) + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + return current_platform.is_cuda_alike() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + p = current_platform + device_supports_fp8 = p.is_rocm() or ( + p.is_cuda() and p.has_device_capability((9, 0)) ) + if not device_supports_fp8: + return (weight_key, activation_key) == (None, None) + + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kFp8StaticChannelSym, kFp8DynamicTokenSym), + (kFp8StaticTensorSym, kFp8DynamicTokenSym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu", "gelu", "swigluoai"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True + def supports_chunking(self) -> bool: return True @@ -2486,11 +2529,12 @@ def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: class TritonWNA16Experts(TritonExperts): - def __init__( - self, - quant_config: FusedMoEQuantConfig, - ): - super().__init__(quant_config) + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return False def apply( self, @@ -2629,10 +2673,12 @@ def apply( def modular_triton_fused_moe( - quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + shared_experts: torch.nn.Module | None = None, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - TritonExperts(quant_config), + TritonExperts(moe_config, quant_config), shared_experts, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 389ccf358c56..b3085d43c0b2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -30,6 +30,17 @@ def __init__(self, moe: FusedMoEConfig): self.moe: FusedMoEConfig = moe self.moe_quant_config: FusedMoEQuantConfig | None = None + @property + def supports_mk_interally(self) -> bool: + """ + Returns True if this method supports using modular kernels (MK) + internally for MoE operations, False otherwise. + + This method should be overridden by subclasses that support + modular kernels internally. + """ + return False + @abstractmethod def create_weights( self, diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index c4bc1824aa1f..b209820cdfa9 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -9,12 +9,16 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEParallelConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) from vllm.triton_utils import tl, triton from vllm.utils.import_utils import has_triton_kernels @@ -241,8 +245,43 @@ def make_routing_data( class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config: FusedMoEQuantConfig): - super().__init__(quant_config) + @staticmethod + def _supports_current_device() -> bool: + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_activation(activation: str) -> bool: + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) def supports_expert_map(self) -> bool: return True @@ -297,19 +336,9 @@ def _make_routing_data( class OAITritonExperts(BaseOAITritonExperts): - def __init__(self, quant_config: FusedMoEQuantConfig): - # TODO (varun) : Enable activation quantization - assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" - super().__init__(quant_config) - - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_chunking(self) -> bool: return True @@ -391,19 +420,9 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): One use case for it is to inject LoRA modules on the activation and moe_sum. """ - def __init__(self, quant_config: FusedMoEQuantConfig): - # TODO (varun) : Enable activation quantization - assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" - super().__init__(quant_config) - - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_chunking(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3b3a789f6d6e..77898da75b1d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -360,7 +360,7 @@ def __init__( enable_eplb: bool = False, num_redundant_experts: int = 0, has_bias: bool = False, - is_sequence_parallel=False, + is_sequence_parallel: bool = False, expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, routing_method_type: RoutingMethodType | None = None, @@ -587,6 +587,7 @@ def __init__( num_experts=self.global_num_experts, experts_per_token=top_k, hidden_dim=hidden_size, + intermediate_size_per_partition=self.intermediate_size_per_partition, num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, in_dtype=moe_in_dtype, @@ -595,9 +596,8 @@ def __init__( has_bias=has_bias, is_act_and_mul=is_act_and_mul, is_lora_enabled=vllm_config.lora_config is not None, - ) - self.moe_config_use_flashinfer_cutlass_kernels = ( - self.moe_config.use_flashinfer_cutlass_kernels + activation=activation, + device=vllm_config.device_config.device, ) self.quant_config = quant_config @@ -685,6 +685,10 @@ def _get_quant_method() -> FusedMoEMethodBase: # This is called after all weight loading and post-processing, so it # should be safe to swap out the quant_method. def maybe_init_modular_kernel(self) -> None: + if self.quant_method.supports_mk_interally: + logger.info_once("DEBUG: SKIPPING MK INIT: Handled Internally!!!!") + return + self.ensure_moe_quant_config_init() # routing_tables only needed for round-robin expert placement with # DeepEP all2all backend. @@ -763,14 +767,6 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels - @property - def use_flashinfer_cutlass_kernels(self): - return ( - self.moe_quant_config is not None - and self.moe_quant_config.quant_dtype == "nvfp4" - and self.moe_config_use_flashinfer_cutlass_kernels - ) - @property def use_marlin_kernels(self): return getattr(self.quant_method, "use_marlin", False) @@ -780,7 +776,7 @@ def use_dp_chunking(self) -> bool: return ( self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels - or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) + or self.moe_parallel_config.use_fi_all2allv_kernels ) and envs.VLLM_ENABLE_MOE_DP_CHUNK @property @@ -1688,6 +1684,7 @@ def must_reduce_shared_expert_outputs(self) -> bool: early. """ assert self.quant_method is not None + # TODO(rob): investigate this. return ( isinstance(self.quant_method, FusedMoEModularMethod) and self.quant_method.fused_experts.output_is_reduced() @@ -1894,8 +1891,15 @@ def forward_impl( self.ensure_moe_quant_config_init() self.ensure_dp_chunking_init() + # TODO: figure out a better way to express who is responsible for the SE. + mk_has_shared_expert = ( + self.quant_method.supports_mk_interally + and hasattr(self.quant_method, "kernel") + and getattr(self.quant_method.kernel, "shared_experts", None) is not None + ) has_separate_shared_experts = ( not isinstance(self.quant_method, FusedMoEModularMethod) + and not mk_has_shared_expert and self.shared_experts is not None ) @@ -1919,8 +1923,9 @@ def forward_impl( hidden_states, router_logits, has_separate_shared_experts ) - do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance( - self.quant_method, FusedMoEModularMethod + do_naive_dispatch_combine: bool = self.dp_size > 1 and not ( + isinstance(self.quant_method, FusedMoEModularMethod) + or self.quant_method.supports_mk_interally ) ctx = get_forward_context() diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a6df2b20af9c..fad57edcd16b 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -13,6 +13,7 @@ from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, ) @@ -22,6 +23,9 @@ count_expert_num_tokens, disable_inplace, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) from vllm.utils.math_utils import cdiv from vllm.v1.worker.ubatching import ( dbo_enabled, @@ -374,18 +378,85 @@ class FusedMoEPermuteExpertsUnpermute(ABC): def __init__( self, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): """ + moe_config: MoE layer configuration. quant_config: Quantization parameters for this experts instance. """ + self.moe_config = moe_config self.quant_config = quant_config + self._max_num_tokens: int | None = None + self._num_dispatchers: int | None = None + + @staticmethod + def should_pf_defer_input_quant( + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig + ) -> bool: + """ + Whether or not the PrepareFinalize should defer input quantization + in the prepare step. If True, then the Experts kernel will + execute the input quantization itself. + + Sample subclasses that override are AITER and FlashInfer CUTLASS. + """ + return False + + def _init_batched_experts_addl_params( + self, + max_num_tokens: int, + num_dispatchers: int, + ): + """ + Initialize any additional parameters needed for batched experts. + """ + self._max_num_tokens = max_num_tokens + self._num_dispatchers = num_dispatchers @property + def max_num_tokens(self) -> int: + if self._max_num_tokens is None: + raise AttributeError("max_num_tokens only valid for BatchedExperts") + return self._max_num_tokens + + @property + def num_dispatchers(self) -> int: + if self._num_dispatchers is None: + raise AttributeError("num_dispatchers only valid for BatchedExperts") + return self._num_dispatchers + + @classmethod + def make_standard_experts( + cls, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + ) -> "FusedMoEPermuteExpertsUnpermute": + """ + Factory method to create an instance of this class. + """ + assert cls.activation_format() == FusedMoEActivationFormat.Standard + return cls(moe_config, quant_config) + + @classmethod + def make_batched_experts( + cls, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + max_num_tokens: int, + num_dispatchers: int, + ) -> "FusedMoEPermuteExpertsUnpermute": + """ + Factory method to create an instance of this class. + """ + assert cls.activation_format() == FusedMoEActivationFormat.BatchedExperts + instance = cls(moe_config, quant_config) + instance._init_batched_experts_addl_params(max_num_tokens, num_dispatchers) + return instance + + @staticmethod @abstractmethod - def activation_formats( - self, - ) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]: + def activation_format() -> FusedMoEActivationFormat: """ A property which is a tuple of the input and output activation formats for the 'apply' method. @@ -435,6 +506,78 @@ def moe_problem_size( return E, M, N, K, topk + # + # Various helpers for registering support for various features. + # Used by the oracle to select a particular kernel for a deployment. + # + + @staticmethod + def is_supported_config( + cls: type["FusedMoEPermuteExpertsUnpermute"], + moe_config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + activation_format: FusedMoEActivationFormat, + ) -> tuple[bool, str | None]: + def _make_reason(reason: str) -> str: + return f"kernel does not support {reason}" + + if not cls._supports_current_device(): + return False, _make_reason("current device") + elif not (moe_config.is_act_and_mul or cls._supports_no_act_and_mul()): + return False, _make_reason("no act_and_mul MLP layer") + elif not cls._supports_activation(moe_config.activation): + return False, _make_reason(f"{moe_config.activation} activation") + elif not cls._supports_quant_scheme(weight_key, activation_key): + return False, _make_reason("quantization scheme") + elif not cls._supports_parallel_config(moe_config.moe_parallel_config): + return False, _make_reason("parallel config") + elif activation_format != cls.activation_format(): + return False, _make_reason(f"{activation_format.value} activation format") + return True, None + + @staticmethod + @abstractmethod + def _supports_current_device() -> bool: + """ + Whether the kernel supports the current device type + (compute cability and current platform). + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def _supports_no_act_and_mul() -> bool: + """ + Whether the kernel supports act_and_mul=False, i.e. + non-gated MoE models like Nemotron-Nano. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + raise NotImplementedError + + @staticmethod + @abstractmethod + def _supports_activation(activation: str) -> bool: + """ + Whether the kernel supports a particular act function. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + """ + Whether the kernel supports deployment in expert parallel. + """ + raise NotImplementedError + # # Various helpers for accessing quantization parameters from the # quant_config. @@ -715,12 +858,12 @@ def __init__( self._post_init_setup() assert ( - prepare_finalize.activation_format == fused_experts.activation_formats[0] + prepare_finalize.activation_format == fused_experts.activation_format() ), ( f"{prepare_finalize.__class__.__name__}." f"{prepare_finalize.activation_format} == " f"{fused_experts.__class__.__name__}." - f"{fused_experts.activation_formats[0]}" + f"{fused_experts.activation_format()}" ) def _post_init_setup(self): diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index f5c3b9af611f..fe6d5cc68bc9 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum +from typing import TYPE_CHECKING import torch @@ -8,12 +9,18 @@ from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, fp8_w8a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( + is_supported_config_trtllm, +) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, get_flashinfer_moe_backend, @@ -26,128 +33,279 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - cutlass_group_gemm_supported, +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, ) -from vllm.platforms import current_platform -from vllm.utils.deep_gemm import is_deep_gemm_supported -from vllm.utils.flashinfer import has_flashinfer_moe -from vllm.utils.import_utils import has_deep_gemm + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.layer import FusedMoE logger = init_logger(__name__) class Fp8MoeBackend(Enum): NONE = 0 - FLASHINFER_TRTLLM = 1 - FLASHINFER_CUTLASS = 2 - DEEPGEMM = 3 - MARLIN = 4 - TRITON = 5 - AITER = 6 - VLLM_CUTLASS = 7 + FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" + FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS" + DEEPGEMM = "DEEPGEMM" + BATCHED_DEEPGEMM = "BATCHED_DEEPGEMM" + MARLIN = "MARLIN" + TRITON = "TRITON" + BATCHED_TRITON = "BATCHED_TRITON" + AITER = "AITER" + VLLM_CUTLASS = "VLLM_CUTLASS" + + +def backend_2_kernel_cls( + backend: Fp8MoeBackend, +) -> type[mk.FusedMoEPermuteExpertsUnpermute]: + if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + raise NotImplementedError + + elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, + ) + + return FlashInferExperts + + elif backend == Fp8MoeBackend.DEEPGEMM: + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts, + ) + + return TritonOrDeepGemmExperts + + elif backend == Fp8MoeBackend.BATCHED_DEEPGEMM: + from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts, + ) + + return BatchedDeepGemmExperts + + elif backend == Fp8MoeBackend.MARLIN: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + MarlinExperts, + ) + + return MarlinExperts + + elif backend == Fp8MoeBackend.TRITON: + from vllm.model_executor.layers.fused_moe.fused_moe import ( + TritonExperts, + ) + + return TritonExperts + + elif backend == Fp8MoeBackend.BATCHED_TRITON: + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts, + ) + + return BatchedTritonExperts + + elif backend == Fp8MoeBackend.AITER: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + AiterExperts, + ) + + return AiterExperts + + elif backend == Fp8MoeBackend.VLLM_CUTLASS: + from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import ( + TritonOrCutlassExperts, + ) + + return TritonOrCutlassExperts + + else: + raise ValueError(f"Unknown FP8 MoE backend: {backend.value}") def select_fp8_moe_backend( - block_quant: bool, - tp_size: int, - with_lora_support: bool, - is_act_and_mul: bool = True, + config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, allow_vllm_cutlass: bool = False, -) -> Fp8MoeBackend: +) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: """ Select the primary FP8 MoE backend Note: Shape-specific fallbacks may still occur at runtime. """ - # TODO(rob): in a future PR, we will query each mk for - # supported features and return the mk directly, just like - # we do for the Attention Backend. - - if with_lora_support: - return Fp8MoeBackend.TRITON - - def _make_log_backend(backend_name: str): - return f"Using {backend_name} backend for FP8 MoE" - - # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100. - if ( - current_platform.is_cuda() - and ( - current_platform.is_device_capability_family(100) - or current_platform.is_device_capability(90) + + if config.is_lora_enabled: + return Fp8MoeBackend.TRITON, backend_2_kernel_cls(Fp8MoeBackend.TRITON) + + # NOTE(rob): this is kind of a hack. We need to peak into + # the prepare-finalize selection to determine if we are using + # the batched or standard expert format. + use_batched = ( + config.moe_parallel_config.use_deepep_ll_kernels + or config.moe_parallel_config.use_pplx_kernels + ) + activation_format = ( + mk.FusedMoEActivationFormat.BatchedExperts + if use_batched + else mk.FusedMoEActivationFormat.Standard + ) + + def _make_log_backend(backend: Fp8MoeBackend): + return f"Using `{backend.value}` backend for FP8 MoE" + + def _make_log_unsupported(backend: Fp8MoeBackend, reason: str | None) -> str: + if reason: + return ( + f"FP8 MoE backend `{backend.value}` does not support the " + f"deployment configuration since {reason}." + ) + else: + return ( + f"FP8 MoE backend `{backend.value}` does not support the " + "deployment configuration." + ) + + def _return_or_raise( + backend: Fp8MoeBackend, + config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + activation_format: mk.FusedMoEActivationFormat, + ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: + k_cls = backend_2_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, config, weight_key, activation_key, activation_format ) - and envs.VLLM_USE_FLASHINFER_MOE_FP8 - and has_flashinfer_moe() - ): - backend = get_flashinfer_moe_backend() - if backend == FlashinferMoeBackend.TENSORRT_LLM: - logger.info_once(_make_log_backend("FlashInfer TRTLLM")) - if not is_act_and_mul: - raise ValueError( - "FlashInfer TRTLLM FP8 MoE backend only supports " - "act_and_mul gate_up_project fusion. Please set " - "VLLM_USE_FLASHINFER_MOE_FP8=throughput to use the " - "FlashInfer CUTLASS backend instead." + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, k_cls + raise ValueError(_make_log_unsupported(backend, reason)) + + # NOTE: the kernels are selected in the following order. + AVAILABLE_BACKENDS = [ + Fp8MoeBackend.AITER, + Fp8MoeBackend.FLASHINFER_TRTLLM, + Fp8MoeBackend.FLASHINFER_CUTLASS, + Fp8MoeBackend.DEEPGEMM, + Fp8MoeBackend.BATCHED_DEEPGEMM, + Fp8MoeBackend.VLLM_CUTLASS, + Fp8MoeBackend.TRITON, + Fp8MoeBackend.BATCHED_TRITON, + Fp8MoeBackend.MARLIN, + ] + + # Handle explicit FlashInfer FP8 configuration. + if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP8"): + if not envs.VLLM_USE_FLASHINFER_MOE_FP8: + # If the user rejects FlashInfer remove those backends. + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.FLASHINFER_TRTLLM) + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.FLASHINFER_CUTLASS) + + elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): + # If user is explicit about backend, validate it. + fi_backend = get_flashinfer_moe_backend() + + if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: + backend = Fp8MoeBackend.FLASHINFER_TRTLLM + supported, reason = is_supported_config_trtllm( + config, weight_key, activation_key, activation_format ) - return Fp8MoeBackend.FLASHINFER_TRTLLM - else: - if block_quant and current_platform.is_device_capability_family(100): - raise ValueError( - "FlashInfer FP8 MoE throughput backend does not " - "support block quantization on SM100. Please use " - "VLLM_FLASHINFER_MOE_BACKEND=latency to use the " - "FlashInfer TRTLLM backend instead." + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, None + else: + raise ValueError(_make_log_unsupported(backend, reason)) + + elif fi_backend == FlashinferMoeBackend.CUTLASS: + backend = Fp8MoeBackend.FLASHINFER_CUTLASS + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format ) - logger.info_once(_make_log_backend("FlashInfer CUTLASS")) - return Fp8MoeBackend.FLASHINFER_CUTLASS - - # weight-only path for older GPUs without native FP8 - if ( - current_platform.is_cuda() and not current_platform.has_device_capability(89) - ) or envs.VLLM_TEST_FORCE_FP8_MARLIN: - logger.info_once(_make_log_backend("Marlin"), scope="local") - return Fp8MoeBackend.MARLIN - - # Determine if we should use DeepGEMM with block-quantized weights: - # - If explicitly set by user, respect their choice - # - If not explicitly set (default), disable when TP size is >= 8 - moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM - if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and tp_size >= 8: - moe_use_deep_gemm = False - logger.info_once( - "DeepGEMM MoE is disabled by default when TP size is >= 8. " - "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.", - scope="local", - ) - use_deep_gemm = envs.VLLM_USE_DEEP_GEMM - if not is_deep_gemm_supported(): - use_deep_gemm = False - logger.info_once( - "DeepGEMM is disabled because the platform does not support it.", - scope="local", + else: + assert fi_backend == FlashinferMoeBackend.CUTEDSL + raise ValueError("FlashInfer MaskedGEMM not supported for FP8") + + else: + # If the user is not explicit about the backend, try both. + for backend in [ + Fp8MoeBackend.FLASHINFER_TRTLLM, + Fp8MoeBackend.FLASHINFER_CUTLASS, + ]: + k_cls = backend_2_kernel_cls(backend) + if k_cls.is_supported_config( + k_cls, config, weight_key, activation_key, activation_format + ): + logger.info_once(_make_log_backend(backend)) + return backend, k_cls + + raise NotImplementedError( + "Found VLLM_USE_FLASHINFER_MOE_FP8=1, but no " + "FlashInfer FP8 MoE backend supports the configuration." + ) + + # Handle explicit DeepGEMM FP8 configuration. + if envs.is_set("VLLM_USE_DEEP_GEMM") or envs.is_set("VLLM_MOE_USE_DEEP_GEMM"): + if not envs.VLLM_USE_DEEP_GEMM or not envs.VLLM_MOE_USE_DEEP_GEMM: + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.DEEPGEMM) + else: + backend = ( + Fp8MoeBackend.DEEPGEMM + if activation_format == mk.FusedMoEActivationFormat.Standard + else Fp8MoeBackend.BATCHED_DEEPGEMM + ) + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) + + # Handle explicit MARLIN FP8 configuration. + if envs.VLLM_TEST_FORCE_FP8_MARLIN: + backend = Fp8MoeBackend.MARLIN + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format ) - if use_deep_gemm and moe_use_deep_gemm and block_quant: - if not has_deep_gemm(): - logger.warning_once( - "DeepGEMM backend requested but not available.", scope="local" + # Handle explicit AITER FP8 configuration. + if envs.is_set("VLLM_ROCM_USE_AITER") or envs.is_set("VLLM_ROCM_USE_AITER_MOE"): + if not envs.VLLM_ROCM_USE_AITER or not envs.VLLM_ROCM_USE_AITER_MOE: + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.AITER) + else: + backend = Fp8MoeBackend.AITER + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format ) - elif is_deep_gemm_supported(): - logger.info_once(_make_log_backend("DeepGEMM"), scope="local") - return Fp8MoeBackend.DEEPGEMM - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE: - logger.info_once(_make_log_backend("ROCm AITER"), scope="local") - return Fp8MoeBackend.AITER + if not allow_vllm_cutlass: + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.VLLM_CUTLASS) + + # Select kernels in order of backend. + for backend in AVAILABLE_BACKENDS: + if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + k_cls = None # type: ignore[assignment] + supported, reason = is_supported_config_trtllm( + config, + weight_key, + activation_key, + activation_format, + ) + else: + k_cls = backend_2_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, + config, + weight_key, + activation_key, + activation_format, + ) - if allow_vllm_cutlass and not block_quant and cutlass_group_gemm_supported(): - logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local") - return Fp8MoeBackend.VLLM_CUTLASS + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + else: + logger.info_once(_make_log_unsupported(backend, reason), scope="local") - # default to Triton - logger.info_once(_make_log_backend("Triton"), scope="local") - return Fp8MoeBackend.TRITON + raise NotImplementedError( + "No FP8 MoE backend supports the deployment configuration." + ) def convert_to_fp8_moe_kernel_format( @@ -205,6 +363,8 @@ def make_fp8_moe_quant_config( a1_scale: torch.Tensor | None, a2_scale: torch.Tensor | None, block_shape: list[int] | None = None, + per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, ) -> FusedMoEQuantConfig | None: """ Create FusedMoEQuantConfig for the specifed FP8 Backend. @@ -257,102 +417,63 @@ def make_fp8_moe_quant_config( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, ) def make_fp8_moe_kernel( - layer: torch.nn.Module, + layer: "FusedMoE", moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, fp8_backend: Fp8MoeBackend, + experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], ) -> tuple[mk.FusedMoEModularKernel, bool]: - # Delayed import is required since the oracle is imported - # by CPU backends which cannot import all of these experts. - # TODO: update the experts to make this not happen. - from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, + # Create Prepare/Finalize. + prepare_finalize = maybe_make_prepare_finalize( + moe=moe_config, + quant_config=moe_quant_config, + routing_tables=None, # TODO: init routing tables here? + defer_input_quant=experts_cls.should_pf_defer_input_quant( + moe_config, moe_quant_config + ), + allow_new_interface=True, ) + assert prepare_finalize is not None - # NOTE(rob): this is a WIP refactor. We are first migrating - # all of the kernels in the TP case to use mk. Once this is - # done, then we will initialzie the TP case and DP/EP case - # via the same code path (i.e. via maybe_init_modular_kernel). - # NOTE(rob): in progress migrating all into this format. - use_inplace = True - if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, - ) + logger.info_once("Using %s", prepare_finalize.__class__.__name__) - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP( - defer_input_quant=moe_quant_config.is_block_quantized - ), - FlashInferExperts( - out_dtype=layer.orig_dtype, - quant_config=moe_quant_config, - ep_rank=moe_config.ep_rank, - ep_size=moe_config.ep_size, - tp_rank=moe_config.tp_rank, - tp_size=moe_config.tp_size, - use_dp=(moe_config.dp_size > 1), - use_deepseek_fp8_block_scale=moe_quant_config.is_block_quantized, - ), - ) - use_inplace = False - - elif fp8_backend == Fp8MoeBackend.AITER: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - AiterExperts, - ) - - kernel = mk.FusedMoEModularKernel( - # TODO: make defer_input_quant an attr of the AiterExperts - MoEPrepareAndFinalizeNoEP(defer_input_quant=True), - AiterExperts(quant_config=moe_quant_config), - ) - elif fp8_backend == Fp8MoeBackend.MARLIN: - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - MarlinExperts, - ) - - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - MarlinExperts(quant_config=moe_quant_config), - ) - elif fp8_backend == Fp8MoeBackend.VLLM_CUTLASS: - from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import ( - TritonOrCutlassExperts, - ) - - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - TritonOrCutlassExperts( - out_dtype=moe_config.in_dtype, - e=layer.local_num_experts, - n=layer.intermediate_size_per_partition, - k=layer.hidden_size, - device=layer.w13_weight.device, - quant_config=moe_quant_config, - ), - ) - elif fp8_backend == Fp8MoeBackend.DEEPGEMM: - from vllm.model_executor.layers.fused_moe import ( - TritonOrDeepGemmExperts, - ) - - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - TritonOrDeepGemmExperts(quant_config=moe_quant_config), + # Create Experts. + if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.Standard: + experts = experts_cls.make_standard_experts( + moe_config=moe_config, + quant_config=moe_quant_config, ) else: - from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None + experts = experts_cls.make_batched_experts( + moe_config=moe_config, + quant_config=moe_quant_config, + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), ) - assert fp8_backend == Fp8MoeBackend.TRITON - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - TritonExperts(quant_config=moe_quant_config), - ) - return kernel, use_inplace + # NOTE(rob): we only want the ModularKernel to control the SharedExpert + # if we are using all2all (for SBO). Need to make a change somewhere + # else to prevent double running the Shared Expert. + # This needs to be refactored. + kernel = mk.FusedMoEModularKernel( + prepare_finalize, + experts, + shared_experts=( + getattr(layer, "shared_expert", None) + if moe_config.moe_parallel_config.use_all2all_kernels + else None + ), + moe_parallel_config=moe_config.moe_parallel_config, + ) + + # TODO(rob): update inplace logic to be part of the kernel. + inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS + return kernel, inplace diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 547a2a795d19..e3d8bb4ab062 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -1,33 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum +from typing import TYPE_CHECKING import torch import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, nvfp4_moe_quant_config, nvfp4_w4a16_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - CutlassExpertsFp4, -) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, -) -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - MarlinExperts, -) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - is_flashinfer_fp4_cutedsl_moe_available, - is_flashinfer_fp4_cutlass_moe_available, + is_supported_config_trtllm, prepare_nvfp4_moe_layer_for_fi_or_cutlass, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( @@ -35,22 +26,24 @@ get_flashinfer_moe_backend, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - is_fp4_marlin_supported, prepare_nvfp4_moe_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - cutlass_fp4_supported, + QuantKey, ) +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + logger = init_logger(__name__) class NvFp4MoeBackend(Enum): - FLASHINFER_CUTLASS = "FlashInfer CUTLASS" - FLASHINFER_TRTLLM = "FlashInfer TRTLLM" - FLASHINFER_CUTEDSL = "FlashInfer CUTEDSL" - VLLM_CUTLASS = "vLLM CUTASS" - MARLIN = "vLLM MARLIN" + FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" + FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS" + FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL" + VLLM_CUTLASS = "VLLM_CUTLASS" + MARLIN = "MARLIN" FLASHINFER_NVFP4_MOE_BACKENDS = [ @@ -71,44 +64,184 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool: # of all experts in Expert Parallel Mode when all experts are not # on the same rank. - return backend in [ - NvFp4MoeBackend.FLASHINFER_CUTLASS, - NvFp4MoeBackend.FLASHINFER_TRTLLM, - ] + return backend in FLASHINFER_NVFP4_MOE_BACKENDS + + +def backend_2_kernel_cls( + backend: NvFp4MoeBackend, +) -> type[mk.FusedMoEPermuteExpertsUnpermute]: + if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: + raise NotImplementedError + + elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, + ) + + return FlashInferExperts + + elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL: + from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( + FlashInferCuteDSLExperts, + ) + return FlashInferCuteDSLExperts + + elif backend == NvFp4MoeBackend.VLLM_CUTLASS: + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassExpertsFp4, + ) + + return CutlassExpertsFp4 + + elif backend == NvFp4MoeBackend.MARLIN: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + MarlinExperts, + ) + + return MarlinExperts + else: + raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}") + + +def select_nvfp4_moe_backend( + config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, +) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: + """ + Select the primary NvFP4 MoE backend + Note: Shape-specific fallbacks may still occur at runtime. + """ + + # NOTE(rob): this is kind of a hack. We need to peak into + # the prepare-finalize selection to determine if we are using + # the batched or standard expert format. + use_batched = ( + config.moe_parallel_config.use_deepep_ll_kernels + or config.moe_parallel_config.use_pplx_kernels + ) + activation_format = ( + mk.FusedMoEActivationFormat.BatchedExperts + if use_batched + else mk.FusedMoEActivationFormat.Standard + ) -def select_nvfp4_moe_backend() -> NvFp4MoeBackend: def _make_log_backend(backend: NvFp4MoeBackend): - return f"Using {backend.value} backend for NvFp4 MoE" + return f"Using '{backend.value}' backend for NvFp4 MoE" + + def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str: + if reason: + return ( + f"NvFP4 MoE backend '{backend.value}' does not support the " + f"deployment configuration since {reason}." + ) + else: + return ( + f"NvFP4 MoE backend '{backend.value}' does not support the " + "deployment configuration." + ) - if cutlass_fp4_supported() and not envs.VLLM_TEST_FORCE_FP8_MARLIN: - allow_flashinfer = ( - is_flashinfer_fp4_cutlass_moe_available() - or is_flashinfer_fp4_cutedsl_moe_available() + def _return_or_raise( + backend: NvFp4MoeBackend, + config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + activation_format: mk.FusedMoEActivationFormat, + ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: + k_cls = backend_2_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, config, weight_key, activation_key, activation_format ) - if allow_flashinfer and envs.VLLM_USE_FLASHINFER_MOE_FP4: - backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()] + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, k_cls + raise ValueError(_make_log_unsupported(backend, reason)) + + # NOTE: the kernels are selected in the following order. + AVAILABLE_BACKENDS = [ + NvFp4MoeBackend.FLASHINFER_TRTLLM, + NvFp4MoeBackend.FLASHINFER_CUTEDSL, + NvFp4MoeBackend.FLASHINFER_CUTLASS, + NvFp4MoeBackend.MARLIN, + # NvFp4MoeBackend.VLLM_CUTLASS, + ] + + if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"): + if not envs.VLLM_USE_FLASHINFER_MOE_FP4: + # If the user rejects FlashInfer remove those backends. + for fi_backend in FLASHINFER_NVFP4_MOE_BACKENDS: + AVAILABLE_BACKENDS.remove(fi_backend) + + elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): + # If user is explicit about backend, validate it. + fi_backend = get_flashinfer_moe_backend() + + if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: + backend = NvFp4MoeBackend.FLASHINFER_TRTLLM + supported, reason = is_supported_config_trtllm( + config, weight_key, activation_key, activation_format + ) + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, None + else: + raise ValueError(_make_log_unsupported(backend, reason)) + else: + backend = fi_2_vllm_backend_map[fi_backend] + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) else: - backend = NvFp4MoeBackend.VLLM_CUTLASS - elif is_fp4_marlin_supported(): - backend = NvFp4MoeBackend.MARLIN - else: - raise ValueError("No NvFp4 kernel backend available for NvFp4 MoE.") + # If the user is not explicit about the backend, try each. + for backend in FLASHINFER_NVFP4_MOE_BACKENDS: + k_cls = backend_2_kernel_cls(backend) + if k_cls.is_supported_config( + k_cls, config, weight_key, activation_key, activation_format + ): + logger.info_once(_make_log_backend(backend)) + return backend, k_cls - # Log warning if FI backend requested but not available. - if ( - backend not in FLASHINFER_NVFP4_MOE_BACKENDS - and envs.VLLM_USE_FLASHINFER_MOE_FP4 - ): - logger.warning_once( - "Requested FlashInfer backend for NvFp4 MoE, but it's not available. " - "Falling back to %s for NvFp4 MoE", - backend.value, - scope="local", + raise NotImplementedError( + "Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no " + "FlashInfer NVFP4 MoE backend supports the configuration." + ) + + if envs.VLLM_TEST_FORCE_FP8_MARLIN: + backend = NvFp4MoeBackend.MARLIN + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format ) - else: - logger.info_once(_make_log_backend(backend), scope="local") - return backend + + # Select kernels in order of backend. + for backend in AVAILABLE_BACKENDS: + if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: + k_cls = None # type: ignore[assignment] + supported, reason = is_supported_config_trtllm( + config, + weight_key, + activation_key, + activation_format, + ) + else: + k_cls = backend_2_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, + config, + weight_key, + activation_key, + activation_format, + ) + + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + else: + logger.info_once(_make_log_unsupported(backend, reason), scope="local") + + raise NotImplementedError( + "No NvFp4 MoE backend supports the deployment configuration." + ) def convert_to_nvfp4_moe_kernel_format( @@ -227,54 +360,54 @@ def make_nvfp4_moe_quant_config( def make_nvfp4_moe_kernel( - backend: NvFp4MoeBackend, - quant_config: FusedMoEQuantConfig, + layer: "FusedMoE", + moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, -) -> mk.FusedMoEModularKernel | None: - assert moe_config.dp_size == 1 - - UNSUPPORTED_BACKENDS = [ - # TRTLLM does not use the modular kernl abstraction. - NvFp4MoeBackend.FLASHINFER_TRTLLM, - # CUTEDSL is used with BATCHED (masked) format only. - # TODO: add here once we support dp/ep via the oracle. - NvFp4MoeBackend.FLASHINFER_CUTEDSL, - ] + experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], +) -> mk.FusedMoEModularKernel: + # Create Prepare/Finalize. + prepare_finalize = maybe_make_prepare_finalize( + moe=moe_config, + quant_config=moe_quant_config, + routing_tables=None, # TODO: init routing tables here? + defer_input_quant=experts_cls.should_pf_defer_input_quant( + moe_config, moe_quant_config + ), + allow_new_interface=True, + ) + assert prepare_finalize is not None - if backend in UNSUPPORTED_BACKENDS: - return None + logger.info_once("Using %s", prepare_finalize.__class__.__name__) - elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: - return mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(defer_input_quant=True), - FlashInferExperts( - out_dtype=moe_config.in_dtype, - quant_config=quant_config, - ep_rank=moe_config.ep_rank, - ep_size=moe_config.ep_size, - tp_rank=moe_config.tp_rank, - tp_size=moe_config.tp_size, - use_dp=False, - use_deepseek_fp8_block_scale=False, - ), + # Create Experts. + if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.Standard: + experts = experts_cls.make_standard_experts( + moe_config=moe_config, + quant_config=moe_quant_config, ) - - elif backend == NvFp4MoeBackend.VLLM_CUTLASS: - return mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(defer_input_quant=True), - CutlassExpertsFp4( - out_dtype=moe_config.in_dtype, - # TODO(rob): see what impact this has on expert map? - max_experts_per_worker=moe_config.num_experts, - quant_config=quant_config, - ), + else: + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None + experts = experts_cls.make_batched_experts( + moe_config=moe_config, + quant_config=moe_quant_config, + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), ) - elif backend == NvFp4MoeBackend.MARLIN: - return mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - MarlinExperts(quant_config=quant_config), - ) + # NOTE(rob): we only want the ModularKernel to control the SharedExpert + # if we are using all2all (for SBO). Need to make a change somewhere + # else to prevent double running the Shared Expert. + # This needs to be refactored. + kernel = mk.FusedMoEModularKernel( + prepare_finalize, + experts, + shared_experts=( + getattr(layer, "shared_expert", None) + if moe_config.moe_parallel_config.use_all2all_kernels + else None + ), + moe_parallel_config=moe_config.moe_parallel_config, + ) - else: - raise ValueError(f"Unknown NvFp4 MoE backend: {backend}") + return kernel diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 5d806fa843a3..a4d38d184e9b 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -4,12 +4,143 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.distributed import get_ep_group from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate, ) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.utils.flashinfer import nvfp4_block_scale_interleave + + +class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): + def __init__( + self, + defer_input_quant: bool = False, + is_sequence_parallel: bool = False, + num_dispatchers: int = 1, + ) -> None: + super().__init__() + self.defer_input_quant = defer_input_quant + self.is_sequence_parallel = is_sequence_parallel + self._num_dispatchers = num_dispatchers + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + def num_dispatchers(self) -> int: + return self._num_dispatchers + + def output_is_reduced(self) -> bool: + return False + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + if apply_router_weight_on_input: + topk = topk_ids.size(1) + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + # Note: do not use inplace for shared experts overlap + a1 = a1 * topk_weights.to(a1.dtype) + + # Defer input quantization to the MoE kernel. + if self.defer_input_quant: + a1q = a1 + a1q_scale = None + else: + use_nvfp4 = quant_config.use_nvfp4_w4a4 + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=False, + ) + + # TODO - this is just for deepgemm? + expert_tokens_meta = None + + from vllm.platforms import current_platform + + # The torch ops do not support fp8, so use an int8 view. + # Since dispatch does not do a reduce, this is safe to do. + use_int8_view = a1q.dtype == current_platform.fp8_dtype() + if use_int8_view: + a1q = a1q.view(torch.int8) + + # Skip gathering scales if we have static quantization + # (the scale is a scalar, replicated on all ranks) or + # if quantization is deferred. + skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0 + scales = None if skip_gather_scales else [a1q_scale] + + res = get_ep_group().dispatch( + a1q, + topk_weights, + topk_ids, + is_sequence_parallel=self.is_sequence_parallel, + extra_tensors=scales, + ) + if skip_gather_scales: + a1q, topk_weights, topk_ids = res + else: + a1q, topk_weights, topk_ids, scales = res + assert scales is not None and len(scales) == 1 + a1q_scale = scales[0] + + if use_int8_view: + a1q = a1q.view(current_platform.fp8_dtype()) + + # NOTE: shuffle into format expected by FLASHINFER_CUTLASS + # There are currently no other kernels that use this P/F + # with nvfp4. If we add other kernels in the future, we + # can regsiter a shuffle that gets called here. + if use_nvfp4: + a1q_scale = nvfp4_block_scale_interleave(a1q_scale) + + return a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + + out = weight_and_reduce_impl.apply( + output=None, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + output.copy_( + get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel) + ) class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index b78794c6bd83..cb1a211e9925 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -9,11 +9,15 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEParallelConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) class QuantMethod(IntEnum): @@ -269,17 +273,46 @@ def rocm_aiter_fused_experts( class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config): - super().__init__(quant_config) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def should_pf_defer_input_quant( + fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig + ) -> bool: + """ + AITER Fused MoE kernels handle input quantization. + """ + return True - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def _supports_current_device() -> bool: + return rocm_aiter_ops.is_fused_moe_enabled() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # return ( + # quant_scheme.is_unquantized + # or quant_scheme.is_fp8_w8a8 + # or quant_scheme.is_mxfp4_w4a4 + # ) + return False + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu", "gelu"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True def supports_expert_map(self): return True diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index a143347b19f2..c3e3744935aa 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -37,7 +37,7 @@ def __init__( use_overlapped and not ( (self.enable_eplb and backend != "allgather_reducescatter") - or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) + or self.moe_config.use_fi_all2allv_kernels ) and self._shared_experts is not None ) diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py index 09d5e45c1ec2..e21d17226671 100644 --- a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py @@ -5,7 +5,10 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts @@ -15,19 +18,18 @@ class TritonOrCutlassExperts(FallbackExperts): """Cutlass with fallback to Triton for low latency shapes on SM100.""" + _experts_cls = CutlassExpertsFp8 + _fallback_cls = TritonExperts + def __init__( self, - e: int, - n: int, - k: int, - out_dtype: torch.dtype | None, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, - device: torch.dtype, ): self.is_sm100 = current_platform.has_device_capability(100) super().__init__( - experts=CutlassExpertsFp8(e, n, k, out_dtype, quant_config, device), - fallback_experts=TritonExperts(quant_config), + experts=CutlassExpertsFp8(moe_config, quant_config), + fallback_experts=TritonExperts(moe_config, quant_config), ) def workspace_shapes( diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 55b1e1211b0a..6c13903fccd3 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -4,7 +4,10 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, @@ -20,10 +23,13 @@ class TritonOrDeepGemmExperts(FallbackExperts): """DeepGemm with fallback to Triton for low latency shapes.""" - def __init__(self, quant_config: FusedMoEQuantConfig): + _experts_cls = DeepGemmExperts + _fallback_cls = TritonExperts + + def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig): super().__init__( - experts=DeepGemmExperts(quant_config), - fallback_experts=TritonExperts(quant_config), + experts=DeepGemmExperts(moe_config, quant_config), + fallback_experts=TritonExperts(moe_config, quant_config), ) def workspace_shapes( diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index c46f59564930..e7cec11c237d 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -6,38 +6,61 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, + FusedMoEParallelConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) +from vllm.platforms import current_platform class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - moe: FusedMoEConfig, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, max_capture_size, ): - super().__init__(quant_config) - self.moe = moe + super().__init__(moe_config, quant_config) self.gemm1_alpha = gemm1_alpha self.gemm1_beta = gemm1_beta self.gemm1_clamp_limit = gemm1_clamp_limit self.max_capture_size = max_capture_size - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + return current_platform.has_device_capability((10, 0)) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # return quant_scheme.is_mxfp4_w4a16 or quant_scheme.is_mxfp4_w4a8 + return False + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["swigluoai"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True def supports_chunking(self) -> bool: return True @@ -86,7 +109,7 @@ def apply( topk = topk_ids.size(-1) local_num_experts = w1.size(0) intermediate_size = w2.size(1) - local_expert_offset = self.moe.ep_rank * local_num_experts + local_expert_offset = self.moe_config.ep_rank * local_num_experts x_quant = hidden_states x_scale = a1q_scale diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index c31c0223eb06..ef0f79a18846 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -121,14 +121,18 @@ def select_gemm_impl( == FusedMoEActivationFormat.BatchedExperts ): logger.debug("BatchedTritonExperts %s", self.moe) - return BatchedTritonExperts( + return BatchedTritonExperts.make_batched_experts( + moe_config=self.moe, + quant_config=self.moe_quant_config, max_num_tokens=self.moe.max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, ) else: logger.debug("TritonExperts %s", self.moe) - return TritonExperts(self.moe_quant_config) + return TritonExperts( + moe_config=self.moe, + quant_config=self.moe_quant_config, + ) def create_weights( self, @@ -257,8 +261,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.use_inplace = True self.kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - AiterExperts(self.moe_quant_config), - shared_experts=None, + AiterExperts( + moe_config=self.moe, + quant_config=self.moe_quant_config, + ), ) elif self.flashinfer_cutlass_moe_enabled: @@ -270,19 +276,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), FlashInferExperts( - out_dtype=layer.params_dtype, + moe_config=self.moe, quant_config=self.moe_quant_config, - tp_rank=self.moe.moe_parallel_config.tp_rank, - tp_size=self.moe.moe_parallel_config.tp_size, - ep_rank=self.moe.moe_parallel_config.ep_rank, - ep_size=self.moe.moe_parallel_config.ep_size, ), ) else: self.use_inplace = True self.kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - TritonExperts(self.moe_quant_config), + TritonExperts( + moe_config=self.moe, + quant_config=self.moe_quant_config, + ), shared_experts=None, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index b71921c5c3d7..a6060ce01c4c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -19,16 +19,15 @@ from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, - FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod, ) from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEQuantConfig, - fp8_w8a8_moe_quant_config, - fp8_w8a16_moe_quant_config, + RoutingMethodType, int4_w4a16_moe_quant_config, int4_w4afp8_moe_quant_config, int8_w8a8_moe_quant_config, @@ -45,10 +44,10 @@ Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, + make_fp8_moe_quant_config, select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( - FLASHINFER_NVFP4_MOE_BACKENDS, NvFp4MoeBackend, convert_to_nvfp4_moe_kernel_format, is_global_sf_supported_for_nvfp4_backend, @@ -61,10 +60,11 @@ WNA16_SUPPORTED_TYPES_MAP, ) from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_prepare_finalize, flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_routed_moe, - select_nvfp4_gemm_impl, +) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + apply_fi_trtllm_fp8_per_tensor_moe, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( process_fp8_input_tensor_strategy_moe, @@ -77,12 +77,16 @@ marlin_make_workspace_new, marlin_moe_permute_scales, ) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - is_fp4_marlin_supported, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( convert_bf16_scales_to_fp8, convert_packed_uint4b8_to_signed_int4_inplace, + kFp8Dynamic128Sym, + kFp8DynamicTokenSym, + kFp8Static128BlockSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, + kNvfp4Dynamic, + kNvfp4Static, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, @@ -195,7 +199,7 @@ def get_moe_method( f"or None for NVFP4A16, found {input_quant}", ) return CompressedTensorsW4A4Nvfp4MoEMethod( - layer.moe_config, layer_name, use_marlin=input_quant is None + layer.moe_config, layer_name, use_a16=(input_quant is None) ) elif ( quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) @@ -229,7 +233,7 @@ def __init__( self, moe: FusedMoEConfig, layer_name: str | None = None, - use_marlin: bool = False, + use_a16: bool = False, ): if not moe.is_act_and_mul: raise ValueError( @@ -239,20 +243,31 @@ def __init__( super().__init__(moe) self.group_size = 16 - if use_marlin: - if is_fp4_marlin_supported(): - self.nvfp4_backend = NvFp4MoeBackend.MARLIN - else: - raise ValueError( - "Marlin FP4 MoE kernel requested but not ", - "supported on current platform.", - ) - else: - self.nvfp4_backend = select_nvfp4_moe_backend() + + # Select experts implementation. + self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( + config=self.moe, + weight_key=kNvfp4Static, + activation_key=None if use_a16 else kNvfp4Dynamic, + ) + + # Delay creation of the kernel until after process-weights. + self.kernel: mk.FusedMoEModularKernel | None = None + + # Used for weight loading. self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.nvfp4_backend ) - self.kernel: mk.FusedMoEModularKernel | None = None + + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None + + @property + def supports_mk_interally(self) -> bool: + return True def create_weights( self, @@ -425,50 +440,34 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_input_scale = a13_scale layer.w2_input_scale = a2_scale - # Initialize the kernel that will be called in apply(). self.moe_quant_config = self.get_fused_moe_quant_config(layer) - use_dp = self.moe.dp_size > 1 - if self.moe_quant_config is not None and not use_dp: + if self.moe_quant_config is not None: + assert self.experts_cls is not None self.kernel = make_nvfp4_moe_kernel( - backend=self.nvfp4_backend, - quant_config=self.moe_quant_config, + layer=layer, + moe_quant_config=self.moe_quant_config, moe_config=self.moe, + experts_cls=self.experts_cls, ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM] - if self.nvfp4_backend in UNSUPPORTED: - return None - elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: - # TP case: avoid convert to ModularKernelMethod - to be refactored. - if self.moe.dp_size == 1: - return None - # For now, fp4 moe only works with the flashinfer dispatcher. - prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( - self.moe - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - else: - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + "CompressedTensorsW4A4NvFp4MoE uses the new modular kernel " + "initialization logic. This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: - assert self.moe_quant_config is not None - """Return the appropriate GEMM experts implementation.""" - experts = select_nvfp4_gemm_impl( - self.moe, - self.moe_quant_config, - allow_flashinfer=(self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS), + ) -> FusedMoEPermuteExpertsUnpermute: + raise ValueError( + "CompressedTensorsW4A4NvFp4MoE uses the new modular kernel " + "initialization logic. This function should not be called." ) - logger.debug_once("Using %s", experts.__class__.__name__) - return experts def get_fused_moe_quant_config( self, layer: torch.nn.Module @@ -579,32 +578,45 @@ def __init__( "For FP8 Fused MoE layer, we require either per tensor or " "channelwise, dynamic per token quantization." ) - self.fp8_backend = select_fp8_moe_backend( - block_quant=self.block_quant, - tp_size=moe.tp_size, - with_lora_support=moe.is_lora_enabled, - # TODO(rob): enable selecting this externally. + + ct2vllm_weight = { + QuantizationStrategy.CHANNEL: kFp8StaticChannelSym, + QuantizationStrategy.TENSOR: kFp8StaticTensorSym, + QuantizationStrategy.BLOCK: kFp8Static128BlockSym, + } + ct2vllm_act = { + QuantizationStrategy.TOKEN: kFp8DynamicTokenSym, + QuantizationStrategy.TENSOR: ( + kFp8StaticTensorSym if self.static_input_scales else kFp8Dynamic128Sym + ), + } + weight_key = ct2vllm_weight[self.weight_quant.strategy] + if weight_key == kFp8Static128BlockSym: + activation_key = kFp8Dynamic128Sym + else: + activation_key = ct2vllm_act[self.input_quant.strategy] + + # Select Fp8 MoE backend + self.fp8_backend, self.experts_cls = select_fp8_moe_backend( + config=self.moe, + weight_key=weight_key, + activation_key=activation_key, allow_vllm_cutlass=True, ) - if self.fp8_backend != Fp8MoeBackend.MARLIN: - per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN - per_channel_quant = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL - ) - if per_act_token != per_channel_quant: - raise NotImplementedError( - "For FP8 Fused MoE layers, per-token and per-channel must be " - "used together." - ) - # TODO(rob): hook this up in a follow up PR. - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - raise NotImplementedError( - "FlashInfer TRTLLM backend not supported for compressed-tensors yet." - ) - self.disable_expert_map = False + # Delay creation of the kernel until after process-weights. self.kernel: mk.FusedMoEModularKernel | None = None + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None + + @property + def supports_mk_interally(self) -> bool: + return True + def create_weights( self, layer: torch.nn.Module, @@ -818,138 +830,55 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: + assert self.experts_cls is not None self.kernel, self.use_inplace = make_fp8_moe_kernel( layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, + experts_cls=self.experts_cls, ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]: - return None - else: - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + "CompressedTensorsW8A8Fp8MoE uses the new modular kernel " + "initialization logic. This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: - # cutlass path - assert self.moe_quant_config is not None - if self.fp8_backend == Fp8MoeBackend.VLLM_CUTLASS: - from vllm.model_executor.layers.fused_moe import ( - CutlassBatchedExpertsFp8, - CutlassExpertsFp8, - ) - - experts: FusedMoEPermuteExpertsUnpermute - - num_dispatchers = prepare_finalize.num_dispatchers() - - if ( - prepare_finalize.activation_format - == FusedMoEActivationFormat.BatchedExperts - ): - logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__) - experts = CutlassBatchedExpertsFp8( - max_experts_per_worker=self.moe.num_local_experts, - num_dispatchers=num_dispatchers, - out_dtype=self.moe.in_dtype, - e=layer.local_num_experts, - n=layer.intermediate_size_per_partition, - k=layer.hidden_size, - device=layer.w13_weight.device, - quant_config=self.moe_quant_config, - ) - else: - logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) - experts = CutlassExpertsFp8( - out_dtype=self.moe.in_dtype, - e=layer.local_num_experts, - n=layer.intermediate_size_per_partition, - k=layer.hidden_size, - device=layer.w13_weight.device, - quant_config=self.moe_quant_config, - ) - - # TODO(rob): investigate disable_expert_map - self.disable_expert_map = ( - num_dispatchers > 1 or not experts.supports_expert_map() - ) - - return experts - - from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts, - ) - from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts, - ) - from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, + raise ValueError( + "CompressedTensorsW8A8Fp8MoE uses the new modular kernel " + "initialization logic. This function should not be called." ) - from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts, - ) - - assert self.fp8_backend not in [Fp8MoeBackend.AITER, Fp8MoeBackend.MARLIN] - - if ( - prepare_finalize.activation_format - == FusedMoEActivationFormat.BatchedExperts - ): - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - assert max_num_tokens_per_rank is not None - - if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: - logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__) - return BatchedDeepGemmExperts( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, - ) - else: - logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) - return BatchedTritonExperts( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, - ) - - else: - if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: - logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) - return TritonOrDeepGemmExperts(self.moe_quant_config) - else: - logger.debug("TritonExperts(%s)", self.__class__.__name__) - return TritonExperts(self.moe_quant_config) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - if self.fp8_backend == Fp8MoeBackend.MARLIN: - return fp8_w8a16_moe_quant_config( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - block_shape=self.weight_block_size, - ) - - per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN - per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL + w1_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + a1_scale = layer.w13_input_scale + a2_scale = layer.w2_input_scale - return fp8_w8a8_moe_quant_config( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - per_act_token_quant=per_act_token, - per_out_ch_quant=per_channel_quant, - block_shape=layer.weight_block_size, + return make_fp8_moe_quant_config( + fp8_backend=self.fp8_backend, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=( + self.input_quant.strategy == QuantizationStrategy.TOKEN + ), + per_out_ch_quant=( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + ), + block_shape=self.weight_block_size, ) def apply( @@ -959,6 +888,55 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + if layer.enable_eplb: + raise NotImplementedError( + "EPLB not supported for `FlashInfer TRTLLM FP8 MoE`." + ) + assert layer.activation == "silu" + + if self.block_quant: + import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 + + e_score_correction_bias = ( + layer.e_score_correction_bias.to(x.dtype) + if layer.e_score_correction_bias is not None + else None + ) + routing_method_type = layer.routing_method_type + return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( + routing_logits=router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits, + routing_bias=e_score_correction_bias, + x=x, + w13_weight=layer.w13_weight, + w13_weight_scale_inv=layer.w13_weight_scale, + w2_weight=layer.w2_weight, + w2_weight_scale_inv=layer.w2_weight_scale, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + intermediate_size=layer.intermediate_size_per_partition, + expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + block_shape=self.weight_block_size, + routing_method_type=routing_method_type, + routed_scaling=layer.routed_scaling_factor, + ) + else: + result = apply_fi_trtllm_fp8_per_tensor_moe( + layer=layer, + hidden_states=x, + router_logits=router_logits, + routing_bias=layer.e_score_correction_bias, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + ) topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, @@ -976,7 +954,7 @@ def apply( global_num_experts=layer.global_num_experts, # TODO(rob): investigate the disable_expert_map introduced by: # https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501 - expert_map=None if self.disable_expert_map else layer.expert_map, + expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) @@ -1435,8 +1413,7 @@ def select_gemm_impl( max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() assert max_num_tokens_per_rank is not None return BatchedMarlinExperts( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), + moe_config=self.moe, quant_config=self.moe_quant_config, w13_g_idx=layer.w13_weight_g_idx, w2_g_idx=layer.w2_weight_g_idx, @@ -1446,6 +1423,7 @@ def select_gemm_impl( ) else: return MarlinExperts( + moe_config=self.moe, quant_config=self.moe_quant_config, w13_g_idx=layer.w13_weight_g_idx, w2_g_idx=layer.w2_weight_g_idx, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1c0c35bf6f41..7fd0fa4f81cb 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -19,7 +19,6 @@ ) from vllm.model_executor.layers.fused_moe import ( FusedMoE, - FusedMoEActivationFormat, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, @@ -51,8 +50,6 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, - build_flashinfer_fp8_cutlass_moe_prepare_finalize, - select_cutlass_fp8_gemm_impl, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -76,6 +73,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, + kFp8StaticTensorSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, @@ -633,37 +633,40 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): self.weight_scale_name = ( "weight_scale_inv" if self.block_quant else "weight_scale" ) - self.fp8_backend = select_fp8_moe_backend( - block_quant=self.block_quant, - tp_size=layer.moe_parallel_config.tp_size, - with_lora_support=self.moe.is_lora_enabled, - ) - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - if self.block_quant and self.weight_block_size != [128, 128]: - raise NotImplementedError( - "FlashInfer CUTLASS FP8 MoE backend only supports block " - "size [128, 128]." - ) - if layer.activation != "silu": - raise NotImplementedError( - "FlashInfer CUTLASS FP8 MoE backend only supports SiLU " - "activation function, but got {layer.activation}." - ) - dynamic_per_token = ( - not self.block_quant and self.quant_config.activation_scheme != "static" - ) - if dynamic_per_token and self.fp8_backend in [ - Fp8MoeBackend.FLASHINFER_TRTLLM, - Fp8MoeBackend.FLASHINFER_CUTLASS, - ]: - raise NotImplementedError( - "FlashInfer FP8 MoE backend does not support dynamic per token " - "activation quantization." + # Set weight key and activation key for kernel compatibility + if self.block_quant: + weight_key = kFp8Static128BlockSym + activation_key = kFp8Dynamic128Sym + else: + weight_key = kFp8StaticTensorSym + activation_key = ( + kFp8StaticTensorSym + if self.quant_config.activation_scheme == "static" + else kFp8Dynamic128Sym ) + # Select Fp8 MoE backend + self.fp8_backend, self.experts_cls = select_fp8_moe_backend( + config=self.moe, + weight_key=weight_key, + activation_key=activation_key, + allow_vllm_cutlass=False, + ) + + # Delay creation of the kernel until after process-weights. self.kernel: mk.FusedMoEModularKernel | None = None + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None + + @property + def supports_mk_interally(self) -> bool: + return True + def create_weights( self, layer: Module, @@ -819,11 +822,13 @@ def _setup_kernel( # Setup modular kernel for TP case. self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: + assert self.experts_cls is not None self.kernel, self.use_inplace = make_fp8_moe_kernel( layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, + experts_cls=self.experts_cls, ) def process_weights_after_loading(self, layer: Module) -> None: @@ -878,93 +883,28 @@ def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if self.fp8_backend in [ - Fp8MoeBackend.AITER, - Fp8MoeBackend.MARLIN, - Fp8MoeBackend.FLASHINFER_TRTLLM, - ]: - return None - elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - self.moe, - use_deepseek_fp8_block_scale=self.block_quant, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + "Fp8FusedMoE uses the new modular kernel initialization logic. " + "This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: - from vllm.model_executor.layers.fused_moe import ( - BatchedDeepGemmExperts, - BatchedTritonExperts, - TritonExperts, - TritonOrDeepGemmExperts, - ) - - if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]: - raise NotImplementedError( - "Marlin and ROCm AITER are not supported with all2all yet." + # TODO(rob): look into whether we can just not do this for LoRA. + if self.moe.is_lora_enabled: + from vllm.model_executor.layers.fused_moe import ( + TritonExperts, ) - assert self.moe_quant_config is not None - - if ( - prepare_finalize.activation_format - == FusedMoEActivationFormat.BatchedExperts - ): - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - assert max_num_tokens_per_rank is not None - - experts_impl = ( - BatchedDeepGemmExperts - if self.fp8_backend == Fp8MoeBackend.DEEPGEMM - else BatchedTritonExperts - ) - logger.debug( - "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", - experts_impl.__name__, - self.__class__.__name__, - max_num_tokens_per_rank, - self.weight_block_size, - False, - ) - return experts_impl( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, - ) - elif self.moe.is_lora_enabled: return TritonExperts(quant_config=self.moe_quant_config) - elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - # Select GEMM experts with block-scale when weights are block-quantized - experts = select_cutlass_fp8_gemm_impl( - self.moe, - self.moe_quant_config, - use_deepseek_fp8_block_scale=self.block_quant, - ) - logger.debug_once("Using %s", experts.__class__.__name__) - return experts - elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM: - logger.debug( - "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", - self.__class__.__name__, - self.weight_block_size, - False, - ) - return TritonOrDeepGemmExperts(self.moe_quant_config) - else: - assert self.fp8_backend == Fp8MoeBackend.TRITON - logger.debug( - "TritonExperts(%s): block_size=%s, per_act_token=%s", - self.__class__.__name__, - self.weight_block_size, - False, - ) - return TritonExperts(self.moe_quant_config) + + raise ValueError( + "Fp8FusedMoE uses the new modular kernel initialization logic. " + "This function should not be called." + ) def get_fused_moe_quant_config( self, layer: torch.nn.Module diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index bcda7b42c2ec..d999351d8264 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -31,7 +31,6 @@ select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( - FLASHINFER_NVFP4_MOE_BACKENDS, NvFp4MoeBackend, convert_to_nvfp4_moe_kernel_format, is_global_sf_supported_for_nvfp4_backend, @@ -51,15 +50,11 @@ ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_prepare_finalize, flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_routed_moe, - select_nvfp4_gemm_impl, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, - build_flashinfer_fp8_cutlass_moe_prepare_finalize, - select_cutlass_fp8_gemm_impl, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -78,6 +73,9 @@ GroupShape, cutlass_fp4_supported, is_layer_skipped, + kFp8StaticTensorSym, + kNvfp4Dynamic, + kNvfp4Static, swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -729,45 +727,45 @@ def __init__( super().__init__(moe_config) self.quant_config = quant_config assert self.quant_config.is_checkpoint_fp8_serialized - self.fp8_backend = select_fp8_moe_backend( - block_quant=False, - tp_size=moe_config.moe_parallel_config.tp_size, - with_lora_support=self.moe.is_lora_enabled, + + # Select Fp8 MoE backend + self.fp8_backend, self.experts_cls = select_fp8_moe_backend( + config=self.moe, + weight_key=kFp8StaticTensorSym, + activation_key=kFp8StaticTensorSym, ) + + # Delay creation of the kernel until after process-weights. self.kernel: mk.FusedMoEModularKernel | None = None + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None + + @property + def supports_mk_interally(self) -> bool: + return True + def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - # TRT LLM not supported with all2all yet. - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - return None - elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - # TP case: avoid convert to ModularKernelMethod - to be refactored. - if self.moe.dp_size == 1: - return None - - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - self.moe, - use_deepseek_fp8_block_scale=False, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + "ModelOptFp8MoEMethod uses the new modular kernel initialization " + "logic. This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: - assert self.moe_quant_config is not None - experts = select_cutlass_fp8_gemm_impl( - self.moe, - self.moe_quant_config, + raise ValueError( + "ModelOptFp8MoEMethod uses the new modular kernel initialization " + "logic. This function should not be called." ) - logger.debug_once("Using %s", experts.__class__.__name__) - return experts def create_weights( self, @@ -879,14 +877,16 @@ def _setup_kernel( replace_parameter(layer, "w13_weight_scale", w13_scale) replace_parameter(layer, "w2_weight_scale", w2_scale) - # Setup modular kernel for TP case. + # Setup modular kernel. self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: + assert self.experts_cls is not None self.kernel, self.use_inplace = make_fp8_moe_kernel( layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, + experts_cls=self.experts_cls, ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -1335,55 +1335,49 @@ def __init__( ) -> None: super().__init__(moe_config) self.quant_config = quant_config - self.nvfp4_backend = select_nvfp4_moe_backend() - # TODO: move this type of check into the oracle. - if ( - not self.moe.is_act_and_mul - and not self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS - ): - raise NotImplementedError( - "Non-gated activations are only supported by FlashInfer " - "CUTLASS NvFP4 MoE backend." - ) + # Select experts implementation. + self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( + config=self.moe, + weight_key=kNvfp4Static, + activation_key=kNvfp4Dynamic, + ) + # Delay creation of the kernel until after process-weights. + self.kernel: mk.FusedMoEModularKernel | None = None + + # Used for weight loading. self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.nvfp4_backend ) - self.kernel: mk.FusedMoEModularKernel | None = None + + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None + + @property + def supports_mk_interally(self) -> bool: + return True def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM] - if self.nvfp4_backend in UNSUPPORTED: - return None - elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: - # TP case: avoid convert to ModularKernelMethod - to be refactored. - if self.moe.dp_size == 1: - return None - # For now, fp4 moe only works with the flashinfer dispatcher. - prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( - self.moe - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - else: - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + "ModelOptNvFp4FusedMoE uses the new modular kernel " + "initialization logic. This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: - assert self.moe_quant_config is not None - experts = select_nvfp4_gemm_impl( - self.moe, - self.moe_quant_config, - allow_flashinfer=self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS, + raise ValueError( + "ModelOptNvFp4FusedMoE uses the new modular kernel " + "initialization logic. This function should not be called." ) - logger.debug_once("Using %s", experts.__class__.__name__) - return experts def uses_weight_scale_2_pattern(self) -> bool: """ @@ -1556,12 +1550,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: replace_parameter(layer, "w2_input_scale", a2_scale) self.moe_quant_config = self.get_fused_moe_quant_config(layer) - use_dp = self.moe.dp_size > 1 - if self.moe_quant_config is not None and not use_dp: + if self.moe_quant_config is not None: + assert self.experts_cls is not None self.kernel = make_nvfp4_moe_kernel( - backend=self.nvfp4_backend, - quant_config=self.moe_quant_config, + layer=layer, + moe_quant_config=self.moe_quant_config, moe_config=self.moe, + experts_cls=self.experts_cls, ) def prepare_dp_allgather_tensor( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 912ff5a4a12a..8517162a9512 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -11,19 +11,13 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, - FusedMoEQuantConfig, + FusedMoEParallelConfig, RoutingMethodType, ) -from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( - FlashInferCuteDSLExperts, -) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, -) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - create_flashinfer_prepare_finalize, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kNvfp4Dynamic, + kNvfp4Static, swizzle_blockscale, ) from vllm.platforms import current_platform @@ -44,9 +38,74 @@ "is_flashinfer_fp4_cutlass_moe_available", "is_flashinfer_fp4_cutedsl_moe_available", "reorder_w1w3_to_w3w1", - "build_flashinfer_fp4_cutlass_moe_prepare_finalize", ] +# +# Methods used by the oracle for kernel selection. +# + + +def _supports_current_device() -> bool: + """Supports only Blackwell-family GPUs.""" + p = current_platform + # Add check flashinfer trtllm is available + return p.is_cuda() and p.is_device_capability_family(100) + + +def _supports_no_act_and_mul() -> bool: + """Does not support non-gated MoE (i.e. Nemotron-Nano).""" + return False + + +def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, +) -> bool: + """Supports Nvfp4 quantization.""" + SUPPORTED_W_A = [ + (kNvfp4Static, kNvfp4Dynamic), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + +def _supports_activation(activation: str) -> bool: + """Supports silu activation only.""" + return activation in ["silu"] + + +def _supports_moe_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + """Supports EP.""" + return True + + +def is_supported_config_trtllm( + moe_config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + activation_format: mk.FusedMoEActivationFormat, +) -> tuple[bool, str | None]: + """ + This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config + """ + + def _make_reason(reason: str) -> str: + return f"kernel does not support {reason}" + + if not _supports_current_device(): + return False, _make_reason("current device") + elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()): + return False, _make_reason("no act_and_mul MLP layer") + elif not _supports_activation(moe_config.activation): + return False, _make_reason(f"{moe_config.activation} activation") + elif not _supports_quant_scheme(weight_key, activation_key): + return False, _make_reason("quantization scheme") + elif not _supports_moe_parallel_config(moe_config.moe_parallel_config): + return False, _make_reason("parallel config") + elif activation_format != mk.FusedMoEActivationFormat.Standard: + return False, _make_reason("activation format") + + return True, None + def is_flashinfer_fp4_cutlass_moe_available() -> bool: """Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" @@ -85,48 +144,6 @@ def reorder_w1w3_to_w3w1( ) -def build_flashinfer_fp4_cutlass_moe_prepare_finalize( - moe: FusedMoEConfig, -) -> mk.FusedMoEPrepareAndFinalize: - """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" - use_dp = moe.moe_parallel_config.dp_size > 1 - enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv" - return create_flashinfer_prepare_finalize( - use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv - ) - - -def select_nvfp4_gemm_impl( - moe: FusedMoEConfig, - moe_quant_config: FusedMoEQuantConfig, - allow_flashinfer: bool, -) -> mk.FusedMoEPermuteExpertsUnpermute: - """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" - - if allow_flashinfer: - if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm": - return FlashInferCuteDSLExperts( - out_dtype=moe.in_dtype, - quant_config=moe_quant_config, - ) - elif envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput": - return FlashInferExperts( - out_dtype=moe.in_dtype, - quant_config=moe_quant_config, - ep_rank=moe.moe_parallel_config.ep_rank, - ep_size=moe.moe_parallel_config.ep_size, - tp_rank=moe.moe_parallel_config.tp_rank, - tp_size=moe.moe_parallel_config.tp_size, - use_dp=moe.moe_parallel_config.dp_size > 1, - ) - - # native cutlass experts currently don't support DP; TP case won't call this - raise ValueError( - "CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS " - "Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)" - ) - - def prepare_static_weights_for_trtllm_fp4_moe( # args_dequant, # args, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 799854479823..19e7124684ad 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -4,19 +4,8 @@ import torch -import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, - FusedMoEQuantConfig, -) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, -) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - create_flashinfer_prepare_finalize, -) from vllm.platforms import current_platform from vllm.utils.math_utils import round_up @@ -191,45 +180,6 @@ def make_fp8_moe_alpha_scales_for_fi( return g1_alphas, g2_alphas -def build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False -) -> mk.FusedMoEPrepareAndFinalize: - """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" - use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False - # Propagate block-scale flag so prepare/finalize can skip act quantization - # and inform the kernel to consume per-block weight scales. - return create_flashinfer_prepare_finalize( - use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale - ) - - -def select_cutlass_fp8_gemm_impl( - moe: FusedMoEConfig | None, - quant_config: FusedMoEQuantConfig, - out_dtype: torch.dtype | None = None, - use_deepseek_fp8_block_scale: bool = False, -) -> mk.FusedMoEPermuteExpertsUnpermute: - """Return a GEMM *experts* implementation for fused-MoE layers""" - - if moe is not None: - return FlashInferExperts( - out_dtype=moe.in_dtype, - quant_config=quant_config, - ep_rank=moe.moe_parallel_config.ep_rank, - ep_size=moe.moe_parallel_config.ep_size, - tp_rank=moe.moe_parallel_config.tp_rank, - tp_size=moe.moe_parallel_config.tp_size, - use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - ) - - assert out_dtype is not None, "If moe config is None, out_dtype must be passed" - return FlashInferExperts( - out_dtype=out_dtype, - quant_config=quant_config, - use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - ) - - def get_flashinfer_moe_backend() -> FlashinferMoeBackend: backend_map = { "throughput": FlashinferMoeBackend.CUTLASS, diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 08262ed1a314..15b7f3444cde 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -48,6 +48,7 @@ class GroupShape(_GroupShape): # Aliases for common quantization group shapes PER_TENSOR: ClassVar["GroupShape"] PER_TOKEN: ClassVar["GroupShape"] + PER_CHANNEL: ClassVar["GroupShape"] def is_per_tensor(self) -> bool: return self.row == -1 and self.col == -1 @@ -55,12 +56,16 @@ def is_per_tensor(self) -> bool: def is_per_token(self) -> bool: return self.row == 1 and self.col == -1 + def is_per_channel(self) -> bool: + return self.row == -1 and self.col == 1 + def is_per_group(self) -> bool: return self.row == 1 and self.col >= 1 GroupShape.PER_TENSOR = GroupShape(-1, -1) GroupShape.PER_TOKEN = GroupShape(1, -1) +GroupShape.PER_CHANNEL = GroupShape(-1, 1) @dataclass(frozen=True) @@ -77,16 +82,12 @@ class ScaleDesc: group_shape: GroupShape def __str__(self): - group_shape = ( - "per_tensor" - if self.group_shape == GroupShape.PER_TENSOR - else ( - "per_token" - if self.group_shape == GroupShape.PER_TOKEN - else str(self.group_shape) - ) - ) - + d = { + GroupShape.PER_TENSOR: "per_tensor", + GroupShape.PER_TOKEN: "per_token", + GroupShape.PER_CHANNEL: "per_channel", + } + group_shape = d.get(self.group_shape, str(self.group_shape)) return ( f"{fx.graph.dtype_abbrs[self.dtype]}," f"{'static' if self.static else 'dynamic'},{group_shape}" @@ -123,15 +124,28 @@ def __str__(self): kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) +kStaticChannelScale = ScaleDesc(torch.float32, True, GroupShape.PER_CHANNEL) +kFp8StaticChannelSym = QuantKey(FP8_DTYPE, kStaticChannelScale, symmetric=True) + kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) -kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) -kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale) +kNvfp4DynamicGroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) +kNvfp4Dynamic = QuantKey( + FP4_DTYPE, scale=kNvfp4DynamicGroupScale, scale2=kStaticTensorScale +) + +kNvfp4StaticGroupScale = ScaleDesc(FP8_DTYPE, True, GroupShape(1, 16)) +kNvfp4Static = QuantKey( + FP4_DTYPE, scale=kNvfp4StaticGroupScale, scale2=kStaticTensorScale +) kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128)) kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True) +kStatic128BlockScale = ScaleDesc(torch.float32, True, GroupShape(128, 128)) +kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True) + kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64)) kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index dde6db7c204b..7cc170638428 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -434,6 +434,7 @@ def load_moe_expert_weights( # Whether the MoE expert weights are loaded successfully. expert_param_loaded = False + loaded_weight = loaded_weight.to("cuda") # If fused is True, the loaded weight is in the layout of: # [num_experts, hidden_in, hidden_out], so we must transpose the last diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 9892c360d3d6..5ed860cc56a9 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -29,7 +29,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability @@ -1184,7 +1184,7 @@ def fused_output_quant_supported(self, quant_key: QuantKey): return ( self.support_trtllm_attn and self.kv_cache_dtype.startswith("fp8") - and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) + and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic) ) # FlashInfer requires attention sinks to be float32