Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
31cd321
redo FusedMoEQuantConfig
bnellnm Aug 8, 2025
2ac62f9
wip
bnellnm Aug 8, 2025
895ac34
wip
bnellnm Aug 8, 2025
86a59a6
wip tests
bnellnm Aug 9, 2025
320487a
fix merge
bnellnm Aug 11, 2025
a39e98a
fixes
bnellnm Aug 11, 2025
0586ffb
comment
bnellnm Aug 11, 2025
3768476
fix
bnellnm Aug 11, 2025
f27d6f8
fixes
bnellnm Aug 11, 2025
9b5242a
fix
bnellnm Aug 11, 2025
3a16300
fixes
bnellnm Aug 11, 2025
872e0eb
fix deepgemm
bnellnm Aug 12, 2025
4b6aac0
fix cutlass + deep gemm tests
bnellnm Aug 12, 2025
4191eff
fix pplx test
bnellnm Aug 12, 2025
093c2d8
test fixes
bnellnm Aug 13, 2025
b761a70
fix
bnellnm Aug 13, 2025
2bff0be
wip
bnellnm Aug 13, 2025
6818d90
wip
bnellnm Aug 14, 2025
87f9dcf
cleanups
bnellnm Aug 15, 2025
448b3c4
lint
bnellnm Aug 15, 2025
5f7e266
fix
bnellnm Aug 15, 2025
d3671ef
fix lint stuff
bnellnm Aug 16, 2025
85d9df2
fix lint
bnellnm Aug 16, 2025
66fb9bf
wip debugging nans
bnellnm Aug 16, 2025
bacf503
fixed
bnellnm Aug 16, 2025
dc11232
add back in disabled test
bnellnm Aug 16, 2025
ce42c6c
fix lint
bnellnm Aug 16, 2025
3b641fa
fix fp4 stuff
bnellnm Aug 17, 2025
67d3449
cleanups
bnellnm Aug 18, 2025
382959a
merge fixes
bnellnm Aug 18, 2025
cab3dea
cleanups
bnellnm Aug 18, 2025
9f3615c
docs
bnellnm Aug 18, 2025
78a85a3
fix typo
bnellnm Aug 19, 2025
2e388f8
move trtllm stuff to separate file
bnellnm Aug 19, 2025
9b0fa6d
fix merge
bnellnm Aug 19, 2025
cc0168e
fix merge
bnellnm Aug 20, 2025
64a779a
tweak comment
bnellnm Aug 23, 2025
7f06feb
fix quantization layer merge
bnellnm Aug 25, 2025
c578f8f
clean up moe method dispatching + add asserts
bnellnm Aug 25, 2025
c897c9a
fix lint
bnellnm Aug 26, 2025
f2ca63a
fix merge error
bnellnm Aug 27, 2025
20e81be
fix merge issue
bnellnm Aug 27, 2025
748029b
fixes
bnellnm Sep 4, 2025
9f2350d
fix merge errors + review comments
bnellnm Sep 9, 2025
5ac746c
fix lint
bnellnm Sep 9, 2025
f63da81
fix TrtLlmGenExperts
bnellnm Sep 10, 2025
7a04c4c
fix lint
bnellnm Sep 10, 2025
980f91e
more lint
bnellnm Sep 10, 2025
939a4e4
fix test
bnellnm Sep 10, 2025
eb6385b
zp -> bias
bnellnm Sep 11, 2025
e91c977
fixes
bnellnm Sep 12, 2025
0d9baa7
fixes
bnellnm Sep 12, 2025
c23c91f
fix mxfp4 lint
bnellnm Sep 12, 2025
5d5e009
fix flashinfer test
bnellnm Sep 16, 2025
6a5453c
update other uses of moe quant configs
bnellnm Sep 16, 2025
08d3a64
update other uses of moe quant configs
bnellnm Sep 16, 2025
ef1c188
resolve merge conflicts
bnellnm Sep 16, 2025
19909f6
fix buffer problems
bnellnm Sep 16, 2025
5bc186b
fix another blackwell test
bnellnm Sep 16, 2025
2082f16
rebase
bnellnm Sep 17, 2025
b19d0bc
fix merge
bnellnm Sep 17, 2025
2de3019
fix merge
bnellnm Sep 17, 2025
469aaee
fix merge
bnellnm Sep 17, 2025
8d94b93
fix more merge issues
bnellnm Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions benchmarks/kernels/benchmark_cutlass_fp4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.scalar_type import scalar_types
Expand Down Expand Up @@ -140,17 +144,20 @@ def run_triton_moe(
a_fp8_scale: torch.Tensor,
num_repeats: int,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
)

for _ in range(num_repeats):
fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
quant_config=quant_config,
)

def run_cutlass_moe_fp4(
Expand All @@ -172,25 +179,27 @@ def run_cutlass_moe_fp4(
device: torch.device,
num_repeats: int,
):
quant_config = nvfp4_moe_quant_config(
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
g1_alphas=w1_gs,
g2_alphas=w2_gs,
)
for _ in range(num_repeats):
with nvtx.annotate("cutlass_moe_fp4", color="green"):
cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
w1_alphas=w1_gs,
w2_fp4=w2_fp4,
w2_blockscale=w2_blockscale,
w2_alphas=w2_gs,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
device=device,
quant_config=quant_config,
)

def run_cutlass_from_graph(
Expand All @@ -211,26 +220,29 @@ def run_cutlass_from_graph(
e: int,
device: torch.device,
):
quant_config = nvfp4_moe_quant_config(
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
g1_alphas=w1_gs,
g2_alphas=w2_gs,
)

with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
w1_alphas=w1_alphas,
a2_gscale=a2_gs,
w2_fp4=w2_fp4,
w2_blockscale=w2_blockscale,
w2_alphas=w2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
device=device,
quant_config=quant_config,
)

def run_triton_from_graph(
Expand All @@ -246,16 +258,18 @@ def run_triton_from_graph(
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
)
return fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
quant_config=quant_config,
)

def replay_graph(graph, num_repeats):
Expand Down
43 changes: 27 additions & 16 deletions benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts,
Expand Down Expand Up @@ -96,17 +97,19 @@ def run_triton_moe(
a_scale: torch.Tensor,
num_repeats: int,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
)
for _ in range(num_repeats):
fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
quant_config=quant_config,
)

def run_cutlass_moe(
Expand All @@ -125,21 +128,24 @@ def run_cutlass_moe(
per_act_token: bool,
num_repeats: int,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
per_act_token_quant=per_act_token,
)

for _ in range(num_repeats):
cutlass_moe_fp8(
a,
w1,
w2,
topk_weights,
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
quant_config=quant_config,
)

def run_cutlass_from_graph(
Expand All @@ -156,6 +162,12 @@ def run_cutlass_from_graph(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
per_act_token_quant=per_act_token,
)

with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
Expand All @@ -165,14 +177,11 @@ def run_cutlass_from_graph(
w2_q,
topk_weights,
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
quant_config=quant_config,
)

def run_triton_from_graph(
Expand All @@ -185,6 +194,11 @@ def run_triton_from_graph(
w2_scale: torch.Tensor,
a_scale: torch.Tensor,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
Expand All @@ -194,10 +208,7 @@ def run_triton_from_graph(
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
quant_config=quant_config,
)

def replay_graph(graph, num_repeats):
Expand Down
73 changes: 35 additions & 38 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
import torch
from ray.experimental.tqdm_ray import tqdm

from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
Expand Down Expand Up @@ -134,43 +138,36 @@ def prepare(i: int):
def run():
from vllm.model_executor.layers.fused_moe import override_config

if use_fp8_w8a8:
quant_dtype = torch.float8_e4m3fn
elif use_int8_w8a16:
quant_dtype = torch.int8
else:
quant_dtype = None

quant_config = FusedMoEQuantConfig.make(
quant_dtype=quant_dtype,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
)

with override_config(config):
if use_deep_gemm:
topk_weights, topk_ids, token_expert_indices = fused_topk(
x, input_gating, topk, False
)
return fused_experts(
x,
w1,
w2,
topk_weights,
topk_ids,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
allow_deep_gemm=True,
)
else:
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
)
topk_weights, topk_ids, token_expert_indices = fused_topk(
x, input_gating, topk, renormalize=not use_deep_gemm
)
return fused_experts(
x,
w1,
w2,
topk_weights,
topk_ids,
inplace=True,
quant_config=quant_config,
allow_deep_gemm=use_deep_gemm,
)

# JIT compilation & warmup
run()
Expand Down Expand Up @@ -414,7 +411,7 @@ def benchmark(
use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed)
dtype_str = get_config_dtype_str(
dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
Expand Down Expand Up @@ -547,7 +544,7 @@ def save_configs(
block_quant_shape: list[int],
save_dir: str,
) -> None:
dtype_str = get_config_dtype_str(
dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)

Expand Down
Loading