diff --git a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
index d6b5820a5b41..10a3e3eab5fd 100644
--- a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
+++ b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
@@ -11,14 +11,20 @@
import torch
import torch.utils.benchmark as benchmark
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
-from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
+from vllm.model_executor.layers.fused_moe.cutlass_moe import (
+ CutlassExpertsFp4,
+)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
+from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNoEP,
+)
from vllm.scalar_type import scalar_types
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager
@@ -188,19 +194,24 @@ def run_cutlass_moe_fp4(
g1_alphas=w1_gs,
g2_alphas=w2_gs,
)
+
+ kernel = mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
+ CutlassExpertsFp4(
+ out_dtype=dtype,
+ max_experts_per_worker=e,
+ quant_config=quant_config,
+ ),
+ )
+
for _ in range(num_repeats):
with nvtx.annotate("cutlass_moe_fp4", color="green"):
- cutlass_moe_fp4(
- a=a,
- w1_fp4=w1_fp4,
- w2_fp4=w2_fp4,
+ kernel(
+ hidden_states=a,
+ w1=w1_fp4,
+ w2=w2_fp4,
topk_weights=topk_weights,
topk_ids=topk_ids,
- m=m,
- n=n,
- k=k,
- e=num_experts,
- quant_config=quant_config,
)
def run_cutlass_from_graph(
@@ -230,20 +241,24 @@ def run_cutlass_from_graph(
g2_alphas=w2_gs,
)
+ kernel = mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
+ CutlassExpertsFp4(
+ out_dtype=dtype,
+ max_experts_per_worker=e,
+ quant_config=quant_config,
+ ),
+ )
+
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
- return cutlass_moe_fp4(
- a=a,
- w1_fp4=w1_fp4,
- w2_fp4=w2_fp4,
+ return kernel(
+ hidden_states=a,
+ w1=w1_fp4,
+ w2=w2_fp4,
topk_weights=topk_weights,
topk_ids=topk_ids,
- m=m,
- n=n,
- k=k,
- e=num_experts,
- quant_config=quant_config,
)
def run_triton_from_graph(
diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index d683b538c415..18216b5965af 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -86,7 +86,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| triton | standard | all1 | G,A,T | silu, gelu,swigluoai,silu_no_mul,gelu_no_mul | Y | Y | [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts],[`TritonExperts`][vllm.model_executor.layers.fused_moe.fused_moe.TritonExperts] |
| triton (batched) | batched | all1 | G,A,T | silu, gelu | 6 | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] |
| deep gemm | standard,batched | fp8 | G(128),A,T | silu, gelu | 6 | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] |
-| cutlass_fp4 | standard,batched | nvfp4 | A,T | silu | Y | Y | [`cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp4],[`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] |
+| cutlass_fp4 | standard,batched | nvfp4 | A,T | silu | Y | Y | [`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] |
| cutlass_fp8 | standard,batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
| flashinfer | standard | nvfp4,fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py
index fd7388e1cff8..873d72117de7 100644
--- a/tests/kernels/moe/test_nvfp4_moe.py
+++ b/tests/kernels/moe/test_nvfp4_moe.py
@@ -3,6 +3,7 @@
import pytest
import torch
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
@@ -13,8 +14,13 @@
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
-from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
+from vllm.model_executor.layers.fused_moe.cutlass_moe import (
+ CutlassExpertsFp4,
+)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
+from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNoEP,
+)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
@@ -83,17 +89,21 @@ def test_cutlass_fp4_moe_no_graph(
w2_scale=w2_blockscale,
)
- cutlass_output = cutlass_moe_fp4(
- a=a,
- w1_fp4=w1_q,
- w2_fp4=w2_q,
+ kernel = mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
+ CutlassExpertsFp4(
+ out_dtype=dtype,
+ max_experts_per_worker=e,
+ quant_config=quant_config,
+ ),
+ )
+
+ cutlass_output = kernel(
+ hidden_states=a,
+ w1=w1_q,
+ w2=w2_q,
topk_weights=topk_weights,
topk_ids=topk_ids,
- quant_config=quant_config,
- m=m,
- n=n,
- k=k,
- e=e,
)
# Reference check:
diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py
index e63404086ed9..9f04397e91f7 100644
--- a/vllm/model_executor/layers/fused_moe/__init__.py
+++ b/vllm/model_executor/layers/fused_moe/__init__.py
@@ -72,7 +72,6 @@ def get_config() -> dict[str, Any] | None:
CutlassBatchedExpertsFp8,
CutlassExpertsFp8,
CutlassExpertsW4A8Fp8,
- cutlass_moe_fp4,
cutlass_moe_w4a8_fp8,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
@@ -95,7 +94,6 @@ def get_config() -> dict[str, Any] | None:
"fused_experts",
"get_config_file_name",
"GroupedTopk",
- "cutlass_moe_fp4",
"cutlass_moe_w4a8_fp8",
"CutlassExpertsFp8",
"CutlassBatchedExpertsFp8",
diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py
index 17d5ec4bcda7..23b86fdca898 100644
--- a/vllm/model_executor/layers/fused_moe/config.py
+++ b/vllm/model_executor/layers/fused_moe/config.py
@@ -331,6 +331,10 @@ def use_fp8_w8a16(self) -> bool:
def use_int4_w4a16(self) -> bool:
return self._a1.dtype is None and self._w1.dtype == "int4"
+ @property
+ def use_nvfp4_w4a16(self) -> bool:
+ return self._a1.dtype is None and self._w1.dtype == "nvfp4"
+
@property
def ocp_mx_scheme(self) -> str | None:
if not hasattr(self, "_ocp_mx_scheme"):
@@ -683,6 +687,25 @@ def nvfp4_moe_quant_config(
)
+def nvfp4_w4a16_moe_quant_config(
+ g1_alphas: torch.Tensor,
+ g2_alphas: torch.Tensor,
+ w1_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
+) -> FusedMoEQuantConfig:
+ """
+ Construct a quant config for 16-but activations and nvp4 weights.
+ """
+ return FusedMoEQuantConfig.make(
+ quant_dtype=None,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ g1_alphas=g1_alphas,
+ g2_alphas=g2_alphas,
+ weight_dtype="nvfp4",
+ )
+
+
def int4_w4a16_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
index 6e397f1e76a1..32ea040c743c 100644
--- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
@@ -706,68 +706,6 @@ def apply(
)
-def cutlass_moe_fp4(
- a: torch.Tensor,
- w1_fp4: torch.Tensor,
- w2_fp4: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- quant_config: FusedMoEQuantConfig,
- m: int,
- n: int,
- k: int,
- e: int,
- expert_map: torch.Tensor | None = None,
- apply_router_weight_on_input: bool = False,
-) -> torch.Tensor:
- assert expert_map is None, (
- "Expert Parallelism / expert_map "
- "is currently not supported for "
- "ModelOptNvFp4FusedMoE's cutlass_moe_fp4."
- )
-
- # TODO(bnell): this feels a bit hacky
- # NVFP4 requires two levels of quantization, which involves
- # computing some scaling factors dynamically. This makes it
- # incompatible with the typical prepare -> MoE -> finalize
- # pipeline. Move the quantization logic into the MoE body.
- quant_config = FusedMoEQuantConfig.make(
- quant_dtype=None, # skip quantization in prepare/finalize
- per_act_token_quant=quant_config.per_act_token_quant,
- per_out_ch_quant=quant_config.per_out_ch_quant,
- block_shape=quant_config.block_shape,
- g1_alphas=quant_config.g1_alphas,
- g2_alphas=quant_config.g2_alphas,
- a1_gscale=quant_config.a1_gscale,
- a2_gscale=quant_config.a2_gscale,
- w1_scale=quant_config.w1_scale,
- w2_scale=quant_config.w2_scale,
- )
-
- fn = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- CutlassExpertsFp4(
- max_experts_per_worker=e,
- out_dtype=a.dtype,
- quant_config=quant_config,
- use_batched_format=False,
- ),
- )
-
- return fn(
- hidden_states=a,
- w1=w1_fp4,
- w2=w2_fp4,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=False,
- activation="silu",
- global_num_experts=e,
- expert_map=None,
- apply_router_weight_on_input=apply_router_weight_on_input,
- )
-
-
# W4A8
def run_cutlass_moe_w4a8_fp8(
output: torch.Tensor,
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
index 6e0b57156cb3..ce93ae235f27 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
@@ -335,42 +335,3 @@ def flashinfer_cutedsl_moe_masked(
alpha_dtype=get_cute_dtype(w2_alpha),
) # in logical [m, k, l]
out = out.permute(2, 0, 1)
-
-
-def flashinfer_cutedsl_moe_fp4(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- quant_config: FusedMoEQuantConfig,
- inplace: bool = False,
- activation: str = "silu",
- global_num_experts: int = -1,
- expert_map: torch.Tensor | None = None,
- apply_router_weight_on_input: bool = False,
-) -> torch.Tensor:
- from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
- create_flashinfer_prepare_finalize,
- )
-
- fused_experts = mk.FusedMoEModularKernel(
- create_flashinfer_prepare_finalize(use_dp=False), # could be swapped later
- FlashInferCuteDSLExperts(
- out_dtype=hidden_states.dtype,
- quant_config=quant_config,
- ),
- )
-
- return fused_experts(
- hidden_states=hidden_states,
- w1=w1,
- w2=w2,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=inplace,
- activation=activation,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- apply_router_weight_on_input=apply_router_weight_on_input,
- )
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
index 0b0efdafbd4d..dfff860750d6 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
@@ -355,21 +355,17 @@ def create_flashinfer_prepare_finalize(
use_deepseek_fp8_block_scale: bool = False,
) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP:
"""Factory function to create the appropriate FlashInfer implementation."""
- # TODO(rob): migrate non-DP cases to MoEPrepareAndFinalizeNoEP
- # once we complete the FP8 refactor.
- if use_nvfp4:
- if enable_alltoallv:
- return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
- else:
- return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
- # FP8 DP path currently supported via AllGather.
if use_dp:
+ if enable_alltoallv:
+ assert use_nvfp4
+ return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
return FlashInferAllGatherMoEPrepareAndFinalize(
use_dp=True,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)
else:
- # NOTE(rob): CUTLASS FP8 block quant executes the input
- # quantzation and grouped gemm in a single kernel.
- return MoEPrepareAndFinalizeNoEP(defer_input_quant=use_deepseek_fp8_block_scale)
+ # CUTLASS FP8 BLOCK and CUTLASS NVFP4 apply input quantization
+ # in a single call with the MoE experts kernel.
+ defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4
+ return MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant)
diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
index c031d9efcc0b..e82a838959de 100644
--- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
@@ -540,9 +540,10 @@ def __init__(
# TODO (varun) : Enable activation quantization
assert (
quant_config.use_mxfp4_w4a16
+ or quant_config.use_nvfp4_w4a16
or quant_config.use_int4_w4a16
or quant_config.use_fp8_w8a16
- ), "Supports only mxfp4_w4a16, int4_w4a16 or fp8_w8a16"
+ ), "Supports only {mxfp,nvfp,int}4_w4a16 or fp8_w8a16"
self.w13_g_idx = w13_g_idx
self.w2_g_idx = w2_g_idx
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
@@ -555,7 +556,7 @@ def quant_type_id(self) -> int:
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
if self.quant_config.use_int4_w4a16:
return scalar_types.uint4b8.id
- elif self.quant_config.use_mxfp4_w4a16:
+ elif self.quant_config.use_mxfp4_w4a16 or self.quant_config.use_nvfp4_w4a16:
return scalar_types.float4_e2m1f.id
elif (
self.quant_config.use_fp8_w8a16
@@ -692,6 +693,8 @@ def apply(
gating_output=None,
topk_weights=topk_weights,
topk_ids=topk_ids,
+ global_scale1=self.g1_alphas,
+ global_scale2=self.g2_alphas,
quant_type_id=self.quant_type_id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index fb441963a97d..3a1860b472d0 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -38,9 +38,6 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
-from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
- is_flashinfer_supporting_global_sf,
-)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
from vllm.utils.math_utils import cdiv, round_up
@@ -1119,14 +1116,9 @@ def weight_loader(
global_expert_id = expert_id
expert_id = self._map_global_expert_id_to_local_expert_id(global_expert_id)
- allow_flashinfer = getattr(self.quant_method, "allow_flashinfer", False)
- moe_backend = getattr(self.quant_method, "flashinfer_moe_backend", None)
-
use_global_sf = (
- allow_flashinfer
- and is_flashinfer_supporting_global_sf(moe_backend)
+ getattr(self.quant_method, "use_global_sf", False)
and "input_scale" in weight_name
- and quant_method_name == "ModelOptNvFp4FusedMoE"
)
if expert_id == -1 and not use_global_sf:
diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
new file mode 100644
index 000000000000..547a2a795d19
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
@@ -0,0 +1,280 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from enum import Enum
+
+import torch
+
+import vllm.envs as envs
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEQuantConfig,
+ nvfp4_moe_quant_config,
+ nvfp4_w4a16_moe_quant_config,
+)
+from vllm.model_executor.layers.fused_moe.cutlass_moe import (
+ CutlassExpertsFp4,
+)
+from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
+ FlashInferExperts,
+)
+from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
+ MarlinExperts,
+)
+from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNoEP,
+)
+from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
+ is_flashinfer_fp4_cutedsl_moe_available,
+ is_flashinfer_fp4_cutlass_moe_available,
+ prepare_nvfp4_moe_layer_for_fi_or_cutlass,
+)
+from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
+ FlashinferMoeBackend,
+ get_flashinfer_moe_backend,
+)
+from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
+ is_fp4_marlin_supported,
+ prepare_nvfp4_moe_layer_for_marlin,
+)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ cutlass_fp4_supported,
+)
+
+logger = init_logger(__name__)
+
+
+class NvFp4MoeBackend(Enum):
+ FLASHINFER_CUTLASS = "FlashInfer CUTLASS"
+ FLASHINFER_TRTLLM = "FlashInfer TRTLLM"
+ FLASHINFER_CUTEDSL = "FlashInfer CUTEDSL"
+ VLLM_CUTLASS = "vLLM CUTASS"
+ MARLIN = "vLLM MARLIN"
+
+
+FLASHINFER_NVFP4_MOE_BACKENDS = [
+ NvFp4MoeBackend.FLASHINFER_CUTLASS,
+ NvFp4MoeBackend.FLASHINFER_TRTLLM,
+ NvFp4MoeBackend.FLASHINFER_CUTEDSL,
+]
+
+fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = {
+ FlashinferMoeBackend.CUTLASS: NvFp4MoeBackend.FLASHINFER_CUTLASS,
+ FlashinferMoeBackend.TENSORRT_LLM: NvFp4MoeBackend.FLASHINFER_TRTLLM,
+ FlashinferMoeBackend.CUTEDSL: NvFp4MoeBackend.FLASHINFER_CUTEDSL,
+}
+
+
+def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
+ # Checks whether `backend` supports quantizing with scaling factors
+ # of all experts in Expert Parallel Mode when all experts are not
+ # on the same rank.
+
+ return backend in [
+ NvFp4MoeBackend.FLASHINFER_CUTLASS,
+ NvFp4MoeBackend.FLASHINFER_TRTLLM,
+ ]
+
+
+def select_nvfp4_moe_backend() -> NvFp4MoeBackend:
+ def _make_log_backend(backend: NvFp4MoeBackend):
+ return f"Using {backend.value} backend for NvFp4 MoE"
+
+ if cutlass_fp4_supported() and not envs.VLLM_TEST_FORCE_FP8_MARLIN:
+ allow_flashinfer = (
+ is_flashinfer_fp4_cutlass_moe_available()
+ or is_flashinfer_fp4_cutedsl_moe_available()
+ )
+ if allow_flashinfer and envs.VLLM_USE_FLASHINFER_MOE_FP4:
+ backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
+ else:
+ backend = NvFp4MoeBackend.VLLM_CUTLASS
+ elif is_fp4_marlin_supported():
+ backend = NvFp4MoeBackend.MARLIN
+ else:
+ raise ValueError("No NvFp4 kernel backend available for NvFp4 MoE.")
+
+ # Log warning if FI backend requested but not available.
+ if (
+ backend not in FLASHINFER_NVFP4_MOE_BACKENDS
+ and envs.VLLM_USE_FLASHINFER_MOE_FP4
+ ):
+ logger.warning_once(
+ "Requested FlashInfer backend for NvFp4 MoE, but it's not available. "
+ "Falling back to %s for NvFp4 MoE",
+ backend.value,
+ scope="local",
+ )
+ else:
+ logger.info_once(_make_log_backend(backend), scope="local")
+ return backend
+
+
+def convert_to_nvfp4_moe_kernel_format(
+ nvfp4_backend: NvFp4MoeBackend,
+ layer: torch.nn.Module,
+ w13: torch.Tensor,
+ w13_scale: torch.Tensor,
+ w13_scale_2: torch.Tensor,
+ a13_scale: torch.Tensor | None,
+ w2: torch.Tensor,
+ w2_scale: torch.Tensor,
+ w2_scale_2: torch.Tensor,
+ a2_scale: torch.Tensor | None,
+ is_act_and_mul: bool,
+) -> tuple[
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+]:
+ if (
+ nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS
+ or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS
+ ):
+ (
+ w13,
+ w13_scale,
+ w13_scale_2,
+ a13_scale,
+ w2,
+ w2_scale,
+ w2_scale_2,
+ a2_scale,
+ ) = prepare_nvfp4_moe_layer_for_fi_or_cutlass(
+ backend=nvfp4_backend,
+ layer=layer,
+ w13=w13,
+ w13_scale=w13_scale,
+ w13_scale_2=w13_scale_2,
+ a13_scale=a13_scale,
+ w2=w2,
+ w2_scale=w2_scale,
+ w2_scale_2=w2_scale_2,
+ a2_scale=a2_scale,
+ is_act_and_mul=is_act_and_mul,
+ )
+ elif nvfp4_backend == NvFp4MoeBackend.MARLIN:
+ a13_scale = None
+ a2_scale = None
+ (
+ w13,
+ w13_scale,
+ w13_scale_2,
+ w2,
+ w2_scale,
+ w2_scale_2,
+ ) = prepare_nvfp4_moe_layer_for_marlin(
+ layer=layer,
+ w13=w13,
+ w13_scale=w13_scale,
+ w13_scale_2=w13_scale_2,
+ w2=w2,
+ w2_scale=w2_scale,
+ w2_scale_2=w2_scale_2,
+ )
+ else:
+ raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}")
+
+ return (
+ w13,
+ w13_scale,
+ w13_scale_2,
+ a13_scale,
+ w2,
+ w2_scale,
+ w2_scale_2,
+ a2_scale,
+ )
+
+
+def make_nvfp4_moe_quant_config(
+ backend: NvFp4MoeBackend,
+ w13_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
+ w13_scale_2: torch.Tensor,
+ w2_scale_2: torch.Tensor,
+ a13_scale: torch.Tensor,
+ a2_scale: torch.Tensor,
+) -> FusedMoEQuantConfig | None:
+ UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM]
+ if backend in UNSUPPORTED:
+ return None
+
+ elif backend == NvFp4MoeBackend.MARLIN:
+ return nvfp4_w4a16_moe_quant_config(
+ g1_alphas=w13_scale_2,
+ g2_alphas=w2_scale_2,
+ w1_scale=w13_scale,
+ w2_scale=w2_scale,
+ )
+
+ g1_alphas = a13_scale * w13_scale_2
+ g2_alphas = a2_scale * w2_scale_2
+ return nvfp4_moe_quant_config(
+ g1_alphas=g1_alphas,
+ g2_alphas=g2_alphas,
+ a1_gscale=(1.0 / a13_scale),
+ a2_gscale=(1.0 / a2_scale),
+ w1_scale=w13_scale,
+ w2_scale=w2_scale,
+ )
+
+
+def make_nvfp4_moe_kernel(
+ backend: NvFp4MoeBackend,
+ quant_config: FusedMoEQuantConfig,
+ moe_config: FusedMoEConfig,
+) -> mk.FusedMoEModularKernel | None:
+ assert moe_config.dp_size == 1
+
+ UNSUPPORTED_BACKENDS = [
+ # TRTLLM does not use the modular kernl abstraction.
+ NvFp4MoeBackend.FLASHINFER_TRTLLM,
+ # CUTEDSL is used with BATCHED (masked) format only.
+ # TODO: add here once we support dp/ep via the oracle.
+ NvFp4MoeBackend.FLASHINFER_CUTEDSL,
+ ]
+
+ if backend in UNSUPPORTED_BACKENDS:
+ return None
+
+ elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
+ return mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
+ FlashInferExperts(
+ out_dtype=moe_config.in_dtype,
+ quant_config=quant_config,
+ ep_rank=moe_config.ep_rank,
+ ep_size=moe_config.ep_size,
+ tp_rank=moe_config.tp_rank,
+ tp_size=moe_config.tp_size,
+ use_dp=False,
+ use_deepseek_fp8_block_scale=False,
+ ),
+ )
+
+ elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
+ return mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
+ CutlassExpertsFp4(
+ out_dtype=moe_config.in_dtype,
+ # TODO(rob): see what impact this has on expert map?
+ max_experts_per_worker=moe_config.num_experts,
+ quant_config=quant_config,
+ ),
+ )
+
+ elif backend == NvFp4MoeBackend.MARLIN:
+ return mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(),
+ MarlinExperts(quant_config=quant_config),
+ )
+
+ else:
+ raise ValueError(f"Unknown NvFp4 MoE backend: {backend}")
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index a2b3aec4457e..509de5dff9c1 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -11,7 +11,6 @@
QuantizationArgs,
QuantizationStrategy,
)
-from torch.nn.parameter import Parameter
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
@@ -34,12 +33,8 @@
int4_w4afp8_moe_quant_config,
int8_w8a8_moe_quant_config,
int8_w8a16_moe_quant_config,
- nvfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
- is_valid_flashinfer_cutlass_fused_moe,
-)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
@@ -51,6 +46,15 @@
make_fp8_moe_kernel,
select_fp8_moe_backend,
)
+from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
+ FLASHINFER_NVFP4_MOE_BACKENDS,
+ NvFp4MoeBackend,
+ convert_to_nvfp4_moe_kernel_format,
+ is_global_sf_supported_for_nvfp4_backend,
+ make_nvfp4_moe_kernel,
+ make_nvfp4_moe_quant_config,
+ select_nvfp4_moe_backend,
+)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP,
@@ -58,14 +62,9 @@
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
- prepare_static_weights_for_trtllm_fp4_moe,
- reorder_w1w3_to_w3w1,
+ flashinfer_trtllm_fp4_routed_moe,
select_nvfp4_gemm_impl,
)
-from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
- FlashinferMoeBackend,
- get_flashinfer_moe_backend,
-)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe,
process_fp8_weight_tensor_strategy_moe,
@@ -77,20 +76,15 @@
marlin_make_workspace_new,
marlin_moe_permute_scales,
)
-from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
- prepare_moe_fp4_layer_for_marlin,
-)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
- swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import CpuArchEnum, current_platform
-from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@@ -218,31 +212,19 @@ def get_moe_method(
class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None):
- from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
- detect_nvfp4_moe_support,
- )
+ if not moe.is_act_and_mul:
+ raise ValueError(
+ "CompressedTensorsW4A4Nvfp4MoEMethod does not yet "
+ "support non gated MoE models."
+ )
super().__init__(moe)
- _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
- self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
- self.allow_flashinfer = _nvfp4.allow_flashinfer
- self.use_marlin = _nvfp4.use_marlin
self.group_size = 16
- self.layer_name = layer_name
- self.marlin_input_dtype = (
- get_marlin_input_dtype(layer_name) if self.use_marlin else None
- )
- self.flashinfer_moe_backend = None
- if self.allow_flashinfer:
- self.flashinfer_moe_backend = get_flashinfer_moe_backend()
- logger.info_once(
- f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
- " for CompressedTensorsW4A4Nvfp4MoEMethod."
- )
- elif self.use_marlin:
- logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoEMethod.")
- else:
- logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoEMethod.")
+ self.nvfp4_backend = select_nvfp4_moe_backend()
+ self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
+ self.nvfp4_backend
+ )
+ self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights(
self,
@@ -355,7 +337,13 @@ def create_weights(
set_weight_attrs(w2_input_scale, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
- # From packed to weight
+ """
+ Convert NVFP4 MoE weights into kernel format and setup the kernel.
+ """
+ # NOTE(rob): wN_weight_packed -> wN_weight is because ModularKernelMethod
+ # requires this naming convention. However, the name change breaks
+ # reloading because the state dict no longer matches disk. Once we
+ # remove MKM, we should revert this change to ensure compatibility.
layer.w13_weight = torch.nn.Parameter(
layer.w13_weight_packed.data, requires_grad=False
)
@@ -366,144 +354,79 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
)
delattr(layer, "w2_weight_packed")
- # reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
- if self.allow_flashinfer:
- w, s = reorder_w1w3_to_w3w1(
- layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2
- )
- layer.w13_weight = torch.nn.Parameter(w, requires_grad=False)
- layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False)
-
- if not torch.allclose(
+ # Use a single gscale for w13.
+ if self.moe.is_act_and_mul and not torch.allclose(
layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1]
):
logger.warning_once(
"w1_weight_global_scale must match w3_weight_global_scale. "
- "Accuracy may be affected."
- )
-
- # Take inverse of global scale saved to disk
- layer.w13_weight_scale_2 = torch.nn.Parameter(
- 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False
- )
-
- layer.w2_weight_scale_2 = torch.nn.Parameter(
- 1 / layer.w2_weight_global_scale.data, requires_grad=False
- )
-
- if self.use_marlin:
- prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
- return
- # w13
- if (
- self.allow_flashinfer
- and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
- ):
- w13_input_global_scale = (
- layer.w13_input_global_scale.min()
- .to(torch.float32)
- .expand(layer.num_experts)
- )
- else:
- w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to(
- torch.float32
- )
- layer.g1_alphas = torch.nn.Parameter(
- ((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
- requires_grad=False,
- )
-
- layer.w13_input_scale_quant = torch.nn.Parameter(
- (w13_input_global_scale), requires_grad=False
- )
-
- # w2
- if (
- self.allow_flashinfer
- and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
- ):
- w2_input_global_scale = (
- layer.w2_input_global_scale.min()
- .to(torch.float32)
- .expand(layer.num_experts)
- )
- else:
- w2_input_global_scale = layer.w2_input_global_scale
-
- layer.g2_alphas = torch.nn.Parameter(
- ((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32),
- requires_grad=False,
- )
-
- layer.w2_input_scale_quant = torch.nn.Parameter(
- (w2_input_global_scale), requires_grad=False
+ "Accuracy may be affected.",
+ )
+ w13_weight_global_scale = layer.w13_weight_global_scale[:, 0].contiguous()
+
+ # Shuffle weights into the NvFp4 kernel format.
+ (
+ w13,
+ w13_scale,
+ w13_scale_2,
+ a13_scale,
+ w2,
+ w2_scale,
+ w2_scale_2,
+ a2_scale,
+ ) = convert_to_nvfp4_moe_kernel_format(
+ nvfp4_backend=self.nvfp4_backend,
+ layer=layer,
+ w13=layer.w13_weight,
+ w13_scale=layer.w13_weight_scale,
+ w13_scale_2=(1.0 / w13_weight_global_scale),
+ a13_scale=(1.0 / layer.w13_input_global_scale),
+ w2=layer.w2_weight,
+ w2_scale=layer.w2_weight_scale,
+ w2_scale_2=(1.0 / layer.w2_weight_global_scale),
+ a2_scale=(1.0 / layer.w2_input_global_scale),
+ is_act_and_mul=self.moe.is_act_and_mul,
)
- # TensorRT-LLM specific processing
- if (
- self.allow_flashinfer
- and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
- ):
- # Prepare static weights for TRT-LLM kernel
- # alternate: prepare_static_weight_layouts_for_trtllm_moe
- (
- gemm1_weights_fp4_shuffled,
- gemm1_scales_fp4_shuffled,
- gemm2_weights_fp4_shuffled,
- gemm2_scales_fp4_shuffled,
- ) = prepare_static_weights_for_trtllm_fp4_moe(
- layer.w13_weight,
- layer.w2_weight,
- layer.w13_weight_scale,
- layer.w2_weight_scale,
- layer.w2_weight.size(-2), # hidden_size
- layer.w13_weight.size(-2) // 2, # intermediate_size
- layer.w13_weight.size(0), # num_experts
- )
- logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
-
- layer.w13_weight = Parameter(
- gemm1_weights_fp4_shuffled, requires_grad=False
- )
- layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
- layer.w13_weight_scale = Parameter(
- gemm1_scales_fp4_shuffled, requires_grad=False
- )
- layer.w2_weight_scale = Parameter(
- gemm2_scales_fp4_shuffled, requires_grad=False
- )
-
- # Additional parameter needed for TRT-LLM
- layer.g1_scale_c = Parameter(
- (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
- requires_grad=False,
- )
- else:
- # swizzle weight scales
- layer.w13_weight_scale = torch.nn.Parameter(
- swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
- )
+ replace_parameter(layer, "w13_weight", w13)
+ replace_parameter(layer, "w13_weight_scale", w13_scale)
+ replace_parameter(layer, "w2_weight", w2)
+ replace_parameter(layer, "w2_weight_scale", w2_scale)
+ layer.w13_weight_scale_2 = w13_scale_2
+ layer.w2_weight_scale_2 = w2_scale_2
+ layer.w13_input_scale = a13_scale
+ layer.w2_input_scale = a2_scale
- layer.w2_weight_scale = torch.nn.Parameter(
- swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
+ # Initialize the kernel that will be called in apply().
+ self.moe_quant_config = self.get_fused_moe_quant_config(layer)
+ use_dp = self.moe.dp_size > 1
+ if self.moe_quant_config is not None and not use_dp:
+ self.kernel = make_nvfp4_moe_kernel(
+ backend=self.nvfp4_backend,
+ quant_config=self.moe_quant_config,
+ moe_config=self.moe,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- if self.use_marlin or (
- self.allow_flashinfer
- and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
- ):
+ UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM]
+ if self.nvfp4_backend in UNSUPPORTED:
return None
- elif not self.allow_flashinfer:
+ elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
+ # TP case: avoid convert to ModularKernelMethod - to be refactored.
+ if self.moe.dp_size == 1:
+ return None
+ # For now, fp4 moe only works with the flashinfer dispatcher.
+ prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
+ self.moe
+ )
+ logger.debug_once("%s", prepare_finalize.__class__.__name__)
+ return prepare_finalize
+ else:
return super().maybe_make_prepare_finalize(routing_tables)
- prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe)
- logger.debug_once("%s", prepare_finalize.__class__.__name__)
- return prepare_finalize
-
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
@@ -514,7 +437,7 @@ def select_gemm_impl(
experts = select_nvfp4_gemm_impl(
self.moe,
self.moe_quant_config,
- allow_flashinfer=self.allow_flashinfer,
+ allow_flashinfer=(self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS),
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
@@ -522,19 +445,14 @@ def select_gemm_impl(
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
- if (
- self.use_marlin
- or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
- ):
- return None
-
- return nvfp4_moe_quant_config(
- g1_alphas=layer.g1_alphas,
- g2_alphas=layer.g2_alphas,
- a1_gscale=layer.w13_input_scale_quant,
- a2_gscale=layer.w2_input_scale_quant,
- w1_scale=layer.w13_weight_scale,
+ return make_nvfp4_moe_quant_config(
+ backend=self.nvfp4_backend,
+ w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
+ w13_scale_2=layer.w13_weight_scale_2,
+ w2_scale_2=layer.w2_weight_scale_2,
+ a13_scale=layer.w13_input_scale,
+ a2_scale=layer.w2_input_scale,
)
def apply(
@@ -546,14 +464,9 @@ def apply(
assert layer.activation == "silu", "Only SiLU activation is supported."
if (
- self.allow_flashinfer
- and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
+ self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
+ and not layer.enable_eplb
):
- if layer.enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet."
- )
-
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
@@ -566,79 +479,41 @@ def apply(
e_score_correction_bias=layer.e_score_correction_bias,
)
+ # Hidden_states in select_experts is only used to extract metadata
+ if isinstance(x, tuple):
+ x_routing, _ = x
+ else:
+ x_routing = x
topk_weights, topk_ids = layer.select_experts(
- hidden_states=x,
+ hidden_states=x_routing,
router_logits=router_logits,
)
- if self.use_marlin:
- return fused_marlin_moe(
+ # EPLB path
+ if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
+ assert layer.enable_eplb
+ return flashinfer_trtllm_fp4_routed_moe(
+ layer=layer,
+ x=x,
+ topk_ids=topk_ids,
+ topk_weights=topk_weights,
+ top_k=layer.top_k,
+ global_num_experts=layer.global_num_experts,
+ )
+ else:
+ assert self.kernel is not None
+ return self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
- None,
- None,
- layer.w13_weight_scale,
- layer.w2_weight_scale,
- router_logits,
topk_weights,
topk_ids,
- global_scale1=layer.w13_weight_scale_2,
- global_scale2=layer.w2_weight_scale_2,
- quant_type_id=scalar_types.float4_e2m1f.id,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- global_num_experts=layer.global_num_experts,
- expert_map=layer.expert_map,
- input_dtype=self.marlin_input_dtype,
- workspace=layer.workspace,
- )
-
- # FlashInfer fused experts path
- elif self.allow_flashinfer:
- from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
- flashinfer_cutlass_moe_fp4,
- )
-
- assert is_valid_flashinfer_cutlass_fused_moe(
- x, layer.w13_weight, layer.w2_weight
- ), "Flashinfer CUTLASS Fused MoE not applicable!"
-
- assert self.moe_quant_config is not None
-
- return flashinfer_cutlass_moe_fp4(
- hidden_states=x,
- w1=layer.w13_weight,
- w2=layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- quant_config=self.moe_quant_config,
- inplace=False, # TODO(shuw): fix later, now output is high prec
+ inplace=False,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
- else:
- # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
- # only (no EP).
- from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
-
- assert self.moe_quant_config is not None
- return cutlass_moe_fp4(
- a=x,
- w1_fp4=layer.w13_weight,
- w2_fp4=layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- quant_config=self.moe_quant_config,
- expert_map=layer.expert_map,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- # TODO(bnell): derive these from arguments
- m=x.shape[0],
- n=layer.w2_weight.shape[2] * 2,
- k=x.shape[1],
- e=layer.w13_weight.shape[0],
- ).to(x.dtype)
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py
index bd7a90a80af1..2e4f1daf6690 100644
--- a/vllm/model_executor/layers/quantization/modelopt.py
+++ b/vllm/model_executor/layers/quantization/modelopt.py
@@ -15,9 +15,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
- nvfp4_moe_quant_config,
)
-from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
@@ -30,6 +28,15 @@
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
+from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
+ FLASHINFER_NVFP4_MOE_BACKENDS,
+ NvFp4MoeBackend,
+ convert_to_nvfp4_moe_kernel_format,
+ is_global_sf_supported_for_nvfp4_backend,
+ make_nvfp4_moe_kernel,
+ make_nvfp4_moe_quant_config,
+ select_nvfp4_moe_backend,
+)
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
@@ -45,16 +52,11 @@
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
- prepare_static_weights_for_trtllm_fp4_moe,
- reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
- FlashinferMoeBackend,
apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
- get_flashinfer_moe_backend,
- is_flashinfer_supporting_global_sf,
select_cutlass_fp8_gemm_impl,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@@ -69,7 +71,6 @@
apply_fp4_marlin_linear,
is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin,
- prepare_moe_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
@@ -89,7 +90,6 @@
PerTensorScaleParameter,
)
from vllm.model_executor.utils import replace_parameter
-from vllm.scalar_type import scalar_types
from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
has_flashinfer,
@@ -1327,43 +1327,32 @@ def __init__(
quant_config: ModelOptNvFp4Config,
layer: FusedMoE,
) -> None:
- from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
- detect_nvfp4_moe_support, # noqa: E501
- )
-
super().__init__(layer.moe_config)
self.quant_config = quant_config
- self.layer = layer
- _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
- self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
- self.allow_flashinfer = _nvfp4.allow_flashinfer
- self.use_marlin = _nvfp4.use_marlin
- self.marlin_input_dtype = None
- self.flashinfer_moe_backend = None
- if self.allow_flashinfer:
- self.flashinfer_moe_backend = get_flashinfer_moe_backend()
- logger.info_once(
- f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
- " for ModelOptNvFp4FusedMoE."
+ self.nvfp4_backend = select_nvfp4_moe_backend()
+ # TODO: move this type of check into the oracle.
+ if (
+ not self.moe.is_act_and_mul
+ and not self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS
+ ):
+ raise NotImplementedError(
+ "Non-gated activations are only supported by FlashInfer "
+ "CUTLASS NvFP4 MoE backend."
)
- elif self.use_marlin:
- logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.")
- else:
- logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.")
+
+ self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
+ self.nvfp4_backend
+ )
+ self.kernel: mk.FusedMoEModularKernel | None = None
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- if self.use_marlin or (
- self.allow_flashinfer
- and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
- ):
+ UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM]
+ if self.nvfp4_backend in UNSUPPORTED:
return None
- elif (
- self.allow_flashinfer
- and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
- ):
+ elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
# TP case: avoid convert to ModularKernelMethod - to be refactored.
if self.moe.dp_size == 1:
return None
@@ -1385,7 +1374,7 @@ def select_gemm_impl(
experts = select_nvfp4_gemm_impl(
self.moe,
self.moe_quant_config,
- allow_flashinfer=self.allow_flashinfer,
+ allow_flashinfer=self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
@@ -1405,11 +1394,7 @@ def create_weights(
params_dtype: torch.dtype,
**extra_weight_attrs,
):
- if not self.quant_config.is_checkpoint_nvfp4_serialized:
- raise ValueError(
- "NVFP4 quantization was selected, "
- " dynamic quantization is not supported."
- )
+ assert self.quant_config.is_checkpoint_nvfp4_serialized
layer.num_experts = num_experts
layer.params_dtype = params_dtype
@@ -1498,14 +1483,12 @@ def create_weights(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
- use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
- self.flashinfer_moe_backend
+ global_sf_num_experts = (
+ global_num_experts if self.use_global_sf else num_experts
)
- global_scale_num_experts = global_num_experts if use_global_sf else num_experts
-
w13_input_scale = PerTensorScaleParameter(
data=torch.empty(
- global_scale_num_experts,
+ global_sf_num_experts,
2 if self.moe.is_act_and_mul else 1,
dtype=torch.float32,
),
@@ -1514,32 +1497,17 @@ def create_weights(
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = PerTensorScaleParameter(
- data=torch.empty(global_scale_num_experts, dtype=torch.float32),
+ data=torch.empty(global_sf_num_experts, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
- # GEMM 1 processing
- gemm1_weight = layer.w13_weight.data
- gemm1_weight_scale = layer.w13_weight_scale.data
-
- if (
- self.allow_flashinfer
- and (
- self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
- or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
- )
- and self.moe.is_act_and_mul
- ):
- gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
- gemm1_weight, gemm1_weight_scale, dim=-2
- )
-
- layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
- layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
+ """
+ Convert NVFP4 MoE weights into kernel format and setup the kernel.
+ """
- # Common processing for w13_weight_scale_2
+ # Use a single gscale for w13.
if self.moe.is_act_and_mul and not torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
):
@@ -1547,136 +1515,47 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"w1_weight_scale_2 must match w3_weight_scale_2. "
"Accuracy may be affected."
)
-
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
- layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
- # Common processing for input scales and alphas
- use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
- self.flashinfer_moe_backend
- )
- if use_global_sf:
- # For backends provide by Flashinfer, the input global scales are
- # shared across all experts.
- w13_input_scale = (
- layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
- )
- else:
- w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
- layer.g1_alphas = Parameter(
- (w13_input_scale * w13_weight_scale_2).to(torch.float32),
- requires_grad=False,
- )
-
- # This is for quantization, so we need to invert it.
- layer.w13_input_scale_quant = Parameter(
- (1 / w13_input_scale).to(torch.float32), requires_grad=False
- )
-
- # GEMM 2 processing
- if use_global_sf:
- # For backends provide by Flashinfer, the input global scales are
- # shared across all experts.
- w2_input_scale = (
- layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
- )
- else:
- w2_input_scale = layer.w2_input_scale
- layer.g2_alphas = Parameter(
- (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
- requires_grad=False,
- )
-
- # This is for quantization, so we need to invert it.
- layer.w2_input_scale_quant = Parameter(
- (1 / w2_input_scale).to(torch.float32), requires_grad=False
+ (
+ w13,
+ w13_scale,
+ w13_scale_2,
+ a13_scale,
+ w2,
+ w2_scale,
+ w2_scale_2,
+ a2_scale,
+ ) = convert_to_nvfp4_moe_kernel_format(
+ nvfp4_backend=self.nvfp4_backend,
+ layer=layer,
+ w13=layer.w13_weight,
+ w13_scale=layer.w13_weight_scale,
+ w13_scale_2=w13_weight_scale_2,
+ a13_scale=layer.w13_input_scale,
+ w2=layer.w2_weight,
+ w2_scale=layer.w2_weight_scale,
+ w2_scale_2=layer.w2_weight_scale_2,
+ a2_scale=layer.w2_input_scale,
+ is_act_and_mul=self.moe.is_act_and_mul,
)
- # TensorRT-LLM specific processing
- if (
- self.allow_flashinfer
- and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
- ):
- # Prepare static weights for TRT-LLM kernel
- # alternate: prepare_static_weight_layouts_for_trtllm_moe
- (
- gemm1_weights_fp4_shuffled,
- gemm1_scales_fp4_shuffled,
- gemm2_weights_fp4_shuffled,
- gemm2_scales_fp4_shuffled,
- ) = prepare_static_weights_for_trtllm_fp4_moe(
- layer.w13_weight,
- layer.w2_weight,
- layer.w13_weight_scale,
- layer.w2_weight_scale,
- layer.w2_weight.size(-2), # hidden_size
- layer.w13_weight.size(-2) // 2, # intermediate_size
- layer.w13_weight.size(0), # num_experts
- )
- logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
-
- layer.w13_weight = Parameter(
- gemm1_weights_fp4_shuffled, requires_grad=False
- )
- layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
- layer.w13_weight_scale = Parameter(
- gemm1_scales_fp4_shuffled, requires_grad=False
- )
- layer.w2_weight_scale = Parameter(
- gemm2_scales_fp4_shuffled, requires_grad=False
- )
-
- # Additional parameter needed for TRT-LLM
- layer.g1_scale_c = Parameter(
- (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
- requires_grad=False,
- )
- elif self.use_marlin:
- # Marlin processing
- prepare_moe_fp4_layer_for_marlin(layer)
- del layer.g1_alphas
- del layer.g2_alphas
- del layer.w13_input_scale_quant
- del layer.w2_input_scale_quant
- else:
- # Non-TRT-LLM processing (Cutlass or non-flashinfer)
- w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
- layer.w13_weight_scale = Parameter(
- w13_blockscale_swizzled, requires_grad=False
- )
-
- w13_weight = layer.w13_weight
- intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1)
- if intermediate_size_pad:
- # padding gated activations will require to split w1 and w3
- # and pad them individually
- assert not self.moe.is_act_and_mul, (
- "The intermediate size required padding, "
- "but padding is not implemented for gated activations"
- )
-
- layer.w13_weight = Parameter(
- torch.nn.functional.pad(
- w13_weight, (0, 0, 0, intermediate_size_pad)
- ),
- requires_grad=False,
- )
- layer.w2_weight = Parameter(
- torch.nn.functional.pad(
- layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)
- ),
- requires_grad=False,
- )
- layer.w2_weight_scale = Parameter(
- torch.nn.functional.pad(
- layer.w2_weight_scale, (0, intermediate_size_pad // 16)
- ),
- requires_grad=False,
- )
+ replace_parameter(layer, "w13_weight", w13)
+ replace_parameter(layer, "w13_weight_scale", w13_scale)
+ replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
+ replace_parameter(layer, "w13_input_scale", a13_scale)
+ replace_parameter(layer, "w2_weight", w2)
+ replace_parameter(layer, "w2_weight_scale", w2_scale)
+ replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
+ replace_parameter(layer, "w2_input_scale", a2_scale)
- w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
- layer.w2_weight_scale = Parameter(
- w2_blockscale_swizzled, requires_grad=False
+ self.moe_quant_config = self.get_fused_moe_quant_config(layer)
+ use_dp = self.moe.dp_size > 1
+ if self.moe_quant_config is not None and not use_dp:
+ self.kernel = make_nvfp4_moe_kernel(
+ backend=self.nvfp4_backend,
+ quant_config=self.moe_quant_config,
+ moe_config=self.moe,
)
def prepare_dp_allgather_tensor(
@@ -1688,7 +1567,8 @@ def prepare_dp_allgather_tensor(
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
import flashinfer
- a1_gscale = layer.w13_input_scale_quant
+ assert self.moe_quant_config is not None
+ a1_gscale = self.moe_quant_config.a1_gscale
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
hidden_states,
a1_gscale,
@@ -1700,19 +1580,14 @@ def prepare_dp_allgather_tensor(
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
- if (
- self.use_marlin
- or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
- ):
- return None
-
- return nvfp4_moe_quant_config(
- w1_scale=layer.w13_weight_scale,
+ return make_nvfp4_moe_quant_config(
+ backend=self.nvfp4_backend,
+ w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
- g1_alphas=layer.g1_alphas,
- g2_alphas=layer.g2_alphas,
- a1_gscale=layer.w13_input_scale_quant,
- a2_gscale=layer.w2_input_scale_quant,
+ w13_scale_2=layer.w13_weight_scale_2,
+ w2_scale_2=layer.w2_weight_scale_2,
+ a13_scale=layer.w13_input_scale,
+ a2_scale=layer.w2_input_scale,
)
@property
@@ -1725,18 +1600,8 @@ def apply(
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if not self.moe.is_act_and_mul:
- assert (
- self.allow_flashinfer
- and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
- ), (
- "Non-gated activations are only supported by the"
- " flashinfer CUTLASS backend for modelopt checkpoints"
- )
-
if (
- self.allow_flashinfer
- and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
+ self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not layer.enable_eplb
):
return flashinfer_trtllm_fp4_moe(
@@ -1762,10 +1627,8 @@ def apply(
)
# EPLB path
- if (
- self.allow_flashinfer
- and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
- ):
+ if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
+ assert layer.enable_eplb
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
@@ -1774,81 +1637,20 @@ def apply(
top_k=layer.top_k,
global_num_experts=layer.global_num_experts,
)
-
- if self.use_marlin:
- return fused_marlin_moe(
+ else:
+ assert self.kernel is not None
+ return self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
- None,
- None,
- layer.w13_weight_scale,
- layer.w2_weight_scale,
- router_logits,
topk_weights,
topk_ids,
- global_scale1=layer.w13_weight_scale_2,
- global_scale2=layer.w2_weight_scale_2,
- quant_type_id=scalar_types.float4_e2m1f.id,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- global_num_experts=layer.global_num_experts,
- expert_map=layer.expert_map,
- input_dtype=self.marlin_input_dtype,
- )
-
- elif self.allow_flashinfer:
- assert self.flashinfer_moe_backend in (
- FlashinferMoeBackend.CUTLASS,
- FlashinferMoeBackend.CUTEDSL,
- )
- if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
- from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
- flashinfer_cutlass_moe_fp4,
- )
-
- flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
- else:
- from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( # noqa: E501
- flashinfer_cutedsl_moe_fp4,
- )
-
- flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
-
- assert self.moe_quant_config is not None
- return flashinfer_fn_moe_fp4(
- hidden_states=x,
- w1=layer.w13_weight,
- w2=layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- quant_config=self.moe_quant_config,
inplace=False,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
- else:
- # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
- # only (no EP).
- from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
-
- assert self.moe_quant_config is not None
- return cutlass_moe_fp4(
- a=x,
- w1_fp4=layer.w13_weight,
- w2_fp4=layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- quant_config=self.moe_quant_config,
- expert_map=layer.expert_map,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- # TODO: derive from arguments
- m=x.shape[0],
- n=layer.w2_weight.shape[2] * 2,
- k=x.shape[1],
- e=layer.w13_weight.shape[0],
- )
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
index 1d410316d629..eaf45ead5afd 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
@@ -2,10 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility helpers for NVFP4 + FlashInfer fused-MoE path"""
+from typing import TYPE_CHECKING
+
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@@ -20,12 +23,23 @@
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ swizzle_blockscale,
+)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutlass_fused_moe,
)
+if TYPE_CHECKING:
+ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
+ NvFp4MoeBackend,
+ )
+
+logger = init_logger(__name__)
+
+
__all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available",
@@ -273,10 +287,9 @@ def flashinfer_trtllm_fp4_moe(
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# hidden_states is the already quantized
- a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
- a1_gscale,
+ layer.a1_gscale,
is_sf_swizzled_layout=False,
)
@@ -369,10 +382,9 @@ def flashinfer_trtllm_fp4_routed_moe(
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# Quantize input to FP4
- a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
- a1_gscale,
+ layer.a1_gscale,
is_sf_swizzled_layout=False,
)
@@ -410,3 +422,93 @@ def flashinfer_trtllm_fp4_routed_moe(
)[0]
return out
+
+
+def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
+ backend: "NvFp4MoeBackend",
+ layer: torch.nn.Module,
+ w13: torch.Tensor,
+ w13_scale: torch.Tensor,
+ w13_scale_2: torch.Tensor,
+ a13_scale: torch.Tensor,
+ w2: torch.Tensor,
+ w2_scale: torch.Tensor,
+ w2_scale_2: torch.Tensor,
+ a2_scale: torch.Tensor,
+ is_act_and_mul: bool,
+) -> tuple[
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+]:
+ # Delayed import for circular dependency avoidance.
+ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
+ NvFp4MoeBackend,
+ is_global_sf_supported_for_nvfp4_backend,
+ )
+
+ assert backend in [
+ NvFp4MoeBackend.VLLM_CUTLASS,
+ NvFp4MoeBackend.FLASHINFER_CUTLASS,
+ NvFp4MoeBackend.FLASHINFER_TRTLLM,
+ NvFp4MoeBackend.FLASHINFER_TRTLLM,
+ ]
+
+ # Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels.
+ if is_act_and_mul and backend in [
+ NvFp4MoeBackend.FLASHINFER_CUTLASS,
+ NvFp4MoeBackend.FLASHINFER_TRTLLM,
+ ]:
+ w13, w13_scale = reorder_w1w3_to_w3w1(w13, w13_scale)
+
+ # For some FI kernels, the input scales are shared by all experts.
+ if is_global_sf_supported_for_nvfp4_backend(backend):
+ num_experts = w13.shape[0]
+ a13_scale = a13_scale.max().to(torch.float32).expand(num_experts)
+ a2_scale = a2_scale.max().to(torch.float32).expand(num_experts)
+ else:
+ a13_scale = a13_scale.max(dim=1).values.to(torch.float32)
+
+ # Shuffle weights and scales for FI TRTLLM NVFP4 MoE kernels.
+ if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
+ w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
+ w13,
+ w2,
+ w13_scale,
+ w2_scale,
+ w2.size(-2), # hidden_size
+ w13.size(-2) // 2, # intermediate_size
+ w13.size(0), # num_experts
+ )
+
+ # We do not need to make this a parameter, because
+ # it is not used during the weight (re)-loading process.
+ layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale
+ layer.a1_gscale = 1.0 / a13_scale
+ layer.g1_alphas = a13_scale * w13_scale_2
+ layer.g2_alphas = a2_scale * w2_scale_2
+ else:
+ # Swizzle the block scales for other FI NVFP4 MoE kernels.
+ w13_scale = swizzle_blockscale(w13_scale)
+
+ # Apply padding if needed.
+ pad_size = w13_scale.size(1) - w13.size(1)
+ if pad_size > 0:
+ if is_act_and_mul:
+ raise NotImplementedError(
+ "Intermediate size padding for w1 and w3, for %s "
+ "NvFp4 backend, but this is not currently supported",
+ backend.value,
+ )
+ w13 = torch.nn.functional.pad(w13, (0, 0, 0, pad_size))
+ w2 = torch.nn.functional.pad(w2, (0, pad_size // 2, 0, 0))
+ w2_scale = torch.nn.functional.pad(w2_scale, (0, pad_size // 16))
+
+ w2_scale = swizzle_blockscale(w2_scale)
+
+ return w13, w13_scale, w13_scale_2, a13_scale, w2, w2_scale, w2_scale_2, a2_scale
diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py
index 4d0a34c3be11..2ced41ef886a 100644
--- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py
+++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py
@@ -8,6 +8,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
USE_FP32_REDUCE_DEFAULT,
+ get_marlin_input_dtype,
marlin_make_workspace_new,
marlin_permute_bias,
marlin_permute_scales,
@@ -226,6 +227,106 @@ def prepare_fp4_layer_for_marlin(
return
+def prepare_nvfp4_moe_layer_for_marlin(
+ layer: torch.nn.Module,
+ w13: torch.Tensor,
+ w13_scale: torch.Tensor,
+ w13_scale_2: torch.Tensor,
+ w2: torch.Tensor,
+ w2_scale: torch.Tensor,
+ w2_scale_2: torch.Tensor,
+) -> tuple[
+ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
+]:
+ logger.warning_once(
+ "Your GPU does not have native support for FP4 computation but "
+ "FP4 quantization is being used. Weight-only FP4 compression will "
+ "be used leveraging the Marlin kernel. This may degrade "
+ "performance for compute-heavy workloads."
+ )
+
+ input_dtype = get_marlin_input_dtype(prefix="")
+ if input_dtype is not None and input_dtype.itemsize == 1:
+ raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
+
+ GROUP_SIZE = 16
+ E = layer.num_experts
+ K = layer.hidden_size
+ N = layer.intermediate_size_per_partition
+
+ device = w13.device
+ param_dtype = layer.params_dtype
+ is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
+
+ # WORKSPACE
+ layer.workspace = marlin_make_workspace_new(device, 4)
+ perm = torch.empty(0, dtype=torch.int, device=device)
+
+ # WEIGHT
+ # Repack weights to marlin format
+ def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor:
+ tensor_list = []
+ if "w13" in name:
+ size_n, size_k = N * 2, K
+ else:
+ size_n, size_k = K, N
+
+ assert weight.shape == (E, size_n, size_k // 2)
+
+ for i in range(E):
+ qweight = weight[i].view(torch.int32).T.contiguous()
+
+ marlin_qweight = ops.gptq_marlin_repack(
+ b_q_weight=qweight,
+ perm=perm,
+ size_k=size_k,
+ size_n=size_n,
+ num_bits=4,
+ is_a_8bit=is_a_8bit,
+ )
+ tensor_list.append(marlin_qweight)
+
+ return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
+
+ w13 = repack_weight(w13, "w13")
+ w2 = repack_weight(w2, "w2")
+
+ # WEIGHT SCALES
+ # Permute scales
+ def premute_scales(
+ scales: torch.Tensor, g_scales: torch.Tensor, name: str
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ scales = scales.to(param_dtype)
+ g_scales = g_scales.to(param_dtype)
+
+ tensor_list = []
+ if "w13" in name:
+ size_n, size_k = N * 2, K
+ else:
+ size_n, size_k = K, N
+
+ for i in range(E):
+ scale = scales[i].T
+ marlin_scales = marlin_permute_scales(
+ s=scale,
+ size_k=size_k,
+ size_n=size_n,
+ group_size=GROUP_SIZE,
+ is_a_8bit=is_a_8bit,
+ )
+ marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
+ tensor_list.append(marlin_scales)
+
+ scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
+ g_scales = nvfp4_marlin_process_global_scale(g_scales)
+ return scales, g_scales
+
+ w13_scale, w13_scale_2 = premute_scales(w13_scale, w13_scale_2, "w13")
+ w2_scale, w2_scale_2 = premute_scales(w2_scale, w2_scale_2, "w2")
+
+ return w13, w13_scale, w13_scale_2, w2, w2_scale, w2_scale_2
+
+
def prepare_moe_fp4_layer_for_marlin(
layer: torch.nn.Module, input_dtype: torch.dtype | None = None
) -> None: