Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
f8851e0
stash
Jan 8, 2026
085adf7
stash
Jan 8, 2026
a6b039d
update interface
Jan 8, 2026
f8052ce
stash
Jan 8, 2026
13b619f
stash
Jan 8, 2026
04bb010
first correctness!
Jan 8, 2026
b1320de
updated
Jan 8, 2026
4d47206
comments
Jan 8, 2026
f86fad8
updated
Jan 8, 2026
5601b95
Merge branch 'main' into naive-dispatch-combine
robertgshaw2-redhat Jan 8, 2026
8c1a530
updateds
Jan 8, 2026
7d7d5a6
nit changes
Jan 8, 2026
63357f7
support apply router weight on input
Jan 8, 2026
3886cfb
attempt to get everything working for llama scout modelopt flashinfer…
Jan 8, 2026
2284b59
updated
Jan 8, 2026
e131054
apply to batched deep gemm
Jan 10, 2026
77c7b05
updated
Jan 11, 2026
477d699
stash
Jan 11, 2026
9f2e10b
remove NaiveBatchedExperts
Jan 11, 2026
ef5e664
stash
Jan 11, 2026
f6e85bc
stash
Jan 11, 2026
0db0b11
added back moe torch iterative
Jan 11, 2026
09dc4f5
revert changes
Jan 11, 2026
3311e9e
re add
Jan 11, 2026
755a3a2
add back iterative
Jan 11, 2026
a8bb9d0
add methodology to kernels
Jan 11, 2026
eb571a2
restructure kernel selection logic
Jan 12, 2026
8f0a969
remove is_cuda
Jan 12, 2026
e461b6f
added renamed file
Jan 12, 2026
93bd28b
improve quant scheme
Jan 12, 2026
2b24d70
improve validation
Jan 12, 2026
312b767
improve platform selection logic
Jan 12, 2026
0f37e95
nit newline
Jan 12, 2026
c187335
nit newline
Jan 12, 2026
6ad575e
revert spurious LOC change
Jan 12, 2026
2aabb4f
revert spurious LOC change
Jan 12, 2026
3c5f602
updated
Jan 12, 2026
00a130a
update marlin
Jan 12, 2026
bd38266
updated
Jan 12, 2026
bba97f4
updated
Jan 12, 2026
dc9723a
support differentiating on static vs dynamic
Jan 12, 2026
fc9ea07
newline nit
Jan 12, 2026
ef8bb7e
update logic for aiter foudn
Jan 12, 2026
dde5dd2
cleanup support logic
Jan 12, 2026
f8045c6
update logic for trtllm config
Jan 12, 2026
40ecf90
updated
Jan 12, 2026
372c697
more progress
Jan 13, 2026
139941a
updated
Jan 13, 2026
fb90fc0
updated
Jan 13, 2026
b3c818d
updated
Jan 13, 2026
5efe9ff
Merge branch 'main' into naive-dispatch-combine
robertgshaw2-redhat Jan 13, 2026
234144b
merged
Jan 13, 2026
e932743
updated
Jan 13, 2026
b21e8be
updated
Jan 13, 2026
7a269bc
updated
Jan 13, 2026
02b0848
updated
Jan 13, 2026
85615f1
update kernel selection logic
Jan 13, 2026
9a4a871
seem to have accuracy with dp/ep and tp for deepgemm
Jan 13, 2026
9f829f0
updated
Jan 13, 2026
03ce528
re-add
Jan 13, 2026
9ce0412
added back naive batched experts
Jan 13, 2026
856d9dc
added back naive batched experts
Jan 13, 2026
38ae5e7
added back naive batched experts
Jan 13, 2026
81b5ed3
added back naive batched experts
Jan 13, 2026
550b42a
reduce LOC change
Jan 13, 2026
d85bd8b
reduce LOC change
Jan 13, 2026
642b85c
reduce LOC change
Jan 13, 2026
5027a28
reduce loc change
Jan 13, 2026
6a009d6
reduce LOC change
Jan 13, 2026
1cc46be
fix native ag/rs
Jan 14, 2026
62d9d7c
oracle is now working properly
Jan 14, 2026
7d7fcd9
confirmed aiter env variables work as expected
Jan 14, 2026
51a93fe
fix up oracle
Jan 14, 2026
a600601
trying to make ct work
Jan 14, 2026
63ea73a
Merge remote-tracking branch 'origin/main' into oracle-improvements
Jan 14, 2026
f83c63e
made llama 4 work via compressed tensors
Jan 14, 2026
9cb7771
initial attempt at modelopt
Jan 14, 2026
ad28a18
attempt to fix naive multicast
Jan 14, 2026
b507afa
flashinfer appears to be working
Jan 14, 2026
18f6159
Merge remote-tracking branch 'origin/main' into oracle-improvements
Jan 14, 2026
50b106b
stash changes for review
Jan 14, 2026
75aca72
made modelopt start up
Jan 14, 2026
3287512
still enable flashinfer
Jan 14, 2026
84b83b6
remove comments
Jan 14, 2026
3f90b88
appears that we have AG/RS working for nvfp4
Jan 14, 2026
19537ff
updated
Jan 14, 2026
dfefddf
updated
Jan 14, 2026
232a5b9
updared comment
Jan 14, 2026
d2dd979
make marling work with current device
Jan 14, 2026
d35c247
added compressed tensors nvfp4
Jan 14, 2026
4c9656a
added compressed tensors nvfp4 - nit
Jan 14, 2026
9d5d3ee
remove shared expert overlap functionality
Jan 14, 2026
553efc1
make cutedsl work
Jan 14, 2026
218e9bf
hook up marlin experts properly
Jan 14, 2026
3c97fda
hook up marlin experts properly
Jan 14, 2026
f2714d8
make marlin work with AG/RS
Jan 14, 2026
471428f
convert to using quant key
Jan 14, 2026
e17d456
convert to using quant key
Jan 14, 2026
8cacf5a
reject usage for things that have not migrated over yet
Jan 14, 2026
9518c97
reject usage for things that have not migrated over yet
Jan 14, 2026
ab090c1
reject usage for things that have not migrated over yet
Jan 14, 2026
1db50dc
reject usage for things that have not migrated over yet
Jan 14, 2026
6090a06
differentiate static vs dynamic quantization
Jan 14, 2026
6a3a75b
remove newline
Jan 14, 2026
89911ea
fix static vs dynamic
Jan 14, 2026
ad8fe2e
get things working again
Jan 14, 2026
7bc3674
updated
Jan 15, 2026
f8d3af7
updatred
Jan 15, 2026
10b957c
attempt to get sp working
Jan 15, 2026
e8ab545
nit
Jan 15, 2026
50a8c97
appears to be working properly
Jan 15, 2026
ddc2eb1
fix pre commit
Jan 15, 2026
5f913ea
updated
Jan 15, 2026
e58e783
remove flashinfer constructors
Jan 15, 2026
a8de4da
nits
Jan 15, 2026
2c9e9e6
remove do naive dispach combine comment
Jan 15, 2026
9a907e0
update backend names
Jan 15, 2026
0670758
update fallback experts
Jan 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/moe_kernel_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| deep gemm | standard,</br>batched | fp8 | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],</br>[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],</br>[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] |
| cutlass_fp4 | standard,</br>batched | nvfp4 | A,T | silu | Y | Y | [`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] |
| cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
Expand Down
6 changes: 3 additions & 3 deletions tests/compile/test_fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Quant,
kNvfp4Dynamic,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform
Expand Down Expand Up @@ -202,7 +202,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
"""Test model for AttentionNvfp4QuantPattern fusion."""

quant_key = kNvfp4Quant
quant_key = kNvfp4Dynamic

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -449,7 +449,7 @@ def test_attention_quant_pattern(

# Note: for fp8, fully_replaced=False because query quant ops remain in graph.
# Only output quant ops are fused into attention.
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant)
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic)

# access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
Expand Down
6 changes: 3 additions & 3 deletions tests/compile/test_silu_mul_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8StaticTensorSym,
kNvfp4Quant,
kNvfp4Dynamic,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
Expand Down Expand Up @@ -121,11 +121,11 @@ def forward(self, x):
def ops_in_model_before(self):
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
QUANT_OPS[kNvfp4Quant],
QUANT_OPS[kNvfp4Dynamic],
]

def ops_in_model_after(self):
return [FUSED_OPS[kNvfp4Quant]]
return [FUSED_OPS[kNvfp4Dynamic]]


class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
Expand Down
9 changes: 6 additions & 3 deletions tests/kernels/moe/modular_kernel_tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,15 @@ def is_fp8_block_quantized(self):

def is_batched_prepare_finalize(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format()

def is_batched_fused_experts(self):
info = expert_info(self.fused_experts_type)
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format()

def is_standard_fused_experts(self):
info = expert_info(self.fused_experts_type)
return mk.FusedMoEActivationFormat.Standard == info.activation_format
return mk.FusedMoEActivationFormat.Standard == info.activation_format()

def fe_supported_types(self):
info = expert_info(self.fused_experts_type)
Expand Down Expand Up @@ -574,10 +574,13 @@ def next_power_of_2(x):
num_experts=config.E,
experts_per_token=config.topk,
hidden_dim=config.K,
intermediate_size_per_partition=config.N,
num_local_experts=config.num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype,
max_num_tokens=next_power_of_2(config.M),
activation="silu",
device=vllm_config.device_config.device,
)

# make modular kernel
Expand Down
10 changes: 4 additions & 6 deletions tests/kernels/moe/modular_kernel_tools/mk_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def expert_info(kind) -> ExpertInfo:
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize,
create_flashinfer_prepare_finalize,
)

register_prepare_and_finalize(
Expand Down Expand Up @@ -389,13 +388,12 @@ def make_prepare_finalize(
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = maybe_make_prepare_finalize(moe, quant_config)
# TODO(rob): add defer input quant.
prepare_finalize = maybe_make_prepare_finalize(
moe, quant_config, allow_new_interface=True
)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return create_flashinfer_prepare_finalize(
use_dp=moe.moe_parallel_config.dp_size > 1
)
else:
return MoEPrepareAndFinalizeNoEP()

Expand Down
1 change: 1 addition & 0 deletions tests/kernels/moe/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def make_moe_tensors_8bit(
ep_rank=0,
use_ep=False,
all2all_backend="naive",
isequence_parallel=False,
)

# flashinfer expects swapped rows for w13
Expand Down
44 changes: 39 additions & 5 deletions tests/kernels/moe/test_flashinfer_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (
create_flashinfer_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import set_random_seed
Expand Down Expand Up @@ -86,9 +90,39 @@ def test_flashinfer_fp4_moe_no_graph(

assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)

moe_config = FusedMoEConfig(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
num_local_experts=e,
activation=activation,
device="cuda",
moe_parallel_config=FusedMoEParallelConfig(
tp_size=1,
pcp_size=1,
dp_size=1,
ep_size=1,
tp_rank=0,
pcp_rank=0,
dp_rank=0,
ep_rank=0,
use_ep=False,
all2all_backend="allgather_reducescatter",
isequence_parallel=False,
),
in_dtype=dtype,
is_act_and_mul=is_gated_act,
)

flashinfer_experts = FusedMoEModularKernel(
create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True),
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
MoEPrepareAndFinalizeNoEP(
defer_input_quant=FlashInferExperts.should_pf_defer_input_quant(
moe_config=moe_config,
quant_config=quant_config,
)
),
FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
)

fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
Expand Down
1 change: 0 additions & 1 deletion tests/kernels/moe/test_nvfp4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def test_cutlass_fp4_moe_no_graph(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
out_dtype=dtype,
max_experts_per_worker=e,
quant_config=quant_config,
),
)
Expand Down
6 changes: 3 additions & 3 deletions vllm/compilation/activation_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Quant,
kNvfp4Dynamic,
)
from vllm.platforms import current_platform

Expand All @@ -41,7 +41,7 @@
torch.ops._C, "silu_and_mul_nvfp4_quant"
)
if silu_and_mul_nvfp4_quant_supported:
FUSED_OPS[kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501


class ActivationQuantPattern(ABC):
Expand Down Expand Up @@ -129,7 +129,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
"""

def __init__(self) -> None:
super().__init__(kNvfp4Quant)
super().__init__(kNvfp4Dynamic)

def get_inputs(self) -> list[torch.Tensor]:
result = self.empty_quant(5, 32)
Expand Down
4 changes: 2 additions & 2 deletions vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Quant,
kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
Expand Down Expand Up @@ -63,7 +63,7 @@ def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
Expand Down
4 changes: 2 additions & 2 deletions vllm/compilation/fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Quant,
kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
Expand Down Expand Up @@ -217,7 +217,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""

def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
super().__init__(layer, kNvfp4Quant, dtype)
super().__init__(layer, kNvfp4Dynamic, dtype)

def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
Expand Down
4 changes: 2 additions & 2 deletions vllm/compilation/matcher_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Quant,
kNvfp4Dynamic,
)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
Expand All @@ -38,7 +38,7 @@
}

if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501

if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
Expand Down
1 change: 1 addition & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def stateless_init_dp_group(self) -> ProcessGroup:
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
# TODO(rob): investigate 'flashinfer_all2allv'?
@property
def use_sequence_parallel_moe(self) -> bool:
return (
Expand Down
Loading
Loading