diff --git a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py index d6b5820a5b41..10a3e3eab5fd 100644 --- a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py +++ b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py @@ -11,14 +11,20 @@ import torch import torch.utils.benchmark as benchmark +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.config import ( fp8_w8a8_moe_quant_config, nvfp4_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassExpertsFp4, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.scalar_type import scalar_types from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.v1.worker.workspace import init_workspace_manager @@ -188,19 +194,24 @@ def run_cutlass_moe_fp4( g1_alphas=w1_gs, g2_alphas=w2_gs, ) + + kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(defer_input_quant=True), + CutlassExpertsFp4( + out_dtype=dtype, + max_experts_per_worker=e, + quant_config=quant_config, + ), + ) + for _ in range(num_repeats): with nvtx.annotate("cutlass_moe_fp4", color="green"): - cutlass_moe_fp4( - a=a, - w1_fp4=w1_fp4, - w2_fp4=w2_fp4, + kernel( + hidden_states=a, + w1=w1_fp4, + w2=w2_fp4, topk_weights=topk_weights, topk_ids=topk_ids, - m=m, - n=n, - k=k, - e=num_experts, - quant_config=quant_config, ) def run_cutlass_from_graph( @@ -230,20 +241,24 @@ def run_cutlass_from_graph( g2_alphas=w2_gs, ) + kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(defer_input_quant=True), + CutlassExpertsFp4( + out_dtype=dtype, + max_experts_per_worker=e, + quant_config=quant_config, + ), + ) + with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): - return cutlass_moe_fp4( - a=a, - w1_fp4=w1_fp4, - w2_fp4=w2_fp4, + return kernel( + hidden_states=a, + w1=w1_fp4, + w2=w2_fp4, topk_weights=topk_weights, topk_ids=topk_ids, - m=m, - n=n, - k=k, - e=num_experts, - quant_config=quant_config, ) def run_triton_from_graph( diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index d683b538c415..18216b5965af 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -86,7 +86,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels | triton | standard | all1 | G,A,T | silu, gelu,
swigluoai,
silu_no_mul,
gelu_no_mul | Y | Y | [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts],
[`TritonExperts`][vllm.model_executor.layers.fused_moe.fused_moe.TritonExperts] | | triton (batched) | batched | all1 | G,A,T | silu, gelu | 6 | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] | | deep gemm | standard,
batched | fp8 | G(128),A,T | silu, gelu | 6 | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],
[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],
[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] | -| cutlass_fp4 | standard,
batched | nvfp4 | A,T | silu | Y | Y | [`cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp4],
[`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] | +| cutlass_fp4 | standard,
batched | nvfp4 | A,T | silu | Y | Y | [`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] | | cutlass_fp8 | standard,
batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],
[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] | | flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],
[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index fd7388e1cff8..873d72117de7 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -3,6 +3,7 @@ import pytest import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from tests.kernels.moe.utils import make_test_weights from tests.kernels.quantization.nvfp4_utils import ( FLOAT4_E2M1_MAX, @@ -13,8 +14,13 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassExpertsFp4, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed @@ -83,17 +89,21 @@ def test_cutlass_fp4_moe_no_graph( w2_scale=w2_blockscale, ) - cutlass_output = cutlass_moe_fp4( - a=a, - w1_fp4=w1_q, - w2_fp4=w2_q, + kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(defer_input_quant=True), + CutlassExpertsFp4( + out_dtype=dtype, + max_experts_per_worker=e, + quant_config=quant_config, + ), + ) + + cutlass_output = kernel( + hidden_states=a, + w1=w1_q, + w2=w2_q, topk_weights=topk_weights, topk_ids=topk_ids, - quant_config=quant_config, - m=m, - n=n, - k=k, - e=e, ) # Reference check: diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index e63404086ed9..9f04397e91f7 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -72,7 +72,6 @@ def get_config() -> dict[str, Any] | None: CutlassBatchedExpertsFp8, CutlassExpertsFp8, CutlassExpertsW4A8Fp8, - cutlass_moe_fp4, cutlass_moe_w4a8_fp8, ) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts @@ -95,7 +94,6 @@ def get_config() -> dict[str, Any] | None: "fused_experts", "get_config_file_name", "GroupedTopk", - "cutlass_moe_fp4", "cutlass_moe_w4a8_fp8", "CutlassExpertsFp8", "CutlassBatchedExpertsFp8", diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 17d5ec4bcda7..23b86fdca898 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -331,6 +331,10 @@ def use_fp8_w8a16(self) -> bool: def use_int4_w4a16(self) -> bool: return self._a1.dtype is None and self._w1.dtype == "int4" + @property + def use_nvfp4_w4a16(self) -> bool: + return self._a1.dtype is None and self._w1.dtype == "nvfp4" + @property def ocp_mx_scheme(self) -> str | None: if not hasattr(self, "_ocp_mx_scheme"): @@ -683,6 +687,25 @@ def nvfp4_moe_quant_config( ) +def nvfp4_w4a16_moe_quant_config( + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for 16-but activations and nvp4 weights. + """ + return FusedMoEQuantConfig.make( + quant_dtype=None, + w1_scale=w1_scale, + w2_scale=w2_scale, + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + weight_dtype="nvfp4", + ) + + def int4_w4a16_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 6e397f1e76a1..32ea040c743c 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -706,68 +706,6 @@ def apply( ) -def cutlass_moe_fp4( - a: torch.Tensor, - w1_fp4: torch.Tensor, - w2_fp4: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_config: FusedMoEQuantConfig, - m: int, - n: int, - k: int, - e: int, - expert_map: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, -) -> torch.Tensor: - assert expert_map is None, ( - "Expert Parallelism / expert_map " - "is currently not supported for " - "ModelOptNvFp4FusedMoE's cutlass_moe_fp4." - ) - - # TODO(bnell): this feels a bit hacky - # NVFP4 requires two levels of quantization, which involves - # computing some scaling factors dynamically. This makes it - # incompatible with the typical prepare -> MoE -> finalize - # pipeline. Move the quantization logic into the MoE body. - quant_config = FusedMoEQuantConfig.make( - quant_dtype=None, # skip quantization in prepare/finalize - per_act_token_quant=quant_config.per_act_token_quant, - per_out_ch_quant=quant_config.per_out_ch_quant, - block_shape=quant_config.block_shape, - g1_alphas=quant_config.g1_alphas, - g2_alphas=quant_config.g2_alphas, - a1_gscale=quant_config.a1_gscale, - a2_gscale=quant_config.a2_gscale, - w1_scale=quant_config.w1_scale, - w2_scale=quant_config.w2_scale, - ) - - fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - CutlassExpertsFp4( - max_experts_per_worker=e, - out_dtype=a.dtype, - quant_config=quant_config, - use_batched_format=False, - ), - ) - - return fn( - hidden_states=a, - w1=w1_fp4, - w2=w2_fp4, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=False, - activation="silu", - global_num_experts=e, - expert_map=None, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - - # W4A8 def run_cutlass_moe_w4a8_fp8( output: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index 6e0b57156cb3..ce93ae235f27 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -335,42 +335,3 @@ def flashinfer_cutedsl_moe_masked( alpha_dtype=get_cute_dtype(w2_alpha), ) # in logical [m, k, l] out = out.permute(2, 0, 1) - - -def flashinfer_cutedsl_moe_fp4( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_config: FusedMoEQuantConfig, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, -) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - create_flashinfer_prepare_finalize, - ) - - fused_experts = mk.FusedMoEModularKernel( - create_flashinfer_prepare_finalize(use_dp=False), # could be swapped later - FlashInferCuteDSLExperts( - out_dtype=hidden_states.dtype, - quant_config=quant_config, - ), - ) - - return fused_experts( - hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=inplace, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 0b0efdafbd4d..dfff860750d6 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -355,21 +355,17 @@ def create_flashinfer_prepare_finalize( use_deepseek_fp8_block_scale: bool = False, ) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP: """Factory function to create the appropriate FlashInfer implementation.""" - # TODO(rob): migrate non-DP cases to MoEPrepareAndFinalizeNoEP - # once we complete the FP8 refactor. - if use_nvfp4: - if enable_alltoallv: - return FlashInferAllToAllMoEPrepareAndFinalize(use_dp) - else: - return FlashInferAllGatherMoEPrepareAndFinalize(use_dp) - # FP8 DP path currently supported via AllGather. if use_dp: + if enable_alltoallv: + assert use_nvfp4 + return FlashInferAllToAllMoEPrepareAndFinalize(use_dp) return FlashInferAllGatherMoEPrepareAndFinalize( use_dp=True, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, ) else: - # NOTE(rob): CUTLASS FP8 block quant executes the input - # quantzation and grouped gemm in a single kernel. - return MoEPrepareAndFinalizeNoEP(defer_input_quant=use_deepseek_fp8_block_scale) + # CUTLASS FP8 BLOCK and CUTLASS NVFP4 apply input quantization + # in a single call with the MoE experts kernel. + defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4 + return MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index c031d9efcc0b..e82a838959de 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -540,9 +540,10 @@ def __init__( # TODO (varun) : Enable activation quantization assert ( quant_config.use_mxfp4_w4a16 + or quant_config.use_nvfp4_w4a16 or quant_config.use_int4_w4a16 or quant_config.use_fp8_w8a16 - ), "Supports only mxfp4_w4a16, int4_w4a16 or fp8_w8a16" + ), "Supports only {mxfp,nvfp,int}4_w4a16 or fp8_w8a16" self.w13_g_idx = w13_g_idx self.w2_g_idx = w2_g_idx self.w13_g_idx_sort_indices = w13_g_idx_sort_indices @@ -555,7 +556,7 @@ def quant_type_id(self) -> int: # uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4 if self.quant_config.use_int4_w4a16: return scalar_types.uint4b8.id - elif self.quant_config.use_mxfp4_w4a16: + elif self.quant_config.use_mxfp4_w4a16 or self.quant_config.use_nvfp4_w4a16: return scalar_types.float4_e2m1f.id elif ( self.quant_config.use_fp8_w8a16 @@ -692,6 +693,8 @@ def apply( gating_output=None, topk_weights=topk_weights, topk_ids=topk_ids, + global_scale1=self.g1_alphas, + global_scale2=self.g2_alphas, quant_type_id=self.quant_type_id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index fb441963a97d..3a1860b472d0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -38,9 +38,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, ) -from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - is_flashinfer_supporting_global_sf, -) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe from vllm.utils.math_utils import cdiv, round_up @@ -1119,14 +1116,9 @@ def weight_loader( global_expert_id = expert_id expert_id = self._map_global_expert_id_to_local_expert_id(global_expert_id) - allow_flashinfer = getattr(self.quant_method, "allow_flashinfer", False) - moe_backend = getattr(self.quant_method, "flashinfer_moe_backend", None) - use_global_sf = ( - allow_flashinfer - and is_flashinfer_supporting_global_sf(moe_backend) + getattr(self.quant_method, "use_global_sf", False) and "input_scale" in weight_name - and quant_method_name == "ModelOptNvFp4FusedMoE" ) if expert_id == -1 and not use_global_sf: diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py new file mode 100644 index 000000000000..547a2a795d19 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum + +import torch + +import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, + nvfp4_moe_quant_config, + nvfp4_w4a16_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassExpertsFp4, +) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, +) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + MarlinExperts, +) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) +from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( + is_flashinfer_fp4_cutedsl_moe_available, + is_flashinfer_fp4_cutlass_moe_available, + prepare_nvfp4_moe_layer_for_fi_or_cutlass, +) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + FlashinferMoeBackend, + get_flashinfer_moe_backend, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + is_fp4_marlin_supported, + prepare_nvfp4_moe_layer_for_marlin, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + cutlass_fp4_supported, +) + +logger = init_logger(__name__) + + +class NvFp4MoeBackend(Enum): + FLASHINFER_CUTLASS = "FlashInfer CUTLASS" + FLASHINFER_TRTLLM = "FlashInfer TRTLLM" + FLASHINFER_CUTEDSL = "FlashInfer CUTEDSL" + VLLM_CUTLASS = "vLLM CUTASS" + MARLIN = "vLLM MARLIN" + + +FLASHINFER_NVFP4_MOE_BACKENDS = [ + NvFp4MoeBackend.FLASHINFER_CUTLASS, + NvFp4MoeBackend.FLASHINFER_TRTLLM, + NvFp4MoeBackend.FLASHINFER_CUTEDSL, +] + +fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = { + FlashinferMoeBackend.CUTLASS: NvFp4MoeBackend.FLASHINFER_CUTLASS, + FlashinferMoeBackend.TENSORRT_LLM: NvFp4MoeBackend.FLASHINFER_TRTLLM, + FlashinferMoeBackend.CUTEDSL: NvFp4MoeBackend.FLASHINFER_CUTEDSL, +} + + +def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool: + # Checks whether `backend` supports quantizing with scaling factors + # of all experts in Expert Parallel Mode when all experts are not + # on the same rank. + + return backend in [ + NvFp4MoeBackend.FLASHINFER_CUTLASS, + NvFp4MoeBackend.FLASHINFER_TRTLLM, + ] + + +def select_nvfp4_moe_backend() -> NvFp4MoeBackend: + def _make_log_backend(backend: NvFp4MoeBackend): + return f"Using {backend.value} backend for NvFp4 MoE" + + if cutlass_fp4_supported() and not envs.VLLM_TEST_FORCE_FP8_MARLIN: + allow_flashinfer = ( + is_flashinfer_fp4_cutlass_moe_available() + or is_flashinfer_fp4_cutedsl_moe_available() + ) + if allow_flashinfer and envs.VLLM_USE_FLASHINFER_MOE_FP4: + backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()] + else: + backend = NvFp4MoeBackend.VLLM_CUTLASS + elif is_fp4_marlin_supported(): + backend = NvFp4MoeBackend.MARLIN + else: + raise ValueError("No NvFp4 kernel backend available for NvFp4 MoE.") + + # Log warning if FI backend requested but not available. + if ( + backend not in FLASHINFER_NVFP4_MOE_BACKENDS + and envs.VLLM_USE_FLASHINFER_MOE_FP4 + ): + logger.warning_once( + "Requested FlashInfer backend for NvFp4 MoE, but it's not available. " + "Falling back to %s for NvFp4 MoE", + backend.value, + scope="local", + ) + else: + logger.info_once(_make_log_backend(backend), scope="local") + return backend + + +def convert_to_nvfp4_moe_kernel_format( + nvfp4_backend: NvFp4MoeBackend, + layer: torch.nn.Module, + w13: torch.Tensor, + w13_scale: torch.Tensor, + w13_scale_2: torch.Tensor, + a13_scale: torch.Tensor | None, + w2: torch.Tensor, + w2_scale: torch.Tensor, + w2_scale_2: torch.Tensor, + a2_scale: torch.Tensor | None, + is_act_and_mul: bool, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + if ( + nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS + or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS + ): + ( + w13, + w13_scale, + w13_scale_2, + a13_scale, + w2, + w2_scale, + w2_scale_2, + a2_scale, + ) = prepare_nvfp4_moe_layer_for_fi_or_cutlass( + backend=nvfp4_backend, + layer=layer, + w13=w13, + w13_scale=w13_scale, + w13_scale_2=w13_scale_2, + a13_scale=a13_scale, + w2=w2, + w2_scale=w2_scale, + w2_scale_2=w2_scale_2, + a2_scale=a2_scale, + is_act_and_mul=is_act_and_mul, + ) + elif nvfp4_backend == NvFp4MoeBackend.MARLIN: + a13_scale = None + a2_scale = None + ( + w13, + w13_scale, + w13_scale_2, + w2, + w2_scale, + w2_scale_2, + ) = prepare_nvfp4_moe_layer_for_marlin( + layer=layer, + w13=w13, + w13_scale=w13_scale, + w13_scale_2=w13_scale_2, + w2=w2, + w2_scale=w2_scale, + w2_scale_2=w2_scale_2, + ) + else: + raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}") + + return ( + w13, + w13_scale, + w13_scale_2, + a13_scale, + w2, + w2_scale, + w2_scale_2, + a2_scale, + ) + + +def make_nvfp4_moe_quant_config( + backend: NvFp4MoeBackend, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + w13_scale_2: torch.Tensor, + w2_scale_2: torch.Tensor, + a13_scale: torch.Tensor, + a2_scale: torch.Tensor, +) -> FusedMoEQuantConfig | None: + UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM] + if backend in UNSUPPORTED: + return None + + elif backend == NvFp4MoeBackend.MARLIN: + return nvfp4_w4a16_moe_quant_config( + g1_alphas=w13_scale_2, + g2_alphas=w2_scale_2, + w1_scale=w13_scale, + w2_scale=w2_scale, + ) + + g1_alphas = a13_scale * w13_scale_2 + g2_alphas = a2_scale * w2_scale_2 + return nvfp4_moe_quant_config( + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + a1_gscale=(1.0 / a13_scale), + a2_gscale=(1.0 / a2_scale), + w1_scale=w13_scale, + w2_scale=w2_scale, + ) + + +def make_nvfp4_moe_kernel( + backend: NvFp4MoeBackend, + quant_config: FusedMoEQuantConfig, + moe_config: FusedMoEConfig, +) -> mk.FusedMoEModularKernel | None: + assert moe_config.dp_size == 1 + + UNSUPPORTED_BACKENDS = [ + # TRTLLM does not use the modular kernl abstraction. + NvFp4MoeBackend.FLASHINFER_TRTLLM, + # CUTEDSL is used with BATCHED (masked) format only. + # TODO: add here once we support dp/ep via the oracle. + NvFp4MoeBackend.FLASHINFER_CUTEDSL, + ] + + if backend in UNSUPPORTED_BACKENDS: + return None + + elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: + return mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(defer_input_quant=True), + FlashInferExperts( + out_dtype=moe_config.in_dtype, + quant_config=quant_config, + ep_rank=moe_config.ep_rank, + ep_size=moe_config.ep_size, + tp_rank=moe_config.tp_rank, + tp_size=moe_config.tp_size, + use_dp=False, + use_deepseek_fp8_block_scale=False, + ), + ) + + elif backend == NvFp4MoeBackend.VLLM_CUTLASS: + return mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(defer_input_quant=True), + CutlassExpertsFp4( + out_dtype=moe_config.in_dtype, + # TODO(rob): see what impact this has on expert map? + max_experts_per_worker=moe_config.num_experts, + quant_config=quant_config, + ), + ) + + elif backend == NvFp4MoeBackend.MARLIN: + return mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + MarlinExperts(quant_config=quant_config), + ) + + else: + raise ValueError(f"Unknown NvFp4 MoE backend: {backend}") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index a2b3aec4457e..509de5dff9c1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -11,7 +11,6 @@ QuantizationArgs, QuantizationStrategy, ) -from torch.nn.parameter import Parameter import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops @@ -34,12 +33,8 @@ int4_w4afp8_moe_quant_config, int8_w8a8_moe_quant_config, int8_w8a16_moe_quant_config, - nvfp4_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - is_valid_flashinfer_cutlass_fused_moe, -) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( BatchedMarlinExperts, MarlinExperts, @@ -51,6 +46,15 @@ make_fp8_moe_kernel, select_fp8_moe_backend, ) +from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( + FLASHINFER_NVFP4_MOE_BACKENDS, + NvFp4MoeBackend, + convert_to_nvfp4_moe_kernel_format, + is_global_sf_supported_for_nvfp4_backend, + make_nvfp4_moe_kernel, + make_nvfp4_moe_quant_config, + select_nvfp4_moe_backend, +) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP, @@ -58,14 +62,9 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, flashinfer_trtllm_fp4_moe, - prepare_static_weights_for_trtllm_fp4_moe, - reorder_w1w3_to_w3w1, + flashinfer_trtllm_fp4_routed_moe, select_nvfp4_gemm_impl, ) -from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - FlashinferMoeBackend, - get_flashinfer_moe_backend, -) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( process_fp8_input_tensor_strategy_moe, process_fp8_weight_tensor_strategy_moe, @@ -77,20 +76,15 @@ marlin_make_workspace_new, marlin_moe_permute_scales, ) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - prepare_moe_fp4_layer_for_marlin, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( convert_bf16_scales_to_fp8, convert_packed_uint4b8_to_signed_int4_inplace, - swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, ) from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import CpuArchEnum, current_platform -from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -218,31 +212,19 @@ def get_moe_method( class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None): - from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 - detect_nvfp4_moe_support, - ) + if not moe.is_act_and_mul: + raise ValueError( + "CompressedTensorsW4A4Nvfp4MoEMethod does not yet " + "support non gated MoE models." + ) super().__init__(moe) - _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) - self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported - self.allow_flashinfer = _nvfp4.allow_flashinfer - self.use_marlin = _nvfp4.use_marlin self.group_size = 16 - self.layer_name = layer_name - self.marlin_input_dtype = ( - get_marlin_input_dtype(layer_name) if self.use_marlin else None - ) - self.flashinfer_moe_backend = None - if self.allow_flashinfer: - self.flashinfer_moe_backend = get_flashinfer_moe_backend() - logger.info_once( - f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" - " for CompressedTensorsW4A4Nvfp4MoEMethod." - ) - elif self.use_marlin: - logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoEMethod.") - else: - logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoEMethod.") + self.nvfp4_backend = select_nvfp4_moe_backend() + self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( + self.nvfp4_backend + ) + self.kernel: mk.FusedMoEModularKernel | None = None def create_weights( self, @@ -355,7 +337,13 @@ def create_weights( set_weight_attrs(w2_input_scale, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # From packed to weight + """ + Convert NVFP4 MoE weights into kernel format and setup the kernel. + """ + # NOTE(rob): wN_weight_packed -> wN_weight is because ModularKernelMethod + # requires this naming convention. However, the name change breaks + # reloading because the state dict no longer matches disk. Once we + # remove MKM, we should revert this change to ensure compatibility. layer.w13_weight = torch.nn.Parameter( layer.w13_weight_packed.data, requires_grad=False ) @@ -366,144 +354,79 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) delattr(layer, "w2_weight_packed") - # reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel. - if self.allow_flashinfer: - w, s = reorder_w1w3_to_w3w1( - layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2 - ) - layer.w13_weight = torch.nn.Parameter(w, requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False) - - if not torch.allclose( + # Use a single gscale for w13. + if self.moe.is_act_and_mul and not torch.allclose( layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1] ): logger.warning_once( "w1_weight_global_scale must match w3_weight_global_scale. " - "Accuracy may be affected." - ) - - # Take inverse of global scale saved to disk - layer.w13_weight_scale_2 = torch.nn.Parameter( - 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False - ) - - layer.w2_weight_scale_2 = torch.nn.Parameter( - 1 / layer.w2_weight_global_scale.data, requires_grad=False - ) - - if self.use_marlin: - prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype) - return - # w13 - if ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - ): - w13_input_global_scale = ( - layer.w13_input_global_scale.min() - .to(torch.float32) - .expand(layer.num_experts) - ) - else: - w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to( - torch.float32 - ) - layer.g1_alphas = torch.nn.Parameter( - ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), - requires_grad=False, - ) - - layer.w13_input_scale_quant = torch.nn.Parameter( - (w13_input_global_scale), requires_grad=False - ) - - # w2 - if ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - ): - w2_input_global_scale = ( - layer.w2_input_global_scale.min() - .to(torch.float32) - .expand(layer.num_experts) - ) - else: - w2_input_global_scale = layer.w2_input_global_scale - - layer.g2_alphas = torch.nn.Parameter( - ((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32), - requires_grad=False, - ) - - layer.w2_input_scale_quant = torch.nn.Parameter( - (w2_input_global_scale), requires_grad=False + "Accuracy may be affected.", + ) + w13_weight_global_scale = layer.w13_weight_global_scale[:, 0].contiguous() + + # Shuffle weights into the NvFp4 kernel format. + ( + w13, + w13_scale, + w13_scale_2, + a13_scale, + w2, + w2_scale, + w2_scale_2, + a2_scale, + ) = convert_to_nvfp4_moe_kernel_format( + nvfp4_backend=self.nvfp4_backend, + layer=layer, + w13=layer.w13_weight, + w13_scale=layer.w13_weight_scale, + w13_scale_2=(1.0 / w13_weight_global_scale), + a13_scale=(1.0 / layer.w13_input_global_scale), + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + w2_scale_2=(1.0 / layer.w2_weight_global_scale), + a2_scale=(1.0 / layer.w2_input_global_scale), + is_act_and_mul=self.moe.is_act_and_mul, ) - # TensorRT-LLM specific processing - if ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - ): - # Prepare static weights for TRT-LLM kernel - # alternate: prepare_static_weight_layouts_for_trtllm_moe - ( - gemm1_weights_fp4_shuffled, - gemm1_scales_fp4_shuffled, - gemm2_weights_fp4_shuffled, - gemm2_scales_fp4_shuffled, - ) = prepare_static_weights_for_trtllm_fp4_moe( - layer.w13_weight, - layer.w2_weight, - layer.w13_weight_scale, - layer.w2_weight_scale, - layer.w2_weight.size(-2), # hidden_size - layer.w13_weight.size(-2) // 2, # intermediate_size - layer.w13_weight.size(0), # num_experts - ) - logger.debug_once("Finished shuffling weights for TRT-LLM MOE") - - layer.w13_weight = Parameter( - gemm1_weights_fp4_shuffled, requires_grad=False - ) - layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False) - layer.w13_weight_scale = Parameter( - gemm1_scales_fp4_shuffled, requires_grad=False - ) - layer.w2_weight_scale = Parameter( - gemm2_scales_fp4_shuffled, requires_grad=False - ) - - # Additional parameter needed for TRT-LLM - layer.g1_scale_c = Parameter( - (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), - requires_grad=False, - ) - else: - # swizzle weight scales - layer.w13_weight_scale = torch.nn.Parameter( - swizzle_blockscale(layer.w13_weight_scale), requires_grad=False - ) + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w13_weight_scale", w13_scale) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, "w2_weight_scale", w2_scale) + layer.w13_weight_scale_2 = w13_scale_2 + layer.w2_weight_scale_2 = w2_scale_2 + layer.w13_input_scale = a13_scale + layer.w2_input_scale = a2_scale - layer.w2_weight_scale = torch.nn.Parameter( - swizzle_blockscale(layer.w2_weight_scale), requires_grad=False + # Initialize the kernel that will be called in apply(). + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + use_dp = self.moe.dp_size > 1 + if self.moe_quant_config is not None and not use_dp: + self.kernel = make_nvfp4_moe_kernel( + backend=self.nvfp4_backend, + quant_config=self.moe_quant_config, + moe_config=self.moe, ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if self.use_marlin or ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - ): + UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM] + if self.nvfp4_backend in UNSUPPORTED: return None - elif not self.allow_flashinfer: + elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: + # TP case: avoid convert to ModularKernelMethod - to be refactored. + if self.moe.dp_size == 1: + return None + # For now, fp4 moe only works with the flashinfer dispatcher. + prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( + self.moe + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + else: return super().maybe_make_prepare_finalize(routing_tables) - prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, @@ -514,7 +437,7 @@ def select_gemm_impl( experts = select_nvfp4_gemm_impl( self.moe, self.moe_quant_config, - allow_flashinfer=self.allow_flashinfer, + allow_flashinfer=(self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS), ) logger.debug_once("Using %s", experts.__class__.__name__) return experts @@ -522,19 +445,14 @@ def select_gemm_impl( def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - if ( - self.use_marlin - or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - ): - return None - - return nvfp4_moe_quant_config( - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - w1_scale=layer.w13_weight_scale, + return make_nvfp4_moe_quant_config( + backend=self.nvfp4_backend, + w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, + w13_scale_2=layer.w13_weight_scale_2, + w2_scale_2=layer.w2_weight_scale_2, + a13_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, ) def apply( @@ -546,14 +464,9 @@ def apply( assert layer.activation == "silu", "Only SiLU activation is supported." if ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM + and not layer.enable_eplb ): - if layer.enable_eplb: - raise NotImplementedError( - "EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet." - ) - return flashinfer_trtllm_fp4_moe( layer=layer, x=x, @@ -566,79 +479,41 @@ def apply( e_score_correction_bias=layer.e_score_correction_bias, ) + # Hidden_states in select_experts is only used to extract metadata + if isinstance(x, tuple): + x_routing, _ = x + else: + x_routing = x topk_weights, topk_ids = layer.select_experts( - hidden_states=x, + hidden_states=x_routing, router_logits=router_logits, ) - if self.use_marlin: - return fused_marlin_moe( + # EPLB path + if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: + assert layer.enable_eplb + return flashinfer_trtllm_fp4_routed_moe( + layer=layer, + x=x, + topk_ids=topk_ids, + topk_weights=topk_weights, + top_k=layer.top_k, + global_num_experts=layer.global_num_experts, + ) + else: + assert self.kernel is not None + return self.kernel( x, layer.w13_weight, layer.w2_weight, - None, - None, - layer.w13_weight_scale, - layer.w2_weight_scale, - router_logits, topk_weights, topk_ids, - global_scale1=layer.w13_weight_scale_2, - global_scale2=layer.w2_weight_scale_2, - quant_type_id=scalar_types.float4_e2m1f.id, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - input_dtype=self.marlin_input_dtype, - workspace=layer.workspace, - ) - - # FlashInfer fused experts path - elif self.allow_flashinfer: - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 - flashinfer_cutlass_moe_fp4, - ) - - assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, layer.w2_weight - ), "Flashinfer CUTLASS Fused MoE not applicable!" - - assert self.moe_quant_config is not None - - return flashinfer_cutlass_moe_fp4( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - quant_config=self.moe_quant_config, - inplace=False, # TODO(shuw): fix later, now output is high prec + inplace=False, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - else: - # If no modular kernel is provided, use cutlass_moe_fp4 for TP case - # only (no EP). - from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 - - assert self.moe_quant_config is not None - return cutlass_moe_fp4( - a=x, - w1_fp4=layer.w13_weight, - w2_fp4=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - quant_config=self.moe_quant_config, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - # TODO(bnell): derive these from arguments - m=x.shape[0], - n=layer.w2_weight.shape[2] * 2, - k=x.shape[1], - e=layer.w13_weight.shape[0], - ).to(x.dtype) class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index bd7a90a80af1..2e4f1daf6690 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -15,9 +15,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, - nvfp4_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -30,6 +28,15 @@ make_fp8_moe_quant_config, select_fp8_moe_backend, ) +from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( + FLASHINFER_NVFP4_MOE_BACKENDS, + NvFp4MoeBackend, + convert_to_nvfp4_moe_kernel_format, + is_global_sf_supported_for_nvfp4_backend, + make_nvfp4_moe_kernel, + make_nvfp4_moe_quant_config, + select_nvfp4_moe_backend, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -45,16 +52,11 @@ build_flashinfer_fp4_cutlass_moe_prepare_finalize, flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_routed_moe, - prepare_static_weights_for_trtllm_fp4_moe, - reorder_w1w3_to_w3w1, select_nvfp4_gemm_impl, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - FlashinferMoeBackend, apply_fi_trtllm_fp8_per_tensor_moe, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - get_flashinfer_moe_backend, - is_flashinfer_supporting_global_sf, select_cutlass_fp8_gemm_impl, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -69,7 +71,6 @@ apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, - prepare_moe_fp4_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -89,7 +90,6 @@ PerTensorScaleParameter, ) from vllm.model_executor.utils import replace_parameter -from vllm.scalar_type import scalar_types from vllm.utils.flashinfer import ( flashinfer_scaled_fp4_mm, has_flashinfer, @@ -1327,43 +1327,32 @@ def __init__( quant_config: ModelOptNvFp4Config, layer: FusedMoE, ) -> None: - from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( - detect_nvfp4_moe_support, # noqa: E501 - ) - super().__init__(layer.moe_config) self.quant_config = quant_config - self.layer = layer - _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) - self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported - self.allow_flashinfer = _nvfp4.allow_flashinfer - self.use_marlin = _nvfp4.use_marlin - self.marlin_input_dtype = None - self.flashinfer_moe_backend = None - if self.allow_flashinfer: - self.flashinfer_moe_backend = get_flashinfer_moe_backend() - logger.info_once( - f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" - " for ModelOptNvFp4FusedMoE." + self.nvfp4_backend = select_nvfp4_moe_backend() + # TODO: move this type of check into the oracle. + if ( + not self.moe.is_act_and_mul + and not self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS + ): + raise NotImplementedError( + "Non-gated activations are only supported by FlashInfer " + "CUTLASS NvFP4 MoE backend." ) - elif self.use_marlin: - logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.") - else: - logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.") + + self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( + self.nvfp4_backend + ) + self.kernel: mk.FusedMoEModularKernel | None = None def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if self.use_marlin or ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - ): + UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM] + if self.nvfp4_backend in UNSUPPORTED: return None - elif ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS - ): + elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: # TP case: avoid convert to ModularKernelMethod - to be refactored. if self.moe.dp_size == 1: return None @@ -1385,7 +1374,7 @@ def select_gemm_impl( experts = select_nvfp4_gemm_impl( self.moe, self.moe_quant_config, - allow_flashinfer=self.allow_flashinfer, + allow_flashinfer=self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts @@ -1405,11 +1394,7 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - if not self.quant_config.is_checkpoint_nvfp4_serialized: - raise ValueError( - "NVFP4 quantization was selected, " - " dynamic quantization is not supported." - ) + assert self.quant_config.is_checkpoint_nvfp4_serialized layer.num_experts = num_experts layer.params_dtype = params_dtype @@ -1498,14 +1483,12 @@ def create_weights( {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) - use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf( - self.flashinfer_moe_backend + global_sf_num_experts = ( + global_num_experts if self.use_global_sf else num_experts ) - global_scale_num_experts = global_num_experts if use_global_sf else num_experts - w13_input_scale = PerTensorScaleParameter( data=torch.empty( - global_scale_num_experts, + global_sf_num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32, ), @@ -1514,32 +1497,17 @@ def create_weights( layer.register_parameter("w13_input_scale", w13_input_scale) w2_input_scale = PerTensorScaleParameter( - data=torch.empty(global_scale_num_experts, dtype=torch.float32), + data=torch.empty(global_sf_num_experts, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w2_input_scale", w2_input_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # GEMM 1 processing - gemm1_weight = layer.w13_weight.data - gemm1_weight_scale = layer.w13_weight_scale.data - - if ( - self.allow_flashinfer - and ( - self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS - or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - ) - and self.moe.is_act_and_mul - ): - gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1( - gemm1_weight, gemm1_weight_scale, dim=-2 - ) - - layer.w13_weight = Parameter(gemm1_weight, requires_grad=False) - layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False) + """ + Convert NVFP4 MoE weights into kernel format and setup the kernel. + """ - # Common processing for w13_weight_scale_2 + # Use a single gscale for w13. if self.moe.is_act_and_mul and not torch.allclose( layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] ): @@ -1547,136 +1515,47 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: "w1_weight_scale_2 must match w3_weight_scale_2. " "Accuracy may be affected." ) - w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous() - layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) - # Common processing for input scales and alphas - use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf( - self.flashinfer_moe_backend - ) - if use_global_sf: - # For backends provide by Flashinfer, the input global scales are - # shared across all experts. - w13_input_scale = ( - layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts) - ) - else: - w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) - layer.g1_alphas = Parameter( - (w13_input_scale * w13_weight_scale_2).to(torch.float32), - requires_grad=False, - ) - - # This is for quantization, so we need to invert it. - layer.w13_input_scale_quant = Parameter( - (1 / w13_input_scale).to(torch.float32), requires_grad=False - ) - - # GEMM 2 processing - if use_global_sf: - # For backends provide by Flashinfer, the input global scales are - # shared across all experts. - w2_input_scale = ( - layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts) - ) - else: - w2_input_scale = layer.w2_input_scale - layer.g2_alphas = Parameter( - (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), - requires_grad=False, - ) - - # This is for quantization, so we need to invert it. - layer.w2_input_scale_quant = Parameter( - (1 / w2_input_scale).to(torch.float32), requires_grad=False + ( + w13, + w13_scale, + w13_scale_2, + a13_scale, + w2, + w2_scale, + w2_scale_2, + a2_scale, + ) = convert_to_nvfp4_moe_kernel_format( + nvfp4_backend=self.nvfp4_backend, + layer=layer, + w13=layer.w13_weight, + w13_scale=layer.w13_weight_scale, + w13_scale_2=w13_weight_scale_2, + a13_scale=layer.w13_input_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + w2_scale_2=layer.w2_weight_scale_2, + a2_scale=layer.w2_input_scale, + is_act_and_mul=self.moe.is_act_and_mul, ) - # TensorRT-LLM specific processing - if ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - ): - # Prepare static weights for TRT-LLM kernel - # alternate: prepare_static_weight_layouts_for_trtllm_moe - ( - gemm1_weights_fp4_shuffled, - gemm1_scales_fp4_shuffled, - gemm2_weights_fp4_shuffled, - gemm2_scales_fp4_shuffled, - ) = prepare_static_weights_for_trtllm_fp4_moe( - layer.w13_weight, - layer.w2_weight, - layer.w13_weight_scale, - layer.w2_weight_scale, - layer.w2_weight.size(-2), # hidden_size - layer.w13_weight.size(-2) // 2, # intermediate_size - layer.w13_weight.size(0), # num_experts - ) - logger.debug_once("Finished shuffling weights for TRT-LLM MOE") - - layer.w13_weight = Parameter( - gemm1_weights_fp4_shuffled, requires_grad=False - ) - layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False) - layer.w13_weight_scale = Parameter( - gemm1_scales_fp4_shuffled, requires_grad=False - ) - layer.w2_weight_scale = Parameter( - gemm2_scales_fp4_shuffled, requires_grad=False - ) - - # Additional parameter needed for TRT-LLM - layer.g1_scale_c = Parameter( - (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), - requires_grad=False, - ) - elif self.use_marlin: - # Marlin processing - prepare_moe_fp4_layer_for_marlin(layer) - del layer.g1_alphas - del layer.g2_alphas - del layer.w13_input_scale_quant - del layer.w2_input_scale_quant - else: - # Non-TRT-LLM processing (Cutlass or non-flashinfer) - w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) - layer.w13_weight_scale = Parameter( - w13_blockscale_swizzled, requires_grad=False - ) - - w13_weight = layer.w13_weight - intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1) - if intermediate_size_pad: - # padding gated activations will require to split w1 and w3 - # and pad them individually - assert not self.moe.is_act_and_mul, ( - "The intermediate size required padding, " - "but padding is not implemented for gated activations" - ) - - layer.w13_weight = Parameter( - torch.nn.functional.pad( - w13_weight, (0, 0, 0, intermediate_size_pad) - ), - requires_grad=False, - ) - layer.w2_weight = Parameter( - torch.nn.functional.pad( - layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0) - ), - requires_grad=False, - ) - layer.w2_weight_scale = Parameter( - torch.nn.functional.pad( - layer.w2_weight_scale, (0, intermediate_size_pad // 16) - ), - requires_grad=False, - ) + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w13_weight_scale", w13_scale) + replace_parameter(layer, "w13_weight_scale_2", w13_scale_2) + replace_parameter(layer, "w13_input_scale", a13_scale) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, "w2_weight_scale", w2_scale) + replace_parameter(layer, "w2_weight_scale_2", w2_scale_2) + replace_parameter(layer, "w2_input_scale", a2_scale) - w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) - layer.w2_weight_scale = Parameter( - w2_blockscale_swizzled, requires_grad=False + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + use_dp = self.moe.dp_size > 1 + if self.moe_quant_config is not None and not use_dp: + self.kernel = make_nvfp4_moe_kernel( + backend=self.nvfp4_backend, + quant_config=self.moe_quant_config, + moe_config=self.moe, ) def prepare_dp_allgather_tensor( @@ -1688,7 +1567,8 @@ def prepare_dp_allgather_tensor( """Optionally prepare extra tensors to carry through DP allgather/EP.""" import flashinfer - a1_gscale = layer.w13_input_scale_quant + assert self.moe_quant_config is not None + a1_gscale = self.moe_quant_config.a1_gscale hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize( hidden_states, a1_gscale, @@ -1700,19 +1580,14 @@ def prepare_dp_allgather_tensor( def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - if ( - self.use_marlin - or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - ): - return None - - return nvfp4_moe_quant_config( - w1_scale=layer.w13_weight_scale, + return make_nvfp4_moe_quant_config( + backend=self.nvfp4_backend, + w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, + w13_scale_2=layer.w13_weight_scale_2, + w2_scale_2=layer.w2_weight_scale_2, + a13_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, ) @property @@ -1725,18 +1600,8 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if not self.moe.is_act_and_mul: - assert ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS - ), ( - "Non-gated activations are only supported by the" - " flashinfer CUTLASS backend for modelopt checkpoints" - ) - if ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM and not layer.enable_eplb ): return flashinfer_trtllm_fp4_moe( @@ -1762,10 +1627,8 @@ def apply( ) # EPLB path - if ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - ): + if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: + assert layer.enable_eplb return flashinfer_trtllm_fp4_routed_moe( layer=layer, x=x, @@ -1774,81 +1637,20 @@ def apply( top_k=layer.top_k, global_num_experts=layer.global_num_experts, ) - - if self.use_marlin: - return fused_marlin_moe( + else: + assert self.kernel is not None + return self.kernel( x, layer.w13_weight, layer.w2_weight, - None, - None, - layer.w13_weight_scale, - layer.w2_weight_scale, - router_logits, topk_weights, topk_ids, - global_scale1=layer.w13_weight_scale_2, - global_scale2=layer.w2_weight_scale_2, - quant_type_id=scalar_types.float4_e2m1f.id, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - input_dtype=self.marlin_input_dtype, - ) - - elif self.allow_flashinfer: - assert self.flashinfer_moe_backend in ( - FlashinferMoeBackend.CUTLASS, - FlashinferMoeBackend.CUTEDSL, - ) - if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 - flashinfer_cutlass_moe_fp4, - ) - - flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4 - else: - from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( # noqa: E501 - flashinfer_cutedsl_moe_fp4, - ) - - flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4 - - assert self.moe_quant_config is not None - return flashinfer_fn_moe_fp4( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - quant_config=self.moe_quant_config, inplace=False, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - else: - # If no modular kernel is provided, use cutlass_moe_fp4 for TP case - # only (no EP). - from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 - - assert self.moe_quant_config is not None - return cutlass_moe_fp4( - a=x, - w1_fp4=layer.w13_weight, - w2_fp4=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - quant_config=self.moe_quant_config, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - # TODO: derive from arguments - m=x.shape[0], - n=layer.w2_weight.shape[2] * 2, - k=x.shape[1], - e=layer.w13_weight.shape[0], - ) ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 1d410316d629..eaf45ead5afd 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -2,10 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utility helpers for NVFP4 + FlashInfer fused-MoE path""" +from typing import TYPE_CHECKING + import torch import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -20,12 +23,23 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 create_flashinfer_prepare_finalize, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + swizzle_blockscale, +) from vllm.platforms import current_platform from vllm.utils.flashinfer import ( has_flashinfer_cutedsl_grouped_gemm_nt_masked, has_flashinfer_cutlass_fused_moe, ) +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( + NvFp4MoeBackend, + ) + +logger = init_logger(__name__) + + __all__ = [ "is_flashinfer_fp4_cutlass_moe_available", "is_flashinfer_fp4_cutedsl_moe_available", @@ -273,10 +287,9 @@ def flashinfer_trtllm_fp4_moe( hidden_states_fp4, hidden_states_scale_linear_fp4 = x else: # hidden_states is the already quantized - a1_gscale = layer.w13_input_scale_quant (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( x, - a1_gscale, + layer.a1_gscale, is_sf_swizzled_layout=False, ) @@ -369,10 +382,9 @@ def flashinfer_trtllm_fp4_routed_moe( hidden_states_fp4, hidden_states_scale_linear_fp4 = x else: # Quantize input to FP4 - a1_gscale = layer.w13_input_scale_quant (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( x, - a1_gscale, + layer.a1_gscale, is_sf_swizzled_layout=False, ) @@ -410,3 +422,93 @@ def flashinfer_trtllm_fp4_routed_moe( )[0] return out + + +def prepare_nvfp4_moe_layer_for_fi_or_cutlass( + backend: "NvFp4MoeBackend", + layer: torch.nn.Module, + w13: torch.Tensor, + w13_scale: torch.Tensor, + w13_scale_2: torch.Tensor, + a13_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + w2_scale_2: torch.Tensor, + a2_scale: torch.Tensor, + is_act_and_mul: bool, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + # Delayed import for circular dependency avoidance. + from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( + NvFp4MoeBackend, + is_global_sf_supported_for_nvfp4_backend, + ) + + assert backend in [ + NvFp4MoeBackend.VLLM_CUTLASS, + NvFp4MoeBackend.FLASHINFER_CUTLASS, + NvFp4MoeBackend.FLASHINFER_TRTLLM, + NvFp4MoeBackend.FLASHINFER_TRTLLM, + ] + + # Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels. + if is_act_and_mul and backend in [ + NvFp4MoeBackend.FLASHINFER_CUTLASS, + NvFp4MoeBackend.FLASHINFER_TRTLLM, + ]: + w13, w13_scale = reorder_w1w3_to_w3w1(w13, w13_scale) + + # For some FI kernels, the input scales are shared by all experts. + if is_global_sf_supported_for_nvfp4_backend(backend): + num_experts = w13.shape[0] + a13_scale = a13_scale.max().to(torch.float32).expand(num_experts) + a2_scale = a2_scale.max().to(torch.float32).expand(num_experts) + else: + a13_scale = a13_scale.max(dim=1).values.to(torch.float32) + + # Shuffle weights and scales for FI TRTLLM NVFP4 MoE kernels. + if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: + w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe( + w13, + w2, + w13_scale, + w2_scale, + w2.size(-2), # hidden_size + w13.size(-2) // 2, # intermediate_size + w13.size(0), # num_experts + ) + + # We do not need to make this a parameter, because + # it is not used during the weight (re)-loading process. + layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale + layer.a1_gscale = 1.0 / a13_scale + layer.g1_alphas = a13_scale * w13_scale_2 + layer.g2_alphas = a2_scale * w2_scale_2 + else: + # Swizzle the block scales for other FI NVFP4 MoE kernels. + w13_scale = swizzle_blockscale(w13_scale) + + # Apply padding if needed. + pad_size = w13_scale.size(1) - w13.size(1) + if pad_size > 0: + if is_act_and_mul: + raise NotImplementedError( + "Intermediate size padding for w1 and w3, for %s " + "NvFp4 backend, but this is not currently supported", + backend.value, + ) + w13 = torch.nn.functional.pad(w13, (0, 0, 0, pad_size)) + w2 = torch.nn.functional.pad(w2, (0, pad_size // 2, 0, 0)) + w2_scale = torch.nn.functional.pad(w2_scale, (0, pad_size // 16)) + + w2_scale = swizzle_blockscale(w2_scale) + + return w13, w13_scale, w13_scale_2, a13_scale, w2, w2_scale, w2_scale_2, a2_scale diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 4d0a34c3be11..2ced41ef886a 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.marlin_utils import ( USE_FP32_REDUCE_DEFAULT, + get_marlin_input_dtype, marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, @@ -226,6 +227,106 @@ def prepare_fp4_layer_for_marlin( return +def prepare_nvfp4_moe_layer_for_marlin( + layer: torch.nn.Module, + w13: torch.Tensor, + w13_scale: torch.Tensor, + w13_scale_2: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + w2_scale_2: torch.Tensor, +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor +]: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads." + ) + + input_dtype = get_marlin_input_dtype(prefix="") + if input_dtype is not None and input_dtype.itemsize == 1: + raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.") + + GROUP_SIZE = 16 + E = layer.num_experts + K = layer.hidden_size + N = layer.intermediate_size_per_partition + + device = w13.device + param_dtype = layer.params_dtype + is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1 + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor: + tensor_list = [] + if "w13" in name: + size_n, size_k = N * 2, K + else: + size_n, size_k = K, N + + assert weight.shape == (E, size_n, size_k // 2) + + for i in range(E): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4, + is_a_8bit=is_a_8bit, + ) + tensor_list.append(marlin_qweight) + + return torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + + w13 = repack_weight(w13, "w13") + w2 = repack_weight(w2, "w2") + + # WEIGHT SCALES + # Permute scales + def premute_scales( + scales: torch.Tensor, g_scales: torch.Tensor, name: str + ) -> tuple[torch.Tensor, torch.Tensor]: + scales = scales.to(param_dtype) + g_scales = g_scales.to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = N * 2, K + else: + size_n, size_k = K, N + + for i in range(E): + scale = scales[i].T + marlin_scales = marlin_permute_scales( + s=scale, + size_k=size_k, + size_n=size_n, + group_size=GROUP_SIZE, + is_a_8bit=is_a_8bit, + ) + marlin_scales = nvfp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + g_scales = nvfp4_marlin_process_global_scale(g_scales) + return scales, g_scales + + w13_scale, w13_scale_2 = premute_scales(w13_scale, w13_scale_2, "w13") + w2_scale, w2_scale_2 = premute_scales(w2_scale, w2_scale_2, "w2") + + return w13, w13_scale, w13_scale_2, w2, w2_scale, w2_scale_2 + + def prepare_moe_fp4_layer_for_marlin( layer: torch.nn.Module, input_dtype: torch.dtype | None = None ) -> None: