Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
60a830a
initial commit
Jan 4, 2026
d77fc66
refactoring to use the oracle
Jan 5, 2026
6ec8d1b
nit
Jan 5, 2026
c0bf3ce
raise error
Jan 5, 2026
c0414f8
move not implemented error to __init__ from apply
Jan 5, 2026
3f85920
minor rename
Jan 5, 2026
e0ade2a
first working via flashinfer cutlass
Jan 5, 2026
3b59a99
update
Jan 5, 2026
ed04dc1
update
Jan 5, 2026
4fdca7e
actually use apply
Jan 5, 2026
b16c0cd
we have correctness for flashinfer cutlass via the .apply() pathway
Jan 5, 2026
fd54809
convert to using the PFNoEP
Jan 5, 2026
8aad006
create the kernel in the oracle
Jan 5, 2026
751286a
forgot to make the kernel :)
Jan 5, 2026
966b123
last nit?
Jan 5, 2026
d98adfd
minor comment change
Jan 5, 2026
e0ba913
minor tweak to oracle
Jan 5, 2026
15876cc
remove flashinfer cutedsl from aply since it does not work via apply …
Jan 5, 2026
7a191ad
flashinfer trtllm working properly again
Jan 5, 2026
04e8a13
we are able to run with vllm cutlass, but have 0% accuracy score
Jan 5, 2026
f943472
got accuracy with vllm cutlass
Jan 5, 2026
715e428
convert to int comparison
Jan 5, 2026
ae14963
update comment
Jan 5, 2026
c222df0
update to decouple use_global_sf from ModelOpt
Jan 5, 2026
3fb8d01
initial attempt to apply this to compressed-tensors
Jan 5, 2026
c13549c
flashinfer cutlass working end-to-end with compressed-tensors
Jan 5, 2026
5d7422b
updated
Jan 5, 2026
f1ce2ee
fixed nvfp4 trtllm
Jan 5, 2026
49ccce2
stash marlin fixes
Jan 5, 2026
5eaf83d
stash
Jan 5, 2026
cc97856
revert changes to ct
Jan 5, 2026
f90c4e8
first evidence of new marlin structure working
Jan 5, 2026
4d2eb0c
marlin now working again for modelopt
Jan 5, 2026
189f7ba
stash
Jan 5, 2026
66ac3a4
revert change for failing test
Jan 5, 2026
7b36d99
test is now passing
Jan 5, 2026
74eba5a
updated the --enforce-eager
Jan 5, 2026
a9d807e
updated
Jan 5, 2026
25db395
revert debug issues
Jan 5, 2026
6b33d39
fixed up block_quant
Jan 5, 2026
adb0aac
update comment
Jan 5, 2026
6b8a47d
stash changes
Jan 5, 2026
e6a41fd
update
Jan 5, 2026
d47717a
updated comments
Jan 5, 2026
858989c
updated to remove debug nits
Jan 5, 2026
d2c112e
updated
Jan 5, 2026
60233a3
Merge branch 'nvfp3-refactor' of https://github.com/vllm-project/vllm…
Jan 5, 2026
33a91d3
marlin nvfp4 working via modular kenrels for modelopt
Jan 5, 2026
dab8425
clean up
Jan 5, 2026
f0b80da
replaced warning
Jan 5, 2026
6cd74c1
stash changes
Jan 5, 2026
f0a9a27
merge the bugfix
Jan 5, 2026
1fa99f7
compressed tensors nvfp4
Jan 5, 2026
17ed636
update comment
Jan 5, 2026
a8a6017
fix the linter
Jan 5, 2026
10c6ffc
updated compressed-tensors
Jan 5, 2026
aaee9aa
fix dp/ep regression
Jan 5, 2026
d79719f
get past the linter
Jan 5, 2026
f187bf1
Merge branch 'main' into nvfp4-refactor
robertgshaw2-redhat Jan 6, 2026
7f04a73
Merge branch 'nvfp3-refactor' of https://github.com/vllm-project/vllm…
Jan 6, 2026
947e41b
Merge remote-tracking branch 'origin/main' into nvfp3-refactor
Jan 6, 2026
7775b22
updated
Jan 7, 2026
b92ead4
update uvicorn acess
Jan 7, 2026
05bd281
update invocation of convert to kernel logic
Jan 7, 2026
4367556
running
Jan 7, 2026
359d467
remove cutlass_moe_fp4
Jan 7, 2026
e528323
nit loc changes
Jan 7, 2026
9db0bbf
reduce loc changes
Jan 7, 2026
157fbb9
fix missing import
Jan 7, 2026
e8119e7
updated with minor nits to reduce LOC or eliminate debugs
Jan 7, 2026
3e4ad7a
do the merge
Jan 8, 2026
5977f00
Update vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
robertgshaw2-redhat Jan 8, 2026
c6d7cc8
updated
Jan 8, 2026
a74743a
updated
Jan 8, 2026
071f0fb
clean up --enforce-eager
Jan 8, 2026
6e769c4
Update vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
robertgshaw2-redhat Jan 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@
import torch
import torch.utils.benchmark as benchmark

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.scalar_type import scalar_types
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager
Expand Down Expand Up @@ -188,19 +194,24 @@ def run_cutlass_moe_fp4(
g1_alphas=w1_gs,
g2_alphas=w2_gs,
)

kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
out_dtype=dtype,
max_experts_per_worker=e,
quant_config=quant_config,
),
)

for _ in range(num_repeats):
with nvtx.annotate("cutlass_moe_fp4", color="green"):
cutlass_moe_fp4(
a=a,
w1_fp4=w1_fp4,
w2_fp4=w2_fp4,
kernel(
hidden_states=a,
w1=w1_fp4,
w2=w2_fp4,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
quant_config=quant_config,
)

def run_cutlass_from_graph(
Expand Down Expand Up @@ -230,20 +241,24 @@ def run_cutlass_from_graph(
g2_alphas=w2_gs,
)

kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
out_dtype=dtype,
max_experts_per_worker=e,
quant_config=quant_config,
),
)

with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return cutlass_moe_fp4(
a=a,
w1_fp4=w1_fp4,
w2_fp4=w2_fp4,
return kernel(
hidden_states=a,
w1=w1_fp4,
w2=w2_fp4,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
quant_config=quant_config,
)

def run_triton_from_graph(
Expand Down
2 changes: 1 addition & 1 deletion docs/design/moe_kernel_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| triton | standard | all<sup>1</sup> | G,A,T | silu, gelu,</br>swigluoai,</br>silu_no_mul,</br>gelu_no_mul | Y | Y | [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts],</br>[`TritonExperts`][vllm.model_executor.layers.fused_moe.fused_moe.TritonExperts] |
| triton (batched) | batched | all<sup>1</sup> | G,A,T | silu, gelu | <sup>6</sup> | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] |
| deep gemm | standard,</br>batched | fp8 | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],</br>[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],</br>[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] |
| cutlass_fp4 | standard,</br>batched | nvfp4 | A,T | silu | Y | Y | [`cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp4],</br>[`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] |
| cutlass_fp4 | standard,</br>batched | nvfp4 | A,T | silu | Y | Y | [`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] |
| cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
Expand Down
30 changes: 20 additions & 10 deletions tests/kernels/moe/test_nvfp4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
Expand All @@ -13,8 +14,13 @@
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed

Expand Down Expand Up @@ -83,17 +89,21 @@ def test_cutlass_fp4_moe_no_graph(
w2_scale=w2_blockscale,
)

cutlass_output = cutlass_moe_fp4(
a=a,
w1_fp4=w1_q,
w2_fp4=w2_q,
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
out_dtype=dtype,
max_experts_per_worker=e,
quant_config=quant_config,
),
)

cutlass_output = kernel(
hidden_states=a,
w1=w1_q,
w2=w2_q,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=quant_config,
m=m,
n=n,
k=k,
e=e,
)

# Reference check:
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def get_config() -> dict[str, Any] | None:
CutlassBatchedExpertsFp8,
CutlassExpertsFp8,
CutlassExpertsW4A8Fp8,
cutlass_moe_fp4,
cutlass_moe_w4a8_fp8,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
Expand All @@ -95,7 +94,6 @@ def get_config() -> dict[str, Any] | None:
"fused_experts",
"get_config_file_name",
"GroupedTopk",
"cutlass_moe_fp4",
"cutlass_moe_w4a8_fp8",
"CutlassExpertsFp8",
"CutlassBatchedExpertsFp8",
Expand Down
23 changes: 23 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ def use_fp8_w8a16(self) -> bool:
def use_int4_w4a16(self) -> bool:
return self._a1.dtype is None and self._w1.dtype == "int4"

@property
def use_nvfp4_w4a16(self) -> bool:
return self._a1.dtype is None and self._w1.dtype == "nvfp4"

@property
def ocp_mx_scheme(self) -> str | None:
if not hasattr(self, "_ocp_mx_scheme"):
Expand Down Expand Up @@ -683,6 +687,25 @@ def nvfp4_moe_quant_config(
)


def nvfp4_w4a16_moe_quant_config(
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for 16-but activations and nvp4 weights.
"""
return FusedMoEQuantConfig.make(
quant_dtype=None,
w1_scale=w1_scale,
w2_scale=w2_scale,
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
weight_dtype="nvfp4",
)


def int4_w4a16_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
Expand Down
62 changes: 0 additions & 62 deletions vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,68 +706,6 @@ def apply(
)


def cutlass_moe_fp4(
a: torch.Tensor,
w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
m: int,
n: int,
k: int,
e: int,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
assert expert_map is None, (
"Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4."
)

# TODO(bnell): this feels a bit hacky
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
quant_config = FusedMoEQuantConfig.make(
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=quant_config.per_act_token_quant,
per_out_ch_quant=quant_config.per_out_ch_quant,
block_shape=quant_config.block_shape,
g1_alphas=quant_config.g1_alphas,
g2_alphas=quant_config.g2_alphas,
a1_gscale=quant_config.a1_gscale,
a2_gscale=quant_config.a2_gscale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
)

fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
max_experts_per_worker=e,
out_dtype=a.dtype,
quant_config=quant_config,
use_batched_format=False,
),
)

return fn(
hidden_states=a,
w1=w1_fp4,
w2=w2_fp4,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation="silu",
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=apply_router_weight_on_input,
)


# W4A8
def run_cutlass_moe_w4a8_fp8(
output: torch.Tensor,
Expand Down
39 changes: 0 additions & 39 deletions vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,42 +335,3 @@ def flashinfer_cutedsl_moe_masked(
alpha_dtype=get_cute_dtype(w2_alpha),
) # in logical [m, k, l]
out = out.permute(2, 0, 1)


def flashinfer_cutedsl_moe_fp4(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)

fused_experts = mk.FusedMoEModularKernel(
create_flashinfer_prepare_finalize(use_dp=False), # could be swapped later
FlashInferCuteDSLExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
),
)

return fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
Original file line number Diff line number Diff line change
Expand Up @@ -355,21 +355,17 @@ def create_flashinfer_prepare_finalize(
use_deepseek_fp8_block_scale: bool = False,
) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP:
"""Factory function to create the appropriate FlashInfer implementation."""
# TODO(rob): migrate non-DP cases to MoEPrepareAndFinalizeNoEP
# once we complete the FP8 refactor.
if use_nvfp4:
if enable_alltoallv:
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
else:
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)

# FP8 DP path currently supported via AllGather.
if use_dp:
if enable_alltoallv:
assert use_nvfp4
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
return FlashInferAllGatherMoEPrepareAndFinalize(
use_dp=True,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)
else:
# NOTE(rob): CUTLASS FP8 block quant executes the input
# quantzation and grouped gemm in a single kernel.
return MoEPrepareAndFinalizeNoEP(defer_input_quant=use_deepseek_fp8_block_scale)
# CUTLASS FP8 BLOCK and CUTLASS NVFP4 apply input quantization
# in a single call with the MoE experts kernel.
defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4
return MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant)
7 changes: 5 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,9 +540,10 @@ def __init__(
# TODO (varun) : Enable activation quantization
assert (
quant_config.use_mxfp4_w4a16
or quant_config.use_nvfp4_w4a16
or quant_config.use_int4_w4a16
or quant_config.use_fp8_w8a16
), "Supports only mxfp4_w4a16, int4_w4a16 or fp8_w8a16"
), "Supports only {mxfp,nvfp,int}4_w4a16 or fp8_w8a16"
self.w13_g_idx = w13_g_idx
self.w2_g_idx = w2_g_idx
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
Expand All @@ -555,7 +556,7 @@ def quant_type_id(self) -> int:
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
if self.quant_config.use_int4_w4a16:
return scalar_types.uint4b8.id
elif self.quant_config.use_mxfp4_w4a16:
elif self.quant_config.use_mxfp4_w4a16 or self.quant_config.use_nvfp4_w4a16:
return scalar_types.float4_e2m1f.id
elif (
self.quant_config.use_fp8_w8a16
Expand Down Expand Up @@ -692,6 +693,8 @@ def apply(
gating_output=None,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_scale1=self.g1_alphas,
global_scale2=self.g2_alphas,
quant_type_id=self.quant_type_id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
Expand Down
Loading