Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
213 commits
Select commit Hold shift + click to select a range
e5c50db
cleanup process weights after loading
Dec 22, 2025
b1dddfd
removing spurious aiter stuff
Dec 22, 2025
78e9289
removing spurious aiter stuff
Dec 22, 2025
f1ae727
good codex bot
Dec 22, 2025
1a576b8
revert spurious aiter stuff
Dec 22, 2025
70367ac
reduce LOC changes
Dec 22, 2025
8425200
reduce LOC changes
Dec 22, 2025
f414f6c
further simplification
Dec 22, 2025
b31a8cb
updated
Dec 22, 2025
f2c70e1
cleanup comment
Dec 22, 2025
bd2046b
fix custom routing function for flashinfer
Dec 22, 2025
6655381
invalid checks in FP8MoE
Dec 22, 2025
a8820d8
stashing ... mixtral via flashinfer is not working properly'
Dec 22, 2025
844093c
merged aiter
Dec 23, 2025
cc2df79
updated
Dec 23, 2025
6db374b
updated
Dec 23, 2025
c2cc8e1
stash
Dec 23, 2025
1cf3b88
fix bad merge and bad qdq for per-tensor
Dec 23, 2025
985a0ab
weight rotation should only happen for per-tensor
Dec 23, 2025
7defa14
improve error message
Dec 23, 2025
078fdc4
updated
Dec 23, 2025
b55d8e3
update marlin ordering
Dec 23, 2025
8fbae90
improve comments
Dec 23, 2025
4bbb70f
add helper functions to share between online and offline quantization
Dec 23, 2025
bd72e61
fix up condition
Dec 23, 2025
24d219b
fix up condition
Dec 23, 2025
c290ebb
fix up condition
Dec 23, 2025
550e763
update to revert cleanup
Dec 23, 2025
d1fba0e
Merge branch 'main' into clean-up-fp8-process-after-loading
robertgshaw2-redhat Dec 23, 2025
7822c4d
Merge remote-tracking branch 'origin/main' into clean-up-fp8-process-…
Dec 27, 2025
32ef76a
updated
Dec 27, 2025
a8c5927
updated with proper impots
Dec 27, 2025
80039a0
added missing file
Dec 27, 2025
d9dfa7b
abstract the kernel swapping
Dec 27, 2025
bbcd012
abstract the kernel swapping
Dec 27, 2025
5822325
start applying to ct_moe_fp8
Dec 27, 2025
3f2322a
further progress
Dec 27, 2025
a7b6550
further progress
Dec 27, 2025
1737ed6
nit logging
Dec 27, 2025
cba5504
factor out the process after loading for requantize
Dec 27, 2025
b2bc870
rename setup_kernel -> make_kernel
Dec 27, 2025
14247f0
make convert_weights_to_runtime_format a pure function
Dec 27, 2025
570c359
factor out process weights after load
Dec 27, 2025
50d0bea
make names shorter
Dec 27, 2025
1e750fc
remove marlin_input_dtype
Dec 27, 2025
d4c5c4d
force reset
Dec 27, 2025
d626763
remove unneeded change
Dec 27, 2025
f7aaa16
remove imports
Dec 27, 2025
cdcc9f5
guard against TRT-LLM
Dec 27, 2025
ad6dc86
rename get_fp8_moe_backend -> select_fp8_moe_backend
Dec 27, 2025
dab20c6
add todo
Dec 27, 2025
5128d77
initial refactoring
Dec 29, 2025
2bc0a42
apply refactor to fp8 and ct
Dec 29, 2025
8a3b303
apply refactor globally
Dec 29, 2025
47a58a2
updated fused moe quant config for weight scale name
Dec 29, 2025
4aa3a64
Merge branch 'main' into fix-up-marlin-prepare-layer
robertgshaw2-redhat Dec 29, 2025
3eaa036
update comment
Dec 29, 2025
7ce4ef8
fix marlin tensor
Dec 29, 2025
5a55df8
Merge remote-tracking branch 'origin/main' into apply-refactor-to-ct
Dec 29, 2025
9d49046
merged main
Dec 29, 2025
8369d49
in process of refactoring cutlass fp8
Dec 29, 2025
4c3b253
convert to use mk for compressed-tensors
Dec 29, 2025
34cdaaa
updated
Dec 29, 2025
9ad1a09
fix online quantization
Dec 29, 2025
dad15e2
fix online quantization
Dec 29, 2025
542d2de
reduce name lenght for less newlines
Dec 29, 2025
758fae9
update name to convert_to_fp8_moe_kernel_format
Dec 29, 2025
35a0426
w13_weight -> w13 etc to reduce line breaks
Dec 29, 2025
33898b4
try to reduce loc change
Dec 29, 2025
7e8db4f
try to reduce loc change
Dec 29, 2025
0e9629d
rename make_kernel -> make_fp8_moe_kernel
Dec 29, 2025
2a32bac
update commentary about disable expert map
Dec 29, 2025
864379e
cleanup
Dec 29, 2025
b7091e3
removing run_cutlass_moe_fp8
Dec 29, 2025
da47fe0
add back ops import
Dec 29, 2025
3658b48
remove strides construction
Dec 29, 2025
464bb50
remove run_cutlass_fp8
Dec 29, 2025
b411440
apply to test cutlass moe
Dec 29, 2025
b17cb0e
remove import
Dec 29, 2025
e045766
fixed failing test
Dec 29, 2025
1d134d5
attempt to fix cutlass moe unit test
Dec 29, 2025
8e864f0
init workspace manager
Dec 29, 2025
b3f9a10
pre-commit on marlin input dtype
Dec 29, 2025
f915373
some more tweaks
Dec 29, 2025
f30ae4b
revert change to stray file
Dec 29, 2025
5f99481
clean up select_gemm_impl
Dec 29, 2025
8bdedc2
we are now passing for fp8.py triton block!
Dec 29, 2025
face1ce
,merged main
Dec 30, 2025
9f31d65
reduce loc change
Dec 30, 2025
47ca569
update comment
Dec 30, 2025
df49c3a
update tritonordeepgemmexperts
Dec 30, 2025
4828063
updated
Dec 30, 2025
6cefc23
split into separate situation
Dec 30, 2025
4880aef
add small batch fallback for cutlass
Dec 30, 2025
b89f68f
add small batch fallback for cutlass
Dec 30, 2025
d43dbb5
fix fallback
Dec 30, 2025
19ae34f
revert changes to marlin utils file
Dec 30, 2025
3845ee8
revert changes to get_marlin_input_dtype
Dec 30, 2025
3196cb1
nits
Dec 30, 2025
461cd8c
the ordered on ABC, Class matters for some reason in python --- duck
Dec 30, 2025
16aa74e
stash
Dec 30, 2025
be0abe2
apply changes to modelopt
robertgshaw2-redhat Dec 31, 2025
4eed452
remove unneeded cruft
robertgshaw2-redhat Dec 31, 2025
450f035
cleanup initialization
robertgshaw2-redhat Dec 31, 2025
d6a1f64
initial commit
Dec 31, 2025
9a28683
move to cuda before the reshapes for r&d
Jan 1, 2026
e84eaa2
clean
Jan 1, 2026
058a998
clean
Jan 1, 2026
d53b6ff
clean
Jan 1, 2026
9a7cf4d
clean
Jan 1, 2026
1182e1d
clean
Jan 1, 2026
2126f98
clean
Jan 1, 2026
33741a8
clean
Jan 1, 2026
31c4e22
stash
Jan 1, 2026
e0129dd
working end to end
Jan 1, 2026
eb6699b
comment nits
Jan 1, 2026
5be7ab1
comment nits
Jan 1, 2026
844a65a
remove
Jan 1, 2026
24a0302
rename method
Jan 1, 2026
7edf70f
stash trtllm fix
Jan 1, 2026
9d994a6
stash changes
Jan 1, 2026
6ff4b75
Merge branch 'fix-flashinfer-experts-quant-config-hack' of https://gi…
Jan 1, 2026
c9a7e5b
updated
Jan 1, 2026
2408ad2
make trtllm work
Jan 1, 2026
96ff599
ad back import
Jan 1, 2026
f9a4724
update comments
Jan 1, 2026
e8831f9
apply changes to fp8.py
Jan 1, 2026
f8f9a33
nit
Jan 1, 2026
59f97a6
revert unneeded assets
Jan 1, 2026
113e472
rename
Jan 1, 2026
a98a380
update comment
Jan 1, 2026
df82e9c
Merge branch 'fix-flashinfer-experts-quant-config-hack' of https://gi…
Jan 1, 2026
df5035c
naming
Jan 1, 2026
3678402
add back check to prevent mixtral
Jan 1, 2026
783b64d
Merge branch 'main' into fix-flashinfer-experts-quant-config-hack
robertgshaw2-redhat Jan 1, 2026
c30d404
remove delete me
Jan 2, 2026
a285f5e
update to address pavani's feedback
Jan 2, 2026
2d96161
reduce LOC change
Jan 2, 2026
a910872
fix
Jan 2, 2026
d2decd6
reduce loc nit
Jan 2, 2026
870fc6a
clean up fi checking
Jan 2, 2026
23e79fd
updated
Jan 2, 2026
86a0e5c
fix assign float to parameter
Jan 2, 2026
7eaa18b
updated doc string
Jan 2, 2026
83a7d9b
fix up
Jan 2, 2026
7300bc5
fix up
Jan 2, 2026
344167d
fix up configs
Jan 2, 2026
b887c4f
cleanup
Jan 2, 2026
d4d4231
remove unneeded assert
Jan 2, 2026
218e697
standardize how we add the params
Jan 2, 2026
b2e3a50
updated
Jan 2, 2026
dd30416
updated
Jan 2, 2026
3d22ba3
updated
Jan 2, 2026
56edeca
a few small nits
Jan 2, 2026
e917f5d
fix tests
Jan 2, 2026
39987f6
revert the llama weight loading hack
Jan 2, 2026
140f447
stash
Jan 2, 2026
de6faa1
unstash
Jan 2, 2026
173e67d
Merge remote-tracking branch 'origin/main' into fix-flashinfer-expert…
Jan 3, 2026
fac4014
merge
robertgshaw2-redhat Jan 3, 2026
c1c1195
fix merge from amd guard
robertgshaw2-redhat Jan 3, 2026
79acaac
merge the fi branch
robertgshaw2-redhat Jan 3, 2026
caf46be
rename flashinfer trtllm funciton names
robertgshaw2-redhat Jan 3, 2026
2a5a58b
cleanup!
robertgshaw2-redhat Jan 3, 2026
4d76cb6
tests are passing
robertgshaw2-redhat Jan 3, 2026
d304a73
remove comments
robertgshaw2-redhat Jan 3, 2026
2ee1c44
updated
Jan 3, 2026
a5a1d0b
circular import
Jan 4, 2026
5a627d8
updated the details
Jan 4, 2026
2a5c255
fix typing
Jan 4, 2026
08a1979
fixed cutlass block
Jan 4, 2026
8f2341f
clean up deepgemm a bit
Jan 4, 2026
85d59c8
use proper naming for modelopt
Jan 4, 2026
1a69be2
use proper naming for modelopt
Jan 4, 2026
8c4dddf
use proper naming for modelopt
Jan 4, 2026
16721e5
merged main
Jan 6, 2026
8fec574
update log for unsupported
Jan 6, 2026
95f0b37
Merge remote-tracking branch 'origin/main' into apply-refactor-to-ct
Jan 6, 2026
609b9b9
merge main
Jan 6, 2026
3ceb254
nit
Jan 6, 2026
af2cbd3
update convert_weights_to_kernel_format
Jan 6, 2026
fb6e402
rever use for flashinfer_moe_backend
Jan 6, 2026
da6218a
stash work
Jan 6, 2026
64b7ba5
fix trtllm kernel
Jan 6, 2026
ebd76f2
fix importing issues
Jan 6, 2026
6afe4bb
fix compressed tensors issue
Jan 6, 2026
eecb7dc
fix lint
Jan 6, 2026
804c147
fix error from lack of routing
Jan 6, 2026
b6e5dc5
delayed imports
Jan 6, 2026
617c662
fix cutlass tensor
Jan 6, 2026
6b1d1ad
make marlin pass
Jan 6, 2026
2c2e274
make things easier to follow in the ci logs
Jan 6, 2026
eb83e8d
add dp/ep
Jan 6, 2026
c7424b7
Merge branch 'main' into apply-refactor-to-ct
robertgshaw2-redhat Jan 6, 2026
e41d147
nit
Jan 6, 2026
3469b8d
updated
Jan 7, 2026
af5a4f4
updated
Jan 7, 2026
34175d1
ipdate the test coverage for dp/ep
Jan 7, 2026
2687d2c
fix up again, some of the a2a backends are not working
Jan 7, 2026
1eceb09
Merge branch 'main' into apply-refactor-to-ct
robertgshaw2-redhat Jan 7, 2026
9ce786a
update oracle to not use cutlass for block quant
Jan 7, 2026
fefa376
delete llama 4 load time optim
Jan 7, 2026
5cccc6c
docs fix
Jan 7, 2026
9810137
revert .to(cuda)
Jan 7, 2026
35c3bc3
updated with bills comments
Jan 7, 2026
ce8deb7
revert change for llama experts not being loaded properly
Jan 7, 2026
08742c5
fix marlin comment
Jan 7, 2026
a369a51
Merge branch 'main' into apply-refactor-to-ct
robertgshaw2-redhat Jan 7, 2026
d4486f8
updated the access logs
Jan 7, 2026
be3dc9a
updated
Jan 7, 2026
a927bec
updated
Jan 7, 2026
192339d
delay imports
Jan 7, 2026
5dccfc6
fix missing
Jan 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1422,3 +1422,10 @@ steps:
num_gpus: 2
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor/config-b200.txt

- label: MoE Refactor Integration Test (B200 DP - TEMPORARY) # optional
gpu: b200
optional: true
num_gpus: 2
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt
116 changes: 24 additions & 92 deletions benchmarks/kernels/benchmark_cutlass_moe_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
but use different quantization strategies and backends.
"""

import nvtx
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager
Expand Down Expand Up @@ -59,6 +62,7 @@ def bench_run(
per_out_ch: bool,
mkn: tuple[int, int, int],
):
init_workspace_manager(torch.cuda.current_device())
(m, k, n) = mkn

dtype = torch.half
Expand Down Expand Up @@ -121,85 +125,6 @@ def bench_run(
# Force per-tensor quantization for all cases
per_act_token = False

# Create stride tensors for CUTLASS
ab_strides1 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
ab_strides2 = torch.full((num_experts,), n, dtype=torch.int64, device=device)
c_strides1 = torch.full((num_experts,), 2 * n, dtype=torch.int64, device=device)
c_strides2 = torch.full((num_experts,), k, dtype=torch.int64, device=device)

def run_triton_moe(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function was unused

a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
a2_scale: torch.Tensor,
num_repeats: int,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
)

for _ in range(num_repeats):
fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
quant_config=quant_config,
)

def run_cutlass_moe_fp8(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
a2_scale: torch.Tensor,
num_repeats: int,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
)

for _ in range(num_repeats):
with nvtx.annotate("cutlass_moe_fp8", color="blue"):
cutlass_moe_fp8(
a=a,
w1_q=w1,
w2_q=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
quant_config=quant_config,
activation="silu",
global_num_experts=num_experts,
)

# Pre-create quantization config to avoid creating it inside CUDA graph
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
Expand All @@ -210,23 +135,30 @@ def run_cutlass_moe_fp8(
per_out_ch_quant=per_out_ch,
)

fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
e=num_experts,
n=n,
k=k,
quant_config=quant_config,
device=w1.device,
),
)

# Create CUDA graphs for CUTLASS (match benchmark_moe.py pattern exactly)
cutlass_stream = torch.cuda.Stream()
cutlass_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
# Capture 10 invocations like benchmark_moe.py
for _ in range(10):
cutlass_moe_fp8(
a=a,
w1_q=w1_fp8q_cutlass,
w2_q=w2_fp8q_cutlass,
topk_weights=topk_weights,
topk_ids=topk_ids,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
quant_config=quant_config,
fn(
a,
w1_fp8q_cutlass,
w2_fp8q_cutlass,
topk_weights,
topk_ids,
activation="silu",
global_num_experts=num_experts,
)
Expand Down
90 changes: 37 additions & 53 deletions benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES_MOE

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager

Expand Down Expand Up @@ -45,6 +49,7 @@ def bench_run(
per_out_ch: bool,
mkn: tuple[int, int, int],
):
init_workspace_manager(torch.cuda.current_device())
label = "Quant Matmul"

sub_label = (
Expand Down Expand Up @@ -82,11 +87,6 @@ def bench_run(
a, score, topk, renormalize=False
)

ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)

def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
Expand Down Expand Up @@ -120,10 +120,6 @@ def run_cutlass_moe(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
Expand All @@ -135,31 +131,29 @@ def run_cutlass_moe(
per_act_token_quant=per_act_token,
)

for _ in range(num_repeats):
cutlass_moe_fp8(
a,
w1,
w2,
topk_weights,
topk_ids,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
e=w2.shape[0],
n=w2.shape[2],
k=w2.shape[1],
quant_config=quant_config,
)
device=w1.device,
),
)

for _ in range(num_repeats):
fn(a, w1, w2, topk_weights, topk_ids)

def run_cutlass_from_graph(
a: torch.Tensor,
a_scale: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
Expand All @@ -169,21 +163,23 @@ def run_cutlass_from_graph(
per_act_token_quant=per_act_token,
)

fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
e=w2.shape[0],
n=w2.shape[2],
k=w2.shape[1],
quant_config=quant_config,
device=w1.device,
),
)

with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return cutlass_moe_fp8(
a,
w1_q,
w2_q,
topk_weights,
topk_ids,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
quant_config=quant_config,
)
return fn(a, w1, w2, topk_weights, topk_ids)

def run_triton_from_graph(
a: torch.Tensor,
Expand Down Expand Up @@ -227,10 +223,6 @@ def replay_graph(graph, num_repeats):
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
)
Expand Down Expand Up @@ -268,10 +260,6 @@ def replay_graph(graph, num_repeats):
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
"ab_strides1": ab_strides1,
"ab_strides2": ab_strides2,
"c_strides1": c_strides1,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
Expand Down Expand Up @@ -330,10 +318,6 @@ def replay_graph(graph, num_repeats):
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
per_act_token,
Expand All @@ -342,7 +326,7 @@ def replay_graph(graph, num_repeats):

results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down
2 changes: 0 additions & 2 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ def clear_triton_cache():

# Try to clear Triton's runtime cache
try:
import triton

if (
hasattr(triton, "runtime")
and hasattr(triton.runtime, "cache")
Expand Down
2 changes: 1 addition & 1 deletion docs/design/moe_kernel_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| triton (batched) | batched | all<sup>1</sup> | G,A,T | silu, gelu | <sup>6</sup> | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] |
| deep gemm | standard,</br>batched | fp8 | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],</br>[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],</br>[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] |
| cutlass_fp4 | standard,</br>batched | nvfp4 | A,T | silu | Y | Y | [`cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp4],</br>[`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] |
| cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],</br>[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
| cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
model_name: "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
accuracy_threshold: 0.92
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_high_throughput"
env:
VLLM_USE_DEEP_GEMM: "1"
VLLM_USE_DEEP_GEMM_MOE: "1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_low_latency --disable-uvicorn-access-log"
env:
VLLM_USE_DEEP_GEMM: "1"
VLLM_USE_DEEP_GEMM_MOE: "1"
VLLM_USE_DEEP_GEMM_E8M0: "0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
env:
VLLM_USE_DEEP_GEMM: "1"
VLLM_USE_DEEP_GEMM_MOE: "1"
Loading