From f8851e0a53477021590fb69666e1606fc4fd7c09 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 7 Jan 2026 19:01:47 -0500 Subject: [PATCH 001/113] stash Signed-off-by: Robert Shaw --- .../layers/fused_moe/prepare_finalize.py | 86 ++++++++++++++++++- 1 file changed, 85 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 5d806fa843a3..5b9a7d65c366 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -11,12 +11,96 @@ ) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): + def __init__(self) -> None: + super().__init__() + self.dummy_tensor = torch.empty(1, device='cuda') + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + def num_dispatchers(self) -> int: + return 1 + + def output_is_reduced(self) -> bool: + return False + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + from vllm.distributed import get_ep_group + + if not hasattr(self, "dummy_tensor"): + self.dummy_tensor = torch.zeros(1, device='cuda') + + extra_tensors = [topk_weights,topk_ids] + a1, _, extra_tensors = get_ep_group().dispatch( + a1, + self.dummy_tensor, + is_sequence_parallel=False, # TODO? + extra_tensors=extra_tensors, + ) + topk_weights, topk_ids = extra_tensors + + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_scale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + ) + expert_tokens_meta = None # TODO? + + return a1q, a1q_scale, None, topk_weights, topk_ids + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + weight_and_reduce_impl.apply( + output=output, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + from vllm.distributed import get_ep_group + combined_output = get_ep_group().combine( + output, + is_sequence_parallel=False + ) + + combined_output.copy_(output) + + class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): def __init__(self, defer_input_quant: bool = False) -> None: super().__init__() self.defer_input_quant = defer_input_quant - + @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard From 085adf7377692bfe287dd58c8e5d98b2dad32e18 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 7 Jan 2026 19:02:06 -0500 Subject: [PATCH 002/113] stash Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/prepare_finalize.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 5b9a7d65c366..06bcb4f68b22 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -50,7 +50,7 @@ def prepare( extra_tensors = [topk_weights,topk_ids] a1, _, extra_tensors = get_ep_group().dispatch( a1, - self.dummy_tensor, + self.dummy_tensor, # router logits is_sequence_parallel=False, # TODO? extra_tensors=extra_tensors, ) @@ -91,7 +91,6 @@ def finalize( output, is_sequence_parallel=False ) - combined_output.copy_(output) From a6b039dd691bfd74fb017a807cb0e39bf0270a7f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 7 Jan 2026 19:14:29 -0500 Subject: [PATCH 003/113] update interface Signed-off-by: Robert Shaw --- .../device_communicators/all2all.py | 44 ++++++++++++------- .../base_device_communicator.py | 10 +++-- .../device_communicators/cuda_communicator.py | 10 +++-- .../device_communicators/xpu_communicator.py | 8 ++-- .../model_executor/layers/quantization/fp8.py | 4 ++ 5 files changed, 49 insertions(+), 27 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 7a4e81cf967d..891c9439da25 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -62,10 +62,11 @@ def naive_multicast( def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if extra_tensors is not None: raise NotImplementedError( "extra_tensors is not supported for NaiveAll2AllManager" @@ -78,11 +79,14 @@ def dispatch( hidden_states = self.naive_multicast( hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel ) - router_logits = self.naive_multicast( - router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel + topk_weights = self.naive_multicast( + topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel + ) + topk_ids = self.naive_multicast( + topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel ) - return hidden_states, router_logits + return hidden_states, topk_weights, topk_ids def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False @@ -117,12 +121,13 @@ def __init__(self, cpu_group): def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, ) -> ( - tuple[torch.Tensor, torch.Tensor] - | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] ): """ Gather hidden_states and router_logits from all dp ranks. @@ -134,7 +139,7 @@ def dispatch( dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] - tensors_to_gather = [hidden_states, router_logits] + tensors_to_gather = [hidden_states, topk_weights, topk_ids] if extra_tensors is not None: tensors_to_gather.extend(extra_tensors) @@ -144,9 +149,14 @@ def dispatch( sizes=sizes, ) - if extra_tensors is not None: - return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:]) - return gathered_tensors[0], gathered_tensors[1] + hidden_states = gathered_tensors[0] + topk_weights = gathered_tensors[1] + topk_ids = gathered_tensors[2] + + if extra_tensors is None: + return hidden_states, topk_weights, topk_ids + + return hidden_states, topk_weights, topk_ids, gathered_tensors[3:] def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False @@ -219,10 +229,11 @@ def get_handle(self, kwargs): def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: raise NotImplementedError def combine( @@ -267,10 +278,11 @@ def get_handle(self, kwargs): def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: raise NotImplementedError def combine( diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 4a2a7ec5b728..3a951e4be23b 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -67,7 +67,8 @@ def get_handle(self, kwargs): def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, ) -> Any: @@ -283,15 +284,16 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Dispatch the hidden states and router logits to the appropriate device. This is a no-op in the base class. """ - return hidden_states, router_logits + return hidden_states, topk_weights, topk_ids def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 9542498c453e..6d725be241e1 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -321,17 +321,19 @@ def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None): def dispatch( # type: ignore[override] self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, ) -> ( - tuple[torch.Tensor, torch.Tensor] - | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] ): assert self.all2all_manager is not None return self.all2all_manager.dispatch( hidden_states, - router_logits, + topk_weights, + topk_ids, is_sequence_parallel, extra_tensors, # type: ignore[call-arg] ) diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index f3d9262d20cf..d7a9d832370f 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -76,14 +76,16 @@ def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert self.all2all_manager is not None return self.all2all_manager.dispatch( hidden_states, - router_logits, + topk_weights, + topk_ids, is_sequence_parallel, extra_tensors, # type: ignore[call-arg] ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1223c6902e5f..003a683e26aa 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1323,10 +1323,14 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, ) + logger.info_once(f"{router_logits.shape=}") topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) + logger.info_once(f"{router_logits.shape=}") + logger.info_once(f"{topk_weights.shape=}") + result = self.kernel( x, layer.w13_weight, From f8052ce5b8e9d7072776b78397542b4c2c286d1f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 7 Jan 2026 19:37:47 -0500 Subject: [PATCH 004/113] stash Signed-off-by: Robert Shaw --- vllm/distributed/parallel_state.py | 12 +++--- .../layers/fused_moe/prepare_finalize.py | 39 +++++++++++-------- .../model_executor/layers/quantization/fp8.py | 5 +-- 3 files changed, 31 insertions(+), 25 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 4611d42a5874..f5a9aea4a285 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1005,22 +1005,24 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module): def dispatch( self, hidden_states: torch.Tensor, - router_logits: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, ) -> ( - tuple[torch.Tensor, torch.Tensor] - | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] ): if self.device_communicator is not None: return self.device_communicator.dispatch( # type: ignore[call-arg] hidden_states, - router_logits, + topk_weights, + topk_ids, is_sequence_parallel, extra_tensors, ) else: - return hidden_states, router_logits + return hidden_states, topk_weights, topk_ids def combine( self, hidden_states, is_sequence_parallel: bool = False diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 06bcb4f68b22..1f036b5e227c 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -11,10 +11,11 @@ ) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input + class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): def __init__(self) -> None: super().__init__() - self.dummy_tensor = torch.empty(1, device='cuda') + self.dummy_tensor = torch.empty(1, device="cuda") @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -31,7 +32,7 @@ def num_dispatchers(self) -> int: def output_is_reduced(self) -> bool: return False - + def prepare( self, a1: torch.Tensor, @@ -44,17 +45,20 @@ def prepare( ) -> mk.PrepareResultType: from vllm.distributed import get_ep_group - if not hasattr(self, "dummy_tensor"): - self.dummy_tensor = torch.zeros(1, device='cuda') + print(f"before: {a1.shape=}") + print(f"before: {topk_weights.shape=}") + print(f"before: {topk_ids.shape=}") - extra_tensors = [topk_weights,topk_ids] - a1, _, extra_tensors = get_ep_group().dispatch( + a1, topk_weights, topk_ids = get_ep_group().dispatch( a1, - self.dummy_tensor, # router logits - is_sequence_parallel=False, # TODO? - extra_tensors=extra_tensors, + topk_weights, + topk_ids, + # TODO? + is_sequence_parallel=False, ) - topk_weights, topk_ids = extra_tensors + print(f"after: {a1.shape=}") + print(f"after: {topk_weights.shape=}") + print(f"after: {topk_ids.shape=}") a1q, a1q_scale = moe_kernel_quantize_input( a1, @@ -63,7 +67,11 @@ def prepare( quant_config.per_act_token_quant, quant_config.block_shape, ) - expert_tokens_meta = None # TODO? + # TODO? + expert_tokens_meta = None + print( + f"{a1q.dtype=}, {a1q_scale.dtype=} {topk_weights.dtype=}, {topk_ids.dtype=}" + ) return a1q, a1q_scale, None, topk_weights, topk_ids @@ -87,19 +95,16 @@ def finalize( ) from vllm.distributed import get_ep_group - combined_output = get_ep_group().combine( - output, - is_sequence_parallel=False - ) - combined_output.copy_(output) + combined_output = get_ep_group().combine(output, is_sequence_parallel=False) + combined_output.copy_(output) class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): def __init__(self, defer_input_quant: bool = False) -> None: super().__init__() self.defer_input_quant = defer_input_quant - + @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 003a683e26aa..610ebe3b3399 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1123,6 +1123,7 @@ def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: + # return MoEPrepareAndFinalizeNaiveEP() if ( self.fp8_backend == Fp8MoeBackend.AITER or self.fp8_backend == Fp8MoeBackend.MARLIN @@ -1323,13 +1324,11 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - logger.info_once(f"{router_logits.shape=}") topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) - logger.info_once(f"{router_logits.shape=}") - logger.info_once(f"{topk_weights.shape=}") + print(f"{x.dtype=}, {topk_weights.dtype=}, {topk_ids.dtype=}") result = self.kernel( x, From 13b619ff7d04d6e579c39bfbcf10cd61153f2d1d Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 7 Jan 2026 20:20:23 -0500 Subject: [PATCH 005/113] stash Signed-off-by: Robert Shaw --- .../layers/fused_moe/fused_moe_modular_method.py | 1 + vllm/model_executor/layers/quantization/fp8.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 6abefde0763e..53a132351374 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -95,6 +95,7 @@ def apply( hidden_states=x, router_logits=router_logits, ) + logger.info(f"after se: {x.dtype=}, {topk_weights.dtype=}, {topk_ids.dtype=}") result = self.fused_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 610ebe3b3399..d612ada0449d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1123,7 +1123,11 @@ def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - # return MoEPrepareAndFinalizeNaiveEP() + from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNaiveEP, + ) + + return MoEPrepareAndFinalizeNaiveEP() if ( self.fp8_backend == Fp8MoeBackend.AITER or self.fp8_backend == Fp8MoeBackend.MARLIN From 04bb01013b7b19d6b98a2f608e36cceaaf2afce2 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 7 Jan 2026 20:33:47 -0500 Subject: [PATCH 006/113] first correctness! Signed-off-by: Robert Shaw --- .../layers/fused_moe/fused_moe.py | 3 ++ .../layers/fused_moe/prepare_finalize.py | 28 ++++++------------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c4047401c0e7..7ca65bc85628 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2315,6 +2315,9 @@ def apply( expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): + logger.info( + f"in apply: {hidden_states.dtype=}, {topk_weights.dtype=}, {topk_ids.dtype=}, {a1q_scale.dtype=}" + ) # Check constraints. if self.quant_config.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch" diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 1f036b5e227c..de86469d491a 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -4,6 +4,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.distributed import get_ep_group from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceContiguous, @@ -43,12 +44,6 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - from vllm.distributed import get_ep_group - - print(f"before: {a1.shape=}") - print(f"before: {topk_weights.shape=}") - print(f"before: {topk_ids.shape=}") - a1, topk_weights, topk_ids = get_ep_group().dispatch( a1, topk_weights, @@ -56,9 +51,6 @@ def prepare( # TODO? is_sequence_parallel=False, ) - print(f"after: {a1.shape=}") - print(f"after: {topk_weights.shape=}") - print(f"after: {topk_ids.shape=}") a1q, a1q_scale = moe_kernel_quantize_input( a1, @@ -67,13 +59,10 @@ def prepare( quant_config.per_act_token_quant, quant_config.block_shape, ) - # TODO? + # TODO - this is just for deepgemm expert_tokens_meta = None - print( - f"{a1q.dtype=}, {a1q_scale.dtype=} {topk_weights.dtype=}, {topk_ids.dtype=}" - ) - return a1q, a1q_scale, None, topk_weights, topk_ids + return a1q, a1q_scale, None, topk_ids, topk_weights def finalize( self, @@ -86,18 +75,17 @@ def finalize( ) -> None: if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): weight_and_reduce_impl = TopKWeightAndReduceContiguous() - weight_and_reduce_impl.apply( - output=output, + + out = weight_and_reduce_impl.apply( + output=None, fused_expert_output=fused_expert_output, topk_weights=topk_weights, topk_ids=topk_ids, apply_router_weight_on_input=apply_router_weight_on_input, ) - from vllm.distributed import get_ep_group - - combined_output = get_ep_group().combine(output, is_sequence_parallel=False) - combined_output.copy_(output) + # ... copy the + output.copy_(get_ep_group().combine(out, is_sequence_parallel=False)) class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): From b1320de56e5d06dfe32148bd86bceb847e634f5d Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 7 Jan 2026 20:40:55 -0500 Subject: [PATCH 007/113] updated Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/fp8.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d612ada0449d..26c371a8a56f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1332,7 +1332,6 @@ def apply( hidden_states=x, router_logits=router_logits, ) - print(f"{x.dtype=}, {topk_weights.dtype=}, {topk_ids.dtype=}") result = self.kernel( x, From 4d472065615a080e72b53a1265309e1b531c3cef Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 7 Jan 2026 20:41:45 -0500 Subject: [PATCH 008/113] comments Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/fused_moe.py | 3 --- .../layers/fused_moe/fused_moe_modular_method.py | 1 - 2 files changed, 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7ca65bc85628..c4047401c0e7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2315,9 +2315,6 @@ def apply( expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): - logger.info( - f"in apply: {hidden_states.dtype=}, {topk_weights.dtype=}, {topk_ids.dtype=}, {a1q_scale.dtype=}" - ) # Check constraints. if self.quant_config.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch" diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 53a132351374..6abefde0763e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -95,7 +95,6 @@ def apply( hidden_states=x, router_logits=router_logits, ) - logger.info(f"after se: {x.dtype=}, {topk_weights.dtype=}, {topk_ids.dtype=}") result = self.fused_experts( hidden_states=x, From f86fad8c14f1458d08da3890c2e292c7ef6923ad Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 7 Jan 2026 20:42:47 -0500 Subject: [PATCH 009/113] updated Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/prepare_finalize.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index de86469d491a..f406df814eff 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -14,10 +14,6 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): - def __init__(self) -> None: - super().__init__() - self.dummy_tensor = torch.empty(1, device="cuda") - @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard @@ -59,10 +55,10 @@ def prepare( quant_config.per_act_token_quant, quant_config.block_shape, ) - # TODO - this is just for deepgemm + # TODO - this is just for deepgemm? expert_tokens_meta = None - return a1q, a1q_scale, None, topk_ids, topk_weights + return a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights def finalize( self, From 8c1a530d976776aadb69a314c32fbd6f6c8d1f2b Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 7 Jan 2026 21:17:26 -0500 Subject: [PATCH 010/113] updateds Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/all2all_utils.py | 7 +++++++ vllm/model_executor/layers/fused_moe/config.py | 10 ++++++++++ vllm/model_executor/layers/quantization/fp8.py | 6 +----- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 036b3cac4cb3..d76c629b72ba 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -15,6 +15,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPrepareAndFinalize, ) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNaiveEP, +) from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_ep, has_pplx @@ -170,4 +173,8 @@ def maybe_make_prepare_finalize( local_expert_global_ids=local_expert_global_ids, ) + elif moe.use_naive_kernels: + prepare_finalize = MoEPrepareAndFinalizeNaiveEP() + + print(f"{prepare_finalize=}") return prepare_finalize diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 17d5ec4bcda7..162fe36ba044 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -855,6 +855,12 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" + @property + def use_naive_kernels(self): + return self.use_all2all_kernels and ( + self.all2all_backend in ["allgather_reducescatter", "naive"] + ) + @staticmethod def flatten_tp_across_dp_and_pcp( tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int @@ -1076,6 +1082,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_naive_kernels(self): + return self.moe_parallel_config.use_naive_kernels + @property def use_flashinfer_cutlass_kernels(self): """ diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 80db65c2f774..df494c05ec2d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -877,11 +877,6 @@ def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNaiveEP, - ) - - return MoEPrepareAndFinalizeNaiveEP() if self.fp8_backend in [ Fp8MoeBackend.AITER, Fp8MoeBackend.MARLIN, @@ -889,6 +884,7 @@ def maybe_make_prepare_finalize( ]: return None elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: + # TODO(rob): we can remove this. prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( self.moe, use_deepseek_fp8_block_scale=self.block_quant, From 7d7d5a62256de6fe20e146ffd1187014246322cf Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 7 Jan 2026 21:33:17 -0500 Subject: [PATCH 011/113] nit changes Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/all2all_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index d76c629b72ba..d6740d8a3ff5 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -176,5 +176,4 @@ def maybe_make_prepare_finalize( elif moe.use_naive_kernels: prepare_finalize = MoEPrepareAndFinalizeNaiveEP() - print(f"{prepare_finalize=}") return prepare_finalize From 63357f7d881959a3b97c1dde7043b9d41695f7a4 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 7 Jan 2026 21:55:16 -0500 Subject: [PATCH 012/113] support apply router weight on input Signed-off-by: Robert Shaw --- .../layers/fused_moe/prepare_finalize.py | 35 +++++++++++++++---- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index f406df814eff..c8ab382ab158 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -40,13 +40,13 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - a1, topk_weights, topk_ids = get_ep_group().dispatch( - a1, - topk_weights, - topk_ids, - # TODO? - is_sequence_parallel=False, - ) + if apply_router_weight_on_input: + topk = topk_ids.size(1) + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + # Note: do not use inplace for shared experts overlap + a1 = a1 * topk_weights.to(a1.dtype) a1q, a1q_scale = moe_kernel_quantize_input( a1, @@ -58,6 +58,27 @@ def prepare( # TODO - this is just for deepgemm? expert_tokens_meta = None + from vllm.platforms import current_platform + + # The torch ops do not support fp8, so use an int8 view. + # Since this does not do a reduce, it is safe to do. + use_int8_view = a1q.dtype == current_platform.fp8_dtype() + if use_int8_view: + a1q = a1q.view(torch.int8) + + extra_tensors = None if a1q_scale is None else [a1q_scale] + a1q, topk_weights, topk_ids, et = get_ep_group().dispatch( + a1q, + topk_weights, + topk_ids, + is_sequence_parallel=False, # TODO: support SP + extra_tensors=extra_tensors, + ) + a1q_scale = et[0] if extra_tensors is not None else None + + if use_int8_view: + a1q = a1q.view(current_platform.fp8_dtype()) + return a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights def finalize( From 3886cfbb4414aefe70dede928467eabed934b688 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 7 Jan 2026 22:09:01 -0500 Subject: [PATCH 013/113] attempt to get everything working for llama scout modelopt flashinfer cutlass Signed-off-by: Robert Shaw --- .../flashinfer_cutlass_prepare_finalize.py | 11 +++--- .../layers/fused_moe/prepare_finalize.py | 34 ++++++++++++------- vllm/model_executor/models/llama4.py | 1 + 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 0b0efdafbd4d..5f48e83ee5a3 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -11,6 +11,7 @@ from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNaiveEP, MoEPrepareAndFinalizeNoEP, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( @@ -363,13 +364,11 @@ def create_flashinfer_prepare_finalize( else: return FlashInferAllGatherMoEPrepareAndFinalize(use_dp) - # FP8 DP path currently supported via AllGather. + # NOTE(rob): CUTLASS FP8 block quant executes the input + # quantzation and grouped gemm in a single kernel. if use_dp: - return FlashInferAllGatherMoEPrepareAndFinalize( - use_dp=True, - use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, + return MoEPrepareAndFinalizeNaiveEP( + defer_input_quant=use_deepseek_fp8_block_scale ) else: - # NOTE(rob): CUTLASS FP8 block quant executes the input - # quantzation and grouped gemm in a single kernel. return MoEPrepareAndFinalizeNoEP(defer_input_quant=use_deepseek_fp8_block_scale) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index c8ab382ab158..a6d81ea16500 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -14,6 +14,10 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): + def __init__(self, defer_input_quant: bool = False) -> None: + super().__init__() + self.defer_input_quant = defer_input_quant + @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard @@ -48,33 +52,38 @@ def prepare( # Note: do not use inplace for shared experts overlap a1 = a1 * topk_weights.to(a1.dtype) - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_scale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - ) + # Defer input quantization to the MoE kernel. + if self.defer_input_quant: + a1q = a1 + a1q_scale = None + else: + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_scale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + ) # TODO - this is just for deepgemm? expert_tokens_meta = None from vllm.platforms import current_platform # The torch ops do not support fp8, so use an int8 view. - # Since this does not do a reduce, it is safe to do. + # Since dispatch does not do a reduce, this is safe to do. use_int8_view = a1q.dtype == current_platform.fp8_dtype() if use_int8_view: a1q = a1q.view(torch.int8) - extra_tensors = None if a1q_scale is None else [a1q_scale] - a1q, topk_weights, topk_ids, et = get_ep_group().dispatch( + scales = None if a1q_scale is None else [a1q_scale] + a1q, topk_weights, topk_ids, scales = get_ep_group().dispatch( a1q, topk_weights, topk_ids, is_sequence_parallel=False, # TODO: support SP - extra_tensors=extra_tensors, + extra_tensors=scales, ) - a1q_scale = et[0] if extra_tensors is not None else None + a1q_scale = scales[0] if scales is not None else None if use_int8_view: a1q = a1q.view(current_platform.fp8_dtype()) @@ -101,7 +110,6 @@ def finalize( apply_router_weight_on_input=apply_router_weight_on_input, ) - # ... copy the output.copy_(get_ep_group().combine(out, is_sequence_parallel=False)) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 9ed0741acba1..3698d867c5a5 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -432,6 +432,7 @@ def load_moe_expert_weights( # Whether the MoE expert weights are loaded successfully. expert_param_loaded = False + loaded_weight = loaded_weight.to("cuda") # If fused is True, the loaded weight is in the layout of: # [num_experts, hidden_in, hidden_out], so we must transpose the last From 2284b59ee76960d9c2d5e6b66a97bb4ffa06fcdf Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 7 Jan 2026 22:36:24 -0500 Subject: [PATCH 014/113] updated Signed-off-by: Robert Shaw --- .../layers/fused_moe/prepare_finalize.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index a6d81ea16500..b66e2a1dd73f 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -75,15 +75,24 @@ def prepare( if use_int8_view: a1q = a1q.view(torch.int8) - scales = None if a1q_scale is None else [a1q_scale] - a1q, topk_weights, topk_ids, scales = get_ep_group().dispatch( + # Skip gathering scales if we have static quantization + # (the scale is a scalar, replicated on all ranks) or + # if quantization is deferred. + skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0 + scales = None if skip_gather_scales else [a1q_scale] + + res = get_ep_group().dispatch( a1q, topk_weights, topk_ids, is_sequence_parallel=False, # TODO: support SP extra_tensors=scales, ) - a1q_scale = scales[0] if scales is not None else None + if skip_gather_scales: + a1q, topk_weights, topk_ids = res + else: + a1q, topk_weights, topk_ids, scales = res + a1q_scale = res[0] if use_int8_view: a1q = a1q.view(current_platform.fp8_dtype()) From e131054dad714905a84372eaf5f911e35b0db3ef Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 10 Jan 2026 10:51:36 -0500 Subject: [PATCH 015/113] apply to batched deep gemm Signed-off-by: Robert Shaw --- .../layers/fused_moe/batched_deep_gemm_moe.py | 24 ++++++++++++ .../layers/fused_moe/modular_kernel.py | 38 +++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 15f6e3a18ed6..fa2724f2ff59 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -4,6 +4,7 @@ import torch +from vllm.platforms import current_platform import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger @@ -276,6 +277,29 @@ def activation_formats( mk.FusedMoEActivationFormat.BatchedExperts, mk.FusedMoEActivationFormat.BatchedExperts, ) + + def supports_current_device(self) -> bool: + return ( + current_platform.is_cuda() and + current_platform.has_device_capability(9,0) + ) + + def supports_no_act_and_mul(self) -> bool: + return False + + def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + return ( + quant_config.use_fp8_w8a8 and + quant_config.is_block_quantized() and + quant_config.block_shape[0] == 128 and + quant_config.block_shape[1] == 128 + ) + + def supports_act_fn(self, activation: str) -> bool: + return activation in ["silu"] + + def supports_ep(self) -> bool: + return True def supports_chunking(self) -> bool: return False diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 79168948f04a..a7eea0e0ace7 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -433,6 +433,44 @@ def moe_problem_size( topk = topk_ids.size(1) return E, M, N, K, topk + + # + # Various helpers for registering support for various features. + # + + @abstractmethod + def supports_current_device(self) -> bool: + """ + Whether the kernel supports the current device type + (compute cability and current platform). + """ + raise NotImplementedError + + @abstractmethod + def supports_no_act_and_mul(self) -> bool: + """ + Whether the kernel supports act_and_mul=False, i.e. + non-gated MoE models like Nemotron-Nano. + """ + raise NotImplementedError + + @abstractmethod + def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + raise NotImplementedError + + @abstractmethod + def supports_act_fn(self, activation: str) -> bool: + """ + Whether the kernel supports a particular act function. + """ + raise NotImplementedError + + @abstractmethod + def supports_ep(self) -> bool: + """ + Whether the kernel supports deployment in expert parallel. + """ + raise NotImplementedError # # Various helpers for accessing quantization parameters from the From 77c7b056f37448910b8569b82e20e786b2f1438d Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 11 Jan 2026 09:29:20 -0500 Subject: [PATCH 016/113] updated Signed-off-by: Robert Shaw --- docs/design/moe_kernel_features.md | 2 +- .../layers/fused_moe/deep_gemm_moe.py | 24 ++++ .../fused_moe/flashinfer_cutedsl_moe.py | 19 +++ .../fused_moe/flashinfer_cutlass_moe.py | 120 ++++++------------ 4 files changed, 82 insertions(+), 83 deletions(-) diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 18216b5965af..5b35b152ff38 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -88,7 +88,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels | deep gemm | standard,
batched | fp8 | G(128),A,T | silu, gelu | 6 | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],
[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],
[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] | | cutlass_fp4 | standard,
batched | nvfp4 | A,T | silu | Y | Y | [`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] | | cutlass_fp8 | standard,
batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],
[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] | -| flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],
[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | +| flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | marlin | standard,
batched | 3 / N/A | 3 / N/A | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 5ca91768c976..044617005e4c 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -3,6 +3,7 @@ import torch +from vllm.platforms import current_platform import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( @@ -125,6 +126,29 @@ def activation_formats( mk.FusedMoEActivationFormat.Standard, ) + def supports_current_device(self) -> bool: + return ( + current_platform.is_cuda() and + current_platform.has_device_capability(9,0) + ) + + def supports_no_act_and_mul(self) -> bool: + return False + + def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + return ( + quant_config.use_fp8_w8a8 and + quant_config.is_block_quantized() and + quant_config.block_shape[0] == 128 and + quant_config.block_shape[1] == 128 + ) + + def supports_act_fn(self, activation: str) -> bool: + return activation in ["silu"] + + def supports_ep(self) -> bool: + return True + def supports_chunking(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index ce93ae235f27..7b08da7fa6f6 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -3,6 +3,7 @@ import torch +from vllm.platforms import current_platform import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm.logger import init_logger @@ -69,6 +70,24 @@ def activation_formats( mk.FusedMoEActivationFormat.BatchedExperts, ) + def supports_current_device(self) -> bool: + return ( + current_platform.is_cuda() and + current_platform.has_device_capability(10,0) + ) + + def supports_no_act_and_mul(self) -> bool: + return False + + def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + return quant_config.use_nvfp4_w4a4 + + def supports_act_fn(self, activation: str) -> bool: + return activation in ["silu"] + + def supports_ep(self) -> bool: + return True + def supports_expert_map(self) -> bool: return False diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 09c3d9b2190f..231552f78994 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -3,6 +3,7 @@ import torch +from vllm.platforms import current_platform import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig @@ -74,6 +75,43 @@ def __init__( # - pass per-block weight scales to the kernel # - skip input activation quantization (kernel applies scaling) self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale + + def supports_current_device(self) -> bool: + return ( + current_platform.is_cuda() and + current_platform.has_device_capability(9,0) + ) + + def supports_no_act_and_mul(self) -> bool: + return False + + def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + # Supports unquantized, fp8, and nvfp4. + if not ( + quant_config.use_nvfp4_w4a4 or + quant_config.use_fp8_w8a8 or + quant_config.quant_dtype == None # TODO: how to express unquantized? + ): + return False + + # For FP8, only support static per tensor or DeepGEMM SwapAB for hopper. + if quant_config.use_fp8_w8a8: + if quant_config.is_per_tensor: + return True + elif quant_config.is_per_act_token: + return False + elif quant_config.is_block_quantized: + if (current_platform.is_cuda and current_platform.is_device_capability(9,0) and quant_config.block_shape[0] == 128 and quant_config.block_shape[1] == 128): + return True + return False + + return True + + def supports_act_fn(self, activation: str) -> bool: + return activation in ["silu", "relu2_no_mul"] + + def supports_ep(self) -> bool: + return True @property def activation_formats( @@ -225,85 +263,3 @@ def apply( # Informs FlashInfer to use the block-scale decoding path when True use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale, ) - - -def flashinfer_cutlass_moe_fp4( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_config: FusedMoEQuantConfig, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, -) -> torch.Tensor: - fused_experts = mk.FusedMoEModularKernel( - create_flashinfer_prepare_finalize( - use_dp=False, use_nvfp4=True, enable_alltoallv=False - ), - FlashInferExperts( - out_dtype=hidden_states.dtype, - quant_config=quant_config, - use_dp=False, - ), - ) - - return fused_experts( - hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=inplace, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - - -def flashinfer_cutlass_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_config: FusedMoEQuantConfig, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - tp_rank: int = 0, - tp_size: int = 1, - ep_rank: int = 0, - ep_size: int = 1, - use_dp: bool = False, -) -> torch.Tensor: - fused_experts = mk.FusedMoEModularKernel( - create_flashinfer_prepare_finalize(use_dp=use_dp), - FlashInferExperts( - out_dtype=hidden_states.dtype, - quant_config=quant_config, - tp_rank=tp_rank, - tp_size=tp_size, - ep_rank=ep_rank, - ep_size=ep_size, - ), - ) - - return fused_experts( - hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=inplace, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) From 477d6992935a71f472951ccd102c3c109edc711c Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 11 Jan 2026 09:36:11 -0500 Subject: [PATCH 017/113] stash Signed-off-by: Robert Shaw --- .../layers/fused_moe/fused_batched_moe.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 7fd8511e297d..40fd6ae3a817 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -4,6 +4,7 @@ import torch +from vllm.platforms import current_platform import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config @@ -653,6 +654,26 @@ def activation_formats( mk.FusedMoEActivationFormat.BatchedExperts, mk.FusedMoEActivationFormat.BatchedExperts, ) + + def supports_current_device(self) -> bool: + return current_platform.is_cuda_alike() + + def supports_no_act_and_mul(self) -> bool: + return False + + def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + return ( + quant_config.use_fp8_w8a8 and + quant_config.is_block_quantized() and + quant_config.block_shape[0] == 128 and + quant_config.block_shape[1] == 128 + ) + + def supports_act_fn(self, activation: str) -> bool: + return activation in ["silu", "gelu", "swigluoai"] + + def supports_ep(self) -> bool: + return True def supports_chunking(self) -> bool: return False From 9f2e10b7e31cf1619b29c848484a51981d8895aa Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 11 Jan 2026 09:57:24 -0500 Subject: [PATCH 018/113] remove NaiveBatchedExperts Signed-off-by: Robert Shaw --- .../moe/modular_kernel_tools/mk_objects.py | 14 -- tests/kernels/moe/test_batched_moe.py | 20 +- tests/kernels/moe/test_pplx_moe.py | 58 +----- tests/kernels/moe/utils.py | 41 ----- .../layers/fused_moe/fused_batched_moe.py | 171 +++--------------- 5 files changed, 29 insertions(+), 275 deletions(-) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 99b168dc7554..4400d2cf53d5 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -20,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts, - NaiveBatchedExperts, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, @@ -184,15 +183,6 @@ def expert_info(kind) -> ExpertInfo: needs_matching_quant=True, ) -register_experts( - NaiveBatchedExperts, - batched_format, - common_float_and_int_types, - blocked_quantization_support=True, - supports_chunking=False, - supports_expert_map=True, -) - # Disable on blackwell for now if has_deep_ep() and not current_platform.has_device_capability(100): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( @@ -455,10 +445,6 @@ def make_fused_experts( kwargs = quant_kwargs | deepgemm_kwargs print(f"Making TritonOrDeepGemmExperts {kwargs} ...") experts = TritonOrDeepGemmExperts(**kwargs) - elif fused_experts_type == NaiveBatchedExperts: - kwargs = batch_kwargs | quant_kwargs - print(f"Making NaiveBatchedExperts {kwargs} ...") - experts = NaiveBatchedExperts(**kwargs) elif fused_experts_type == CutlassExpertsFp8: strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim) kwargs = { diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index c9d425b5b990..08e080f64d2b 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -10,7 +10,6 @@ batched_moe, make_quantized_test_activations, make_test_weights, - naive_batched_moe, ) from tests.kernels.quant_utils import native_batched_masked_quant_matmul from tests.kernels.utils import torch_experts @@ -318,21 +317,6 @@ def test_fused_moe_batched_experts( block_shape=block_shape, ) - batched_output = naive_batched_moe( - a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s, - a1_scale=a1_scale, - a2_scale=a2_scale, - quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - ) - triton_output = batched_moe( a, w1, @@ -348,6 +332,4 @@ def test_fused_moe_batched_experts( block_shape=block_shape, ) - torch.testing.assert_close(batched_output, baseline_output, atol=3e-2, rtol=2e-2) - - torch.testing.assert_close(triton_output, batched_output, atol=2e-2, rtol=2e-2) + torch.testing.assert_close(triton_output, baseline_output, atol=2e-2, rtol=2e-2) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index c08a54f0e9f6..e0a1f49aced8 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -31,7 +31,6 @@ from tests.kernels.moe.utils import ( make_shared_experts, make_test_weights, - naive_batched_moe, ) from tests.kernels.quant_utils import dequant from tests.kernels.utils import torch_experts @@ -171,40 +170,6 @@ def torch_batched_moe( return torch_finalize(out, topk_weight, topk_ids) -@pytest.mark.parametrize("m,n,k", BATCHED_MOE_MNK_FACTORS) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -def test_fused_moe_batched_experts( - m: int, - n: int, - k: int, - e: int, - topk: int, - dtype: torch.dtype, - workspace_init, -): - set_random_seed(7) - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) - - with set_current_vllm_config(vllm_config): - topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - baseline_output = torch_experts( - a, w1, w2, topk_weight, topk_ids - ) # only for baseline - torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) - batched_output = naive_batched_moe( - a, w1, w2, topk_weight, topk_ids - ) # pick torch_experts or this - - torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0) - torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0) - - def create_pplx_prepare_finalize( num_tokens: int, hidden_dim: int, @@ -716,21 +681,6 @@ def _pplx_moe( block_shape=block_shape, ) - batched_output = naive_batched_moe( - a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s, - a1_scale=a1_scale, - a2_scale=a2_scale, - quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - ) - pplx_outputs = pplx_moe( group_name, rank, @@ -766,14 +716,12 @@ def _pplx_moe( else: chunked_shared_output = None - chunked_batch_output = chunk_by_rank( - batched_output, pgi.rank, pgi.world_size + chunked_torch_output = chunk_by_rank( + torch_output, pgi.rank, pgi.world_size ).to(pplx_output.device) - torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2) - torch.testing.assert_close( - pplx_output, chunked_batch_output, atol=3e-2, rtol=3e-2 + pplx_output, chunked_torch_output, atol=3e-2, rtol=3e-2 ) if shared_experts is not None: diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index f0c8c8033b8e..8d66fdecee37 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -12,7 +12,6 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize, BatchedTritonExperts, - NaiveBatchedExperts, ) from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input @@ -87,46 +86,6 @@ def batched_moe( return fused_experts(a, w1, w2, topk_weight, topk_ids) -def naive_batched_moe( - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, - w1_scale: torch.Tensor | None = None, - w2_scale: torch.Tensor | None = None, - a1_scale: torch.Tensor | None = None, - a2_scale: torch.Tensor | None = None, - quant_dtype: torch.dtype | None = None, - per_act_token_quant: bool = False, - block_shape: list[int] | None = None, -) -> torch.Tensor: - max_num_tokens = round_up(a.shape[0], 64) - - quant_config = FusedMoEQuantConfig.make( - quant_dtype, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - ) - - fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize( - max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0 - ), - NaiveBatchedExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=1, - quant_config=quant_config, - ), - ) - - return fused_experts(a, w1, w2, topk_weight, topk_ids) - - def chunk_scales( scales: torch.Tensor | None, start: int, end: int ) -> torch.Tensor | None: diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 40fd6ae3a817..5e7551797ff0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -625,152 +625,6 @@ def finalize( ) -class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): - """ - A reference MoE expert class that operates on expert batched format, - i.e. E x max_num_tokens x K. This is the format that the pplx - dispatch/combine kernels use. - """ - - def __init__( - self, - max_num_tokens: int, - num_dispatchers: int, - quant_config: FusedMoEQuantConfig, - ): - super().__init__(quant_config) - assert not self.quant_config.use_int8_w8a8, "NYI" - assert not self.quant_config.use_int8_w8a16, "NYI" - assert not self.quant_config.use_int4_w4a16, "NYI" - assert self.quant_config.ocp_mx_scheme is None, "NYI" - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers - - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) - - def supports_current_device(self) -> bool: - return current_platform.is_cuda_alike() - - def supports_no_act_and_mul(self) -> bool: - return False - - def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: - return ( - quant_config.use_fp8_w8a8 and - quant_config.is_block_quantized() and - quant_config.block_shape[0] == 128 and - quant_config.block_shape[1] == 128 - ) - - def supports_act_fn(self, activation: str) -> bool: - return activation in ["silu", "gelu", "swigluoai"] - - def supports_ep(self) -> bool: - return True - - def supports_chunking(self) -> bool: - return False - - def supports_expert_map(self) -> bool: - return False - - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() - - def workspace_shapes( - self, - M: int, - N: int, - K: int, - topk: int, - global_num_experts: int, - local_num_experts: int, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - num_dp = self.num_dispatchers - num_experts = local_num_experts - workspace13 = (num_experts, self.max_num_tokens * num_dp, K) - workspace2 = (self.max_num_tokens * num_dp, N) - output = workspace13 - return (workspace13, workspace2, output) - - def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: - assert self.quant_config.is_quantized - f32 = torch.float32 - if self.quant_config.is_per_act_token or self.quant_config.is_per_tensor: - return t.to(f32) * scale - else: - return t.to(f32) * group_broadcast(scale, t.shape) - - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: torch.Tensor | None, - a1q_scale: torch.Tensor | None, - a2_scale: torch.Tensor | None, - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - apply_router_weight_on_input: bool, - ): - assert hidden_states.dim() == 3 - assert expert_tokens_meta is not None - expert_num_tokens = expert_tokens_meta.expert_num_tokens - - num_local_experts = w1.size(0) - assert num_local_experts == w1.size(0), f"{num_local_experts} == {w1.size(0)}" - - N = w1.size(1) // 2 - - for expert in range(num_local_experts): - # Indexing expert_num_tokens doesn't work w/cudagraphs or inductor - if ( - torch.compiler.is_compiling() - or torch.cuda.is_current_stream_capturing() - ): - num = hidden_states.shape[1] - else: - num = int(expert_num_tokens[expert].item()) - - if num == 0: - continue - - tmp = _resize_cache(workspace2, (num, N)) - - if self.quant_config.is_quantized: - assert a1q_scale is not None and self.w1_scale is not None - input = self.dequant(hidden_states[expert, :, :], a1q_scale[expert]) - w1_dq = self.dequant(w1[expert], self.w1_scale[expert]) - input = input[:num] @ w1_dq.transpose(0, 1) - else: - input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) - - self.activation(activation, tmp, input.to(tmp.dtype)) - - if self.quant_config.is_quantized: - assert self.w2_scale is not None - w2_dq = self.dequant(w2[expert], self.w2_scale[expert]) - else: - w2_dq = w2[expert] - - output[expert, :num, :] = tmp @ w2_dq.transpose(0, 1).to(tmp.dtype) - - def batched_moe_kernel_quantize_input( A: torch.Tensor, A_scale: torch.Tensor | None, @@ -868,6 +722,31 @@ def activation_formats( mk.FusedMoEActivationFormat.BatchedExperts, mk.FusedMoEActivationFormat.BatchedExperts, ) + + def supports_current_device(self) -> bool: + return current_platform.is_cuda_alike() + + def supports_no_act_and_mul(self) -> bool: + return False + + def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + # Supports unquantized and fp8. + # TODO(rob): allow int4 (for kimi?) + if not ( + quant_config.use_fp8_w8a8 or + quant_config.quant_dtype == None # TODO: how to express unquantized? + ): + return False + + if quant_config.use_fp8_w8a8: + return (current_platform.is_rocm or + current_platform.has_device_capability(9,0)) + + def supports_act_fn(self, activation: str) -> bool: + return activation in ["silu", "gelu", "swigluoai"] + + def supports_ep(self) -> bool: + return True def supports_chunking(self) -> bool: return False From ef5e664853f16e6dc3ae3164cfd923a133a1fd96 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 11 Jan 2026 10:08:22 -0500 Subject: [PATCH 019/113] stash Signed-off-by: Robert Shaw --- docs/design/moe_kernel_features.md | 1 - .../layers/fused_moe/fused_batched_moe.py | 2 +- .../layers/fused_moe/fused_marlin_moe.py | 22 +++++++++++++++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 5b35b152ff38..0c48aaa03dbd 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -95,7 +95,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels | iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | | rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] | | cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] | -| naive batched4 | batched | int8,
fp8 | G,A,T | silu, gelu | 6 | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] | !!! info "Table key" 1. All types: mxfp4, nvfp4, int4, int8, fp8 diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 5e7551797ff0..3fa50d91254b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -731,7 +731,7 @@ def supports_no_act_and_mul(self) -> bool: def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: # Supports unquantized and fp8. - # TODO(rob): allow int4 (for kimi?) + # TODO(rob): allow int4 (for kimi --- no, we have marlinexperts for this. if not ( quant_config.use_fp8_w8a8 or quant_config.quant_dtype == None # TODO: how to express unquantized? diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index e82a838959de..7739e8f102f3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -551,6 +551,28 @@ def __init__( self.is_k_full = is_k_full super().__init__(quant_config) + def supports_current_device(self) -> bool: + return current_platform.is_cuda() and current_platform.has_device_capability(8,0) + + def supports_no_act_and_mul(self) -> bool: + return False + + def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + # TODO(rob): check if we support the Fp8 activation + return ( + quant_config.use_fp8_w8a16 or + quant_config.use_int8_w8a16 or + quant_config.use_int4_w4a16 or + quant_config.use_nvfp4_w4a16 or + quant_config.use_mxfp4_w4a16 + ) + + def supports_act_fn(self, activation: str) -> bool: + return activation in ["silu", "swigluoai"] + + def supports_ep(self) -> bool: + return True + @property def quant_type_id(self) -> int: # uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4 From f6e85bc07682f9589ed9a2d9064174893f07ca52 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 11 Jan 2026 16:41:52 -0500 Subject: [PATCH 020/113] stash Signed-off-by: Robert Shaw --- docs/design/moe_kernel_features.md | 1 - tests/kernels/moe/test_moe.py | 58 +++++++++++++++++- .../layers/fused_moe/fused_moe.py | 26 ++++++++ .../fused_moe/gpt_oss_triton_kernels_moe.py | 20 ++++++- .../layers/fused_moe/moe_torch_iterative.py | 60 ------------------- 5 files changed, 99 insertions(+), 66 deletions(-) delete mode 100644 vllm/model_executor/layers/fused_moe/moe_torch_iterative.py diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 0c48aaa03dbd..4070705d8c7a 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -92,7 +92,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | marlin | standard,
batched | 3 / N/A | 3 / N/A | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | -| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | | rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] | | cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] | diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index b58c42b7d3f2..e62eef688074 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -39,9 +39,6 @@ fused_topk, modular_triton_fused_moe, ) -from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( - fused_moe as iterative_moe, -) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_permute_bias, ) @@ -152,6 +149,61 @@ vllm_config = VllmConfig() +def iterative_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + global_num_experts: int, + expert_map: torch.Tensor = None, + renormalize: bool = False, +) -> torch.Tensor: + """ + Args: + hidden_states: [*, hidden_size] + w1: [num_experts, intermediate_size * 2, hidden_size] + w2: [num_experts, hidden_size, intermediate_size] + gating_output: [*, num_experts] + expert_map: [num_experts] + """ + orig_shape = hidden_states.shape + hidden_size = hidden_states.shape[-1] + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + intermediate_size = w2.shape[-1] + dtype = hidden_states.dtype + + hidden_states = hidden_states.view(num_tokens, hidden_size) + gating_output = gating_output.view(num_tokens, global_num_experts) + topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) + topk_weights, selected_experts = topk_weights.topk(topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(dtype) + + if expert_map is not None: + selected_experts = expert_map[selected_experts] + + final_hidden_states = None + for expert_idx in range(num_experts): + expert_w1 = w1[expert_idx] + expert_w2 = w2[expert_idx] + expert_mask = selected_experts == expert_idx + expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True) + x = F.linear(hidden_states, expert_w1) + gate = F.silu(x[:, :intermediate_size]) + x = x[:, intermediate_size:] * gate + x = F.linear(x, expert_w2) + current_hidden_states = x * expert_weights + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states = final_hidden_states + current_hidden_states + + return final_hidden_states.view(orig_shape) # type: ignore + + def run_moe_test( baseline: Callable | torch.Tensor, moe_fn: Callable, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 735c4aa8fcd1..76dcf75025a0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2321,6 +2321,32 @@ def activation_formats( mk.FusedMoEActivationFormat.Standard, ) + def supports_current_device(self) -> bool: + return current_platform.is_cuda_alike() + + def supports_no_act_and_mul(self) -> bool: + return False + + def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + # Supports unquantized and fp8. + # TODO(rob): allow int4 (for kimi --- no, we have marlinexperts for this. + if not ( + quant_config.use_fp8_w8a8 + or quant_config.quant_dtype == None # TODO: how to express unquantized? + ): + return False + + if quant_config.use_fp8_w8a8: + return current_platform.is_rocm or current_platform.has_device_capability( + 9, 0 + ) + + def supports_act_fn(self, activation: str) -> bool: + return activation in ["silu", "gelu", "swigluoai"] + + def supports_ep(self) -> bool: + return True + def supports_chunking(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index dff8a9f3a8f0..21836547d88f 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -15,6 +15,7 @@ TopKWeightAndReduceNoOP, ) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.import_utils import has_triton_kernels @@ -241,8 +242,23 @@ def make_routing_data( class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config: FusedMoEQuantConfig): - super().__init__(quant_config) + def supports_current_device(self) -> bool: + if current_platform.is_cuda(): + return current_platform.has_device_capability(9, 0) + else: + return current_platform.is_cuda_alike() + + def supports_no_act_and_mul(self) -> bool: + return False + + def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + return quant_config.use_mxfp4_w4a16 + + def supports_act_fn(self, activation: str) -> bool: + return activation in ["swigluoai"] + + def supports_ep(self) -> bool: + return True def supports_expert_map(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py deleted file mode 100644 index f721d00d75ea..000000000000 --- a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch -import torch.nn.functional as F - - -def fused_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - global_num_experts: int, - expert_map: torch.Tensor = None, - renormalize: bool = False, -) -> torch.Tensor: - """ - Args: - hidden_states: [*, hidden_size] - w1: [num_experts, intermediate_size * 2, hidden_size] - w2: [num_experts, hidden_size, intermediate_size] - gating_output: [*, num_experts] - expert_map: [num_experts] - """ - orig_shape = hidden_states.shape - hidden_size = hidden_states.shape[-1] - num_tokens = hidden_states.shape[:-1].numel() - num_experts = w1.shape[0] - intermediate_size = w2.shape[-1] - dtype = hidden_states.dtype - - hidden_states = hidden_states.view(num_tokens, hidden_size) - gating_output = gating_output.view(num_tokens, global_num_experts) - topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) - topk_weights, selected_experts = topk_weights.topk(topk, dim=-1) - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - topk_weights = topk_weights.to(dtype) - - if expert_map is not None: - selected_experts = expert_map[selected_experts] - - final_hidden_states = None - for expert_idx in range(num_experts): - expert_w1 = w1[expert_idx] - expert_w2 = w2[expert_idx] - expert_mask = selected_experts == expert_idx - expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True) - x = F.linear(hidden_states, expert_w1) - gate = F.silu(x[:, :intermediate_size]) - x = x[:, intermediate_size:] * gate - x = F.linear(x, expert_w2) - current_hidden_states = x * expert_weights - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states = final_hidden_states + current_hidden_states - - return final_hidden_states.view(orig_shape) # type: ignore From 0db0b1182ce855c4c42d122d5b516ced8464f76a Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 11 Jan 2026 16:49:02 -0500 Subject: [PATCH 021/113] added back moe torch iterative Signed-off-by: Robert Shaw --- tests/kernels/moe/test_moe.py | 58 ++--------------------------------- 1 file changed, 3 insertions(+), 55 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index e62eef688074..d3b422e035a7 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -18,6 +18,9 @@ from torch.nn import functional as F from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( + fused_moe as iterative_moe, +) import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.moe.utils import fused_moe @@ -149,61 +152,6 @@ vllm_config = VllmConfig() -def iterative_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - global_num_experts: int, - expert_map: torch.Tensor = None, - renormalize: bool = False, -) -> torch.Tensor: - """ - Args: - hidden_states: [*, hidden_size] - w1: [num_experts, intermediate_size * 2, hidden_size] - w2: [num_experts, hidden_size, intermediate_size] - gating_output: [*, num_experts] - expert_map: [num_experts] - """ - orig_shape = hidden_states.shape - hidden_size = hidden_states.shape[-1] - num_tokens = hidden_states.shape[:-1].numel() - num_experts = w1.shape[0] - intermediate_size = w2.shape[-1] - dtype = hidden_states.dtype - - hidden_states = hidden_states.view(num_tokens, hidden_size) - gating_output = gating_output.view(num_tokens, global_num_experts) - topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) - topk_weights, selected_experts = topk_weights.topk(topk, dim=-1) - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - topk_weights = topk_weights.to(dtype) - - if expert_map is not None: - selected_experts = expert_map[selected_experts] - - final_hidden_states = None - for expert_idx in range(num_experts): - expert_w1 = w1[expert_idx] - expert_w2 = w2[expert_idx] - expert_mask = selected_experts == expert_idx - expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True) - x = F.linear(hidden_states, expert_w1) - gate = F.silu(x[:, :intermediate_size]) - x = x[:, intermediate_size:] * gate - x = F.linear(x, expert_w2) - current_hidden_states = x * expert_weights - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states = final_hidden_states + current_hidden_states - - return final_hidden_states.view(orig_shape) # type: ignore - - def run_moe_test( baseline: Callable | torch.Tensor, moe_fn: Callable, From 09dc4f51df8c3f182cf5417ec392895a6c23f2d6 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 11 Jan 2026 16:52:21 -0500 Subject: [PATCH 022/113] revert changes Signed-off-by: Robert Shaw --- tests/kernels/moe/test_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index d3b422e035a7..b58c42b7d3f2 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -18,9 +18,6 @@ from torch.nn import functional as F from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock -from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( - fused_moe as iterative_moe, -) import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.moe.utils import fused_moe @@ -42,6 +39,9 @@ fused_topk, modular_triton_fused_moe, ) +from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( + fused_moe as iterative_moe, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_permute_bias, ) From 3311e9ef099bdef1e2e9b463cdbf1226bfa3fd15 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 11 Jan 2026 16:52:36 -0500 Subject: [PATCH 023/113] re add Signed-off-by: Robert Shaw --- .../layers/fused_moe/moe_torch_iterative.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/moe_torch_iterative.py diff --git a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py new file mode 100644 index 000000000000..f721d00d75ea --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn.functional as F + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + global_num_experts: int, + expert_map: torch.Tensor = None, + renormalize: bool = False, +) -> torch.Tensor: + """ + Args: + hidden_states: [*, hidden_size] + w1: [num_experts, intermediate_size * 2, hidden_size] + w2: [num_experts, hidden_size, intermediate_size] + gating_output: [*, num_experts] + expert_map: [num_experts] + """ + orig_shape = hidden_states.shape + hidden_size = hidden_states.shape[-1] + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + intermediate_size = w2.shape[-1] + dtype = hidden_states.dtype + + hidden_states = hidden_states.view(num_tokens, hidden_size) + gating_output = gating_output.view(num_tokens, global_num_experts) + topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) + topk_weights, selected_experts = topk_weights.topk(topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(dtype) + + if expert_map is not None: + selected_experts = expert_map[selected_experts] + + final_hidden_states = None + for expert_idx in range(num_experts): + expert_w1 = w1[expert_idx] + expert_w2 = w2[expert_idx] + expert_mask = selected_experts == expert_idx + expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True) + x = F.linear(hidden_states, expert_w1) + gate = F.silu(x[:, :intermediate_size]) + x = x[:, intermediate_size:] * gate + x = F.linear(x, expert_w2) + current_hidden_states = x * expert_weights + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states = final_hidden_states + current_hidden_states + + return final_hidden_states.view(orig_shape) # type: ignore From 755a3a2d846da4691697837216149a8fad289149 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 11 Jan 2026 16:53:13 -0500 Subject: [PATCH 024/113] add back iterative Signed-off-by: Robert Shaw --- docs/design/moe_kernel_features.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 4070705d8c7a..0c48aaa03dbd 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -92,6 +92,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | marlin | standard,
batched | 3 / N/A | 3 / N/A | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | +| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | | rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] | | cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] | From a8bb9d0498fef3a7abd0dabded45ff17e4799057 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 11 Jan 2026 17:06:05 -0500 Subject: [PATCH 025/113] add methodology to kernels Signed-off-by: Robert Shaw --- docs/design/moe_kernel_features.md | 2 +- .../model_executor/layers/fused_moe/config.py | 4 + .../layers/fused_moe/rocm_aiter_fused_moe.py | 20 +++ .../layers/fused_moe/trtllm_moe.py | 143 ------------------ .../layers/quantization/mxfp4.py | 4 +- 5 files changed, 27 insertions(+), 146 deletions(-) delete mode 100644 vllm/model_executor/layers/fused_moe/trtllm_moe.py diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 0c48aaa03dbd..abd30041e44a 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -91,7 +91,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels | flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | marlin | standard,
batched | 3 / N/A | 3 / N/A | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | -| trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | +| trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtllmMxFp4Experts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtllmMxFp4Experts] | | iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | | rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] | | cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] | diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 23b86fdca898..c7c52e66e573 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -358,6 +358,10 @@ def ocp_mx_scheme(self) -> str | None: def use_mxfp4_w4a16(self) -> bool: return self._a1.dtype is None and self._w1.dtype == "mxfp4" + @property + def use_mxfp4_w4a8(self) -> bool: + return self._a1.dtype == "mxfp8" and self._w1.dtype == "mxfp4" + @property def use_mxfp4_w4a4(self) -> bool: return self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4" diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 06707e5e4892..85b14507816d 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.platforms import current_platform class QuantMethod(IntEnum): @@ -281,6 +282,25 @@ def activation_formats( mk.FusedMoEActivationFormat.Standard, ) + def supports_current_device(self) -> bool: + return current_platform.is_rocm() + + def supports_no_act_and_mul(self) -> bool: + return False + + def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + return ( + quant_config.use_fp8_w8a8 + or quant_config.use_mxfp4_w4a4 + or False # TODO: how to represent unquantizes? + ) + + def supports_act_fn(self, activation: str) -> bool: + return activation in ["silu", "gelu"] + + def supports_ep(self) -> bool: + return True + def supports_expert_map(self): return True diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py deleted file mode 100644 index 132d35e65aba..000000000000 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ /dev/null @@ -1,143 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, - FusedMoEQuantConfig, -) -from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP, -) - - -class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( - self, - moe: FusedMoEConfig, - quant_config: FusedMoEQuantConfig, - gemm1_alpha, - gemm1_beta, - gemm1_clamp_limit, - max_capture_size, - ): - super().__init__(quant_config) - self.moe = moe - self.gemm1_alpha = gemm1_alpha - self.gemm1_beta = gemm1_beta - self.gemm1_clamp_limit = gemm1_clamp_limit - self.max_capture_size = max_capture_size - - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) - - def supports_chunking(self) -> bool: - return True - - def supports_expert_map(self) -> bool: - return True - - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - return TopKWeightAndReduceNoOP() - - def workspace_shapes( - self, - M: int, - N: int, - K: int, - topk: int, - global_num_experts: int, - local_num_experts: int, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - # The workspaces for this implementation are managed by flashinfer. - workspace1 = (0,) - workspace2 = (0,) - output = (M, K) - return (workspace1, workspace2, output) - - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: torch.Tensor | None, - a1q_scale: torch.Tensor | None, - a2_scale: torch.Tensor | None, - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - apply_router_weight_on_input: bool, - ): - topk = topk_ids.size(-1) - local_num_experts = w1.size(0) - intermediate_size = w2.size(1) - local_expert_offset = self.moe.ep_rank * local_num_experts - - x_quant = hidden_states - x_scale = a1q_scale - if x_scale is not None: - x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x_quant.shape[:-1], -1) - - packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( - torch.bfloat16 - ).view(torch.int16) - - assert self.w1_scale is not None - assert self.w2_scale is not None - kwargs = { - "topk_ids": packed_tensor, - "routing_bias": None, - "hidden_states": x_quant, - "hidden_states_scale": x_scale, - "gemm1_weights": w1, - "gemm1_weights_scale": self.w1_scale, - "gemm1_bias": self.w1_bias, - "gemm1_alpha": self.gemm1_alpha, - "gemm1_beta": self.gemm1_beta, - "gemm1_clamp_limit": self.gemm1_clamp_limit, - "gemm2_weights": w2, - "gemm2_weights_scale": self.w2_scale, - "gemm2_bias": self.w2_bias, - "output1_scale_scalar": None, - "output1_scale_gate_scalar": None, - "output2_scale_scalar": None, - "num_experts": global_num_experts, - "top_k": topk, - "n_group": None, - "topk_group": None, - "intermediate_size": intermediate_size, - "local_expert_offset": local_expert_offset, - "local_num_experts": local_num_experts, - "routed_scaling_factor": None, - "tile_tokens_dim": None, - "routing_method_type": 1, - "do_finalize": True, - "output": output, - "tune_max_num_tokens": max(self.max_capture_size, 1), - } - - from flashinfer import trtllm_fp4_block_scale_routed_moe - - from vllm.utils.flashinfer import autotune - - with autotune(False): - # Enable autotune when, - # https://github.com/flashinfer-ai/flashinfer/issues/2023 is - # resolved. - trtllm_fp4_block_scale_routed_moe(**kwargs) - - return output diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 8e050b795f94..3fcc61878b3c 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -32,7 +32,7 @@ OAITritonExperts, UnfusedOAITritonExperts, ) -from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts +from vllm.model_executor.layers.fused_moe.trtllm_mxfp4_moe import TrtllmMxFp4Experts from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( @@ -873,7 +873,7 @@ def select_gemm_impl( # TODO(bnell): part of quant_config "max_capture_size": self.max_capture_size, } - return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs) + return TrtllmMxFp4Experts(self.moe, self.moe_quant_config, **kwargs) elif self.mxfp4_backend == Mxfp4Backend.MARLIN: return MarlinExperts(self.moe_quant_config) elif self.mxfp4_backend == Mxfp4Backend.TRITON: From eb571a2c418b0f63e1a3152c86e309226863a61b Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 11:05:49 -0500 Subject: [PATCH 026/113] restructure kernel selection logic Signed-off-by: Robert Shaw --- docs/design/fused_moe_modular_kernel.md | 2 +- .../moe/modular_kernel_tools/common.py | 7 +- tests/kernels/moe/test_flashinfer_moe.py | 3 - .../layers/fused_moe/batched_deep_gemm_moe.py | 49 ++- .../model_executor/layers/fused_moe/config.py | 48 +++ .../layers/fused_moe/cutlass_moe.py | 127 ++++---- .../layers/fused_moe/deep_gemm_moe.py | 44 +-- .../layers/fused_moe/fallback.py | 9 - .../fused_moe/flashinfer_cutedsl_moe.py | 71 ++--- .../fused_moe/flashinfer_cutlass_moe.py | 109 +++---- .../layers/fused_moe/flashinfer_trtllm_moe.py | 59 +++- .../layers/fused_moe/fused_batched_moe.py | 57 ++-- .../layers/fused_moe/fused_marlin_moe.py | 64 ++-- .../layers/fused_moe/fused_moe.py | 40 +-- .../fused_moe/gpt_oss_triton_kernels_moe.py | 42 ++- vllm/model_executor/layers/fused_moe/layer.py | 1 + .../layers/fused_moe/modular_kernel.py | 48 ++- .../layers/fused_moe/oracle/fp8.py | 283 ++++++++++++------ .../layers/fused_moe/rocm_aiter_fused_moe.py | 34 ++- .../layers/fused_moe/triton_cutlass_moe.py | 30 +- .../layers/fused_moe/triton_deep_gemm_moe.py | 31 +- .../compressed_tensors_moe.py | 16 +- .../model_executor/layers/quantization/fp8.py | 42 ++- .../layers/quantization/modelopt.py | 12 +- 24 files changed, 727 insertions(+), 501 deletions(-) diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index e1a96be6c344..39898189c440 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -166,7 +166,7 @@ We suggest picking an already existing `FusedMoEPrepareAndFinalize` implementati FusedMoEPermuteExpertsUnpermute performs the core of the FusedMoE operations. The various functions exposed by the abstract class and their significance is as follows, -`FusedMoEPermuteExpertsUnpermute::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format. +`FusedMoEPermuteExpertsUnpermute::activation_format()`: Return the supported activation formats. i.e. Standard / Batched (MaskedGEMM) format. `FusedMoEPermuteExpertsUnpermute::supports_chunking()`: Return True if the implementation supports chunking. Typically implementations that input `FusedMoEActivationFormat.Standard` support chunking and `FusedMoEActivationFormat.BatchedExperts` do not. diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index 537dcae4e74b..52906c043df0 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -161,15 +161,15 @@ def is_fp8_block_quantized(self): def is_batched_prepare_finalize(self): info = prepare_finalize_info(self.prepare_finalize_type) - return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format + return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format() def is_batched_fused_experts(self): info = expert_info(self.fused_experts_type) - return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format + return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format() def is_standard_fused_experts(self): info = expert_info(self.fused_experts_type) - return mk.FusedMoEActivationFormat.Standard == info.activation_format + return mk.FusedMoEActivationFormat.Standard == info.activation_format() def fe_supported_types(self): info = expert_info(self.fused_experts_type) @@ -578,6 +578,7 @@ def next_power_of_2(x): moe_parallel_config=moe_parallel_config, in_dtype=config.dtype, max_num_tokens=next_power_of_2(config.M), + activation="silu", ) # make modular kernel diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index 1262eea70bab..7739d0040e07 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -14,7 +14,6 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, - is_valid_flashinfer_cutlass_fused_moe, ) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( create_flashinfer_prepare_finalize, @@ -84,8 +83,6 @@ def test_flashinfer_fp4_moe_no_graph( score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) - assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) - flashinfer_experts = FusedMoEModularKernel( create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True), FlashInferExperts(out_dtype=dtype, quant_config=quant_config), diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index fa2724f2ff59..ca1317c55c25 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -4,11 +4,14 @@ import torch -from vllm.platforms import current_platform import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, + FusedMoEQuantConfig, + FusedMoEQuantScheme, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, ) @@ -269,36 +272,30 @@ def __init__( self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) - - def supports_current_device(self) -> bool: - return ( - current_platform.is_cuda() and - current_platform.has_device_capability(9,0) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + @staticmethod + def _supports_current_device() -> bool: + return current_platform.is_cuda() and current_platform.has_device_capability( + 9, 0 ) - - def supports_no_act_and_mul(self) -> bool: + + @staticmethod + def _supports_no_act_and_mul() -> bool: return False - def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: - return ( - quant_config.use_fp8_w8a8 and - quant_config.is_block_quantized() and - quant_config.block_shape[0] == 128 and - quant_config.block_shape[1] == 128 - ) + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + return quant_scheme.is_fp8_w8a8 and quant_scheme.block_size == [128, 128] - def supports_act_fn(self, activation: str) -> bool: + @staticmethod + def _supports_activation(activation: str) -> bool: return activation in ["silu"] - def supports_ep(self) -> bool: + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True def supports_chunking(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index c7c52e66e573..b7a6ede50157 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -121,6 +121,52 @@ class RoutingMethodType(IntEnum): Unspecified = 6.0 +UNQUANTIZED_DTYPES = [torch.bfloat16, torch.float16, torch.float32] + + +@dataclass +class FusedMoEQuantScheme: + weight_dtype: torch.dtype | str | None + act_dtype: torch.dtype | str | None + per_token_quant: bool + per_tensor_quant: bool + block_size: tuple[int, int] | None + + def __post_init__(self): + if self.per_tensor_quant: + assert not self.per_token_quant + assert not self.per_block_quant + elif self.per_token_quant: + assert not self.per_tensor_quant + assert not self.per_block_quant + elif self.per_block_quant: + assert not self.per_tensor_quant + assert not self.per_token_quant + assert self.block_size is not None + assert self.per_block_quant or self.per_token_quant or self.per_tensor_quant + if self.is_unquantized: + assert self.act_dtype in UNQUANTIZED_DTYPES + + @property + def per_block_quant(self) -> bool: + return self.block_size is not None + + @property + def is_unquantized(self) -> bool: + return self.weight_dtype in UNQUANTIZED_DTYPES + + @property + def is_fp8_w8a8(self) -> bool: + return ( + self.weight_dtype == current_platform.fp8_dtype() + and self.act_dtype == current_platform.fp8_dtype() + ) + + @property + def is_nvfp4_w4a4(self) -> bool: + return self.weight_dtype == "nvfp4" and self.act_dtype == "nvfp4" + + @dataclass class FusedMoEQuantDesc: """ @@ -1044,6 +1090,8 @@ class FusedMoEConfig: is_lora_enabled: bool = False + activation: str = "silu" + def __post_init__(self): if self.dp_size > 1: logger.debug_once( diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index fdac768da8f9..b3a484f9f3a3 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -9,7 +9,11 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, + FusedMoEQuantConfig, + FusedMoEQuantScheme, +) from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( moe_permute, moe_unpermute, @@ -22,6 +26,7 @@ TopKWeightAndReduceNoOP, ) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.platforms import current_platform from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -259,6 +264,33 @@ def __init__( self.c_strides1 = c_strides1 self.c_strides2 = ab_strides1_c_strides2 + @staticmethod + def _supports_current_device() -> bool: + return current_platform.is_cuda() and current_platform.has_device_capability( + 9, 0 + ) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + if not current_platform.has_device_capability((9, 0)): + return False + + return quant_scheme.is_fp8_w8a8 and ( + quant_scheme.per_tensor_quant or quant_scheme.per_token_quant + ) + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceDelegate() @@ -291,7 +323,7 @@ def apply( activation_callable = lambda o, i: self.activation(activation, o, i) use_batched_format = ( - self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts + self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts ) in_dtype = hidden_states.dtype @@ -324,14 +356,9 @@ def apply( class CutlassExpertsFp8(CutlassExpertsFp8Base): - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_chunking(self) -> bool: return True @@ -375,14 +402,9 @@ def __init__( self.max_experts_per_worker = max_experts_per_worker self.num_dispatchers = num_dispatchers - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts def supports_chunking(self) -> bool: return False @@ -592,27 +614,14 @@ def __init__( max_experts_per_worker: int, out_dtype: torch.dtype, quant_config: FusedMoEQuantConfig, - use_batched_format: bool = False, ): super().__init__(quant_config) self.max_experts_per_worker = max_experts_per_worker self.out_dtype = out_dtype - self.use_batched_format = use_batched_format - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - if self.use_batched_format: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) - else: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_expert_map(self) -> bool: return False @@ -636,17 +645,9 @@ def workspace_shapes( local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - workspace1: tuple[int, ...] = () - workspace2: tuple[int, ...] = () - output: tuple[int, ...] = () - if self.use_batched_format: - workspace1 = (self.max_experts_per_worker, M, max(N, K)) - workspace2 = (self.max_experts_per_worker, M, (N // 2)) - output = (self.max_experts_per_worker, M, K) - else: - workspace1 = (M * topk, max(2 * N, K)) - workspace2 = (M * topk, N) - output = (M, K) + workspace1 = (M * topk, max(2 * N, K)) + workspace2 = (M * topk, N) + output = (M, K) return (workspace1, workspace2, output) def apply( @@ -865,15 +866,35 @@ def __init__( self.s_strides2 = s_strides2 self.group_size = group_size - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + return current_platform.is_cuda() and current_platform.has_device_capability( + (9, 0) ) + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + return True + # quant_scheme.is_fp8_w8a8 + # # or quant_scheme.is_mxfp4_w4a4 + # or quant_scheme.is_unquantized + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu", "gelu"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True + def supports_chunking(self) -> bool: return True @@ -927,7 +948,7 @@ def apply( activation_callable = lambda o, i: self.activation(activation, o, i) use_batched_format = ( - self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts + self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts ) assert not use_batched_format, "batched format not supported" diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 044617005e4c..dbbeecdca02c 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -3,11 +3,12 @@ import torch -from vllm.platforms import current_platform import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, FusedMoEQuantConfig, + FusedMoEQuantScheme, fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( @@ -27,6 +28,7 @@ per_token_group_quant_fp8_packed_for_deepgemm, silu_mul_per_token_group_quant_fp8_colmajor, ) +from vllm.platforms import current_platform from vllm.utils.deep_gemm import ( DeepGemmQuantScaleFMT, get_mk_alignment_for_contiguous_layout, @@ -117,36 +119,34 @@ def __init__(self, quant_config: FusedMoEQuantConfig): assert not quant_config.per_act_token_quant assert not quant_config.per_out_ch_quant - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard - def supports_current_device(self) -> bool: - return ( - current_platform.is_cuda() and - current_platform.has_device_capability(9,0) + @staticmethod + def _supports_current_device() -> bool: + return current_platform.is_cuda() and current_platform.has_device_capability( + 9, 0 ) - - def supports_no_act_and_mul(self) -> bool: + + @staticmethod + def _supports_no_act_and_mul() -> bool: return False - def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: return ( - quant_config.use_fp8_w8a8 and - quant_config.is_block_quantized() and - quant_config.block_shape[0] == 128 and - quant_config.block_shape[1] == 128 + quant_scheme.is_fp8_w8a8 + and quant_scheme.per_block_quant + and quant_scheme.block_size == [128, 128] ) - def supports_act_fn(self, activation: str) -> bool: + @staticmethod + def _supports_activation(activation: str) -> bool: return activation in ["silu"] - def supports_ep(self) -> bool: + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True def supports_chunking(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py index 14ef6b9aaa5e..e4c5b149deee 100644 --- a/vllm/model_executor/layers/fused_moe/fallback.py +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -20,15 +20,6 @@ def __init__( self.fallback_experts = fallback_experts self.experts = experts - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - assert ( - self.fallback_experts.activation_formats == self.experts.activation_formats - ) - return self.fallback_experts.activation_formats - def supports_chunking(self) -> bool: assert ( self.experts.supports_chunking() diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index 7b08da7fa6f6..475f849720f8 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -3,14 +3,18 @@ import torch -from vllm.platforms import current_platform import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, + FusedMoEQuantConfig, + FusedMoEQuantScheme, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, ) +from vllm.platforms import current_platform from vllm.utils.flashinfer import ( flashinfer_cutedsl_grouped_gemm_nt_masked, has_flashinfer_cutedsl_grouped_gemm_nt_masked, @@ -21,34 +25,6 @@ logger = init_logger(__name__) -def is_valid_flashinfer_cutedsl_fused_moe( - hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor -) -> bool: - """ - Check if the given problem size is supported by the FlashInfer CuteDSL MoE - kernel. - """ - if not has_flashinfer_cutedsl_grouped_gemm_nt_masked(): - logger.debug_once( - "FlashInferCuteDSLExperts disabled: " - "flashinfer_cutedsl_fused_moe not available." - ) - return False - # Data type checks - if ( - w1.dtype != torch.uint8 - or w2.dtype != torch.uint8 - or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16] - ): - logger.debug_once( - "FlashInferCuteDSLExperts disabled: w1/w2 must be torch.uint8 " - f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be " - f"float32, float16, or bfloat16 (got {hidden_states.dtype})." - ) - return False - return True - - class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, @@ -61,31 +37,32 @@ def __init__( ) self.out_dtype = out_dtype - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts - def supports_current_device(self) -> bool: + @staticmethod + def _supports_current_device() -> bool: return ( - current_platform.is_cuda() and - current_platform.has_device_capability(10,0) + current_platform.is_cuda() + and current_platform.has_device_capability(10, 0) + and has_flashinfer_cutedsl_grouped_gemm_nt_masked() ) - - def supports_no_act_and_mul(self) -> bool: + + @staticmethod + def _supports_no_act_and_mul() -> bool: return False - def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: - return quant_config.use_nvfp4_w4a4 + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + return quant_scheme.is_nvfp4_w4a4 - def supports_act_fn(self, activation: str) -> bool: + @staticmethod + def _supports_activation(activation: str) -> bool: return activation in ["silu"] - def supports_ep(self) -> bool: + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True def supports_expert_map(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 231552f78994..a1a0c2f1c503 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -3,16 +3,17 @@ import torch -from vllm.platforms import current_platform import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - create_flashinfer_prepare_finalize, +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, + FusedMoEQuantConfig, + FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.platforms import current_platform from vllm.utils.flashinfer import ( flashinfer_cutlass_fused_moe, has_flashinfer_cutlass_fused_moe, @@ -21,33 +22,6 @@ logger = init_logger(__name__) -def is_valid_flashinfer_cutlass_fused_moe( - hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor -) -> bool: - """ - Check if the given problem size is supported by the FlashInfer CUTLASS MoE - kernel. - """ - if not has_flashinfer_cutlass_fused_moe(): - logger.debug_once( - "FlashInferExperts disabled: flashinfer_cutlass_fused_moe not available." - ) - return False - # Data type checks - if ( - w1.dtype != torch.uint8 - or w2.dtype != torch.uint8 - or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16] - ): - logger.debug_once( - "FlashInferExperts disabled: w1/w2 must be torch.uint8 " - f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be " - f"float32, float16, or bfloat16 (got {hidden_states.dtype})." - ) - return False - return True - - class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, @@ -75,52 +49,51 @@ def __init__( # - pass per-block weight scales to the kernel # - skip input activation quantization (kernel applies scaling) self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale - - def supports_current_device(self) -> bool: + + @staticmethod + def _supports_current_device() -> bool: return ( - current_platform.is_cuda() and - current_platform.has_device_capability(9,0) + current_platform.is_cuda() + and current_platform.has_device_capability((9, 0)) + and has_flashinfer_cutlass_fused_moe() ) - - def supports_no_act_and_mul(self) -> bool: + + @staticmethod + def _supports_no_act_and_mul() -> bool: return False - def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: - # Supports unquantized, fp8, and nvfp4. - if not ( - quant_config.use_nvfp4_w4a4 or - quant_config.use_fp8_w8a8 or - quant_config.quant_dtype == None # TODO: how to express unquantized? - ): - return False - - # For FP8, only support static per tensor or DeepGEMM SwapAB for hopper. - if quant_config.use_fp8_w8a8: - if quant_config.is_per_tensor: - return True - elif quant_config.is_per_act_token: - return False - elif quant_config.is_block_quantized: - if (current_platform.is_cuda and current_platform.is_device_capability(9,0) and quant_config.block_shape[0] == 128 and quant_config.block_shape[1] == 128): - return True - return False - - return True + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + # Supports: + # * unquantized + # * fp8 per-tensor on 9.0+ + # * fp8 block on 9.0 + # * nvfp4 on 10.0+ + return ( + (quant_scheme.is_unquantized) + or (quant_scheme.is_fp8_w8a8 and quant_scheme.per_tensor_quant) + or ( + quant_scheme.is_fp8_w8a8 + and quant_scheme.block_size == [128, 128] + and current_platform.is_device_capability((9, 0)) + ) + or ( + quant_scheme.is_nvfp4_w4a4 + and current_platform.has_device_capability((10, 0)) + ) + ) - def supports_act_fn(self, activation: str) -> bool: + @staticmethod + def _supports_activation(activation: str) -> bool: return activation in ["silu", "relu2_no_mul"] - def supports_ep(self) -> bool: + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_expert_map(self) -> bool: return False diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 3bb5a23abb7b..e575f4fdb49d 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -3,7 +3,12 @@ import torch -from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantScheme, + RoutingMethodType, +) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( calculate_tile_tokens_dim, @@ -11,8 +16,60 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) +from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op +# +# Methods used by the oracle for kernel selection. +# + + +def _supports_current_device() -> bool: + """Supports only Blackwell-family GPUs.""" + return current_platform.is_cuda() and current_platform.is_device_capability_family( + 10 + ) + + +def _supports_no_act_and_mul() -> bool: + """Does not support non-gated MoE (i.e. Nanotron-Mini).""" + return False + + +def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + """Supports Fp8 per-tensor, Fp8 block, and Nvfp4 quantization.""" + return ( + (quant_scheme.is_fp8_w8a8 and quant_scheme.per_tensor_quant) + or (quant_scheme.is_fp8_w8a8 and quant_scheme.block_size == [128, 128]) + or (quant_scheme.is_nvfp4_w4a4) + ) + + +def _supports_activation(activation: str) -> bool: + """Supports silu activation only.""" + return activation in ["silu"] + + +def _supports_moe_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + """Supports EP.""" + return True + + +def is_supported_config_trtllm( + moe_config: FusedMoEConfig, + moe_quant_scheme: FusedMoEQuantScheme, +) -> bool: + """ + This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config + """ + return ( + _supports_current_device() + and (not moe_config.is_act_and_mul or _supports_no_act_and_mul()) + and _supports_activation(moe_config.activation) + and _supports_quant_scheme(moe_quant_scheme) + and _supports_moe_parallel_config(moe_config.moe_parallel_config) + ) + def flashinfer_fused_moe_blockscale_fp8( routing_logits: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 3fa50d91254b..078847218de2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -4,9 +4,12 @@ import torch -from vllm.platforms import current_platform import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, + FusedMoEQuantConfig, + FusedMoEQuantScheme, +) from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, @@ -18,7 +21,7 @@ normalize_batched_scales_shape, normalize_scales_shape, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -714,38 +717,38 @@ def __init__( self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) - - def supports_current_device(self) -> bool: + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + @staticmethod + def _supports_current_device() -> bool: return current_platform.is_cuda_alike() - def supports_no_act_and_mul(self) -> bool: + @staticmethod + def _supports_no_act_and_mul() -> bool: return False - def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: # Supports unquantized and fp8. # TODO(rob): allow int4 (for kimi --- no, we have marlinexperts for this. - if not ( - quant_config.use_fp8_w8a8 or - quant_config.quant_dtype == None # TODO: how to express unquantized? - ): - return False - - if quant_config.use_fp8_w8a8: - return (current_platform.is_rocm or - current_platform.has_device_capability(9,0)) - - def supports_act_fn(self, activation: str) -> bool: + if quant_scheme.is_unquantized: + return True + + if quant_scheme.is_fp8_w8a8: + return current_platform.is_rocm() or current_platform.has_device_capability( + (9, 0) + ) + + return False + + @staticmethod + def _supports_activation(activation: str) -> bool: return activation in ["silu", "gelu", "swigluoai"] - def supports_ep(self) -> bool: + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True def supports_chunking(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 7739e8f102f3..4cb03eb18951 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -8,7 +8,11 @@ import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, + FusedMoEQuantConfig, + FusedMoEQuantScheme, +) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( batched_moe_align_block_size, moe_align_block_size, @@ -551,26 +555,34 @@ def __init__( self.is_k_full = is_k_full super().__init__(quant_config) - def supports_current_device(self) -> bool: - return current_platform.is_cuda() and current_platform.has_device_capability(8,0) + @staticmethod + def _supports_current_device() -> bool: + return current_platform.is_cuda() and current_platform.has_device_capability( + 8, 0 + ) - def supports_no_act_and_mul(self) -> bool: + @staticmethod + def _supports_no_act_and_mul() -> bool: return False - def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: # TODO(rob): check if we support the Fp8 activation - return ( - quant_config.use_fp8_w8a16 or - quant_config.use_int8_w8a16 or - quant_config.use_int4_w4a16 or - quant_config.use_nvfp4_w4a16 or - quant_config.use_mxfp4_w4a16 - ) - - def supports_act_fn(self, activation: str) -> bool: + return True + # return ( + # quant_scheme.use_fp8_w8a16 + # or quant_config.use_int8_w8a16 + # or quant_config.use_int4_w4a16 + # or quant_config.use_nvfp4_w4a16 + # or quant_config.use_mxfp4_w4a16 + # ) + + @staticmethod + def _supports_activation(activation: str) -> bool: return activation in ["silu", "swigluoai"] - def supports_ep(self) -> bool: + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True @property @@ -641,14 +653,9 @@ def supports_expert_map(self) -> bool: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_chunking(self) -> bool: return True @@ -769,14 +776,9 @@ def supports_expert_map(self) -> bool: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceDelegate() - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts def supports_chunking(self) -> bool: return False diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 76dcf75025a0..f1a245a920a6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -22,7 +22,9 @@ ) from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEParallelConfig, FusedMoEQuantConfig, + FusedMoEQuantScheme, _get_config_dtype_str, ) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( @@ -2312,39 +2314,37 @@ def __init__( ): super().__init__(quant_config) - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard - def supports_current_device(self) -> bool: + @staticmethod + def _supports_current_device() -> bool: return current_platform.is_cuda_alike() - def supports_no_act_and_mul(self) -> bool: + @staticmethod + def _supports_no_act_and_mul() -> bool: return False - def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: # Supports unquantized and fp8. # TODO(rob): allow int4 (for kimi --- no, we have marlinexperts for this. - if not ( - quant_config.use_fp8_w8a8 - or quant_config.quant_dtype == None # TODO: how to express unquantized? - ): + if not (quant_scheme.is_fp8_w8a8 or quant_scheme.is_unquantized): return False - if quant_config.use_fp8_w8a8: - return current_platform.is_rocm or current_platform.has_device_capability( - 9, 0 + if quant_scheme.is_fp8_w8a8: + return current_platform.is_rocm() or current_platform.has_device_capability( + (9, 0) ) + return False - def supports_act_fn(self, activation: str) -> bool: + @staticmethod + def _supports_activation(activation: str) -> bool: return activation in ["silu", "gelu", "swigluoai"] - def supports_ep(self) -> bool: + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True def supports_chunking(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 21836547d88f..3816399f4322 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -9,7 +9,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEParallelConfig, FusedMoEQuantConfig, + FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, @@ -242,22 +244,28 @@ def make_routing_data( class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): - def supports_current_device(self) -> bool: + @staticmethod + def _supports_current_device() -> bool: if current_platform.is_cuda(): return current_platform.has_device_capability(9, 0) else: return current_platform.is_cuda_alike() - def supports_no_act_and_mul(self) -> bool: + @staticmethod + def _supports_no_act_and_mul() -> bool: return False - def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: - return quant_config.use_mxfp4_w4a16 + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + return True + # return quant_scheme.is_mxfp4_w4a16 - def supports_act_fn(self, activation: str) -> bool: + @staticmethod + def _supports_activation(activation: str) -> bool: return activation in ["swigluoai"] - def supports_ep(self) -> bool: + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True def supports_expert_map(self) -> bool: @@ -318,14 +326,9 @@ def __init__(self, quant_config: FusedMoEQuantConfig): assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" super().__init__(quant_config) - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_chunking(self) -> bool: return True @@ -410,14 +413,9 @@ def __init__(self, quant_config: FusedMoEQuantConfig): assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" super().__init__(quant_config) - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_chunking(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 60e8ef9f77fd..7d628515a11e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -592,6 +592,7 @@ def __init__( has_bias=has_bias, is_act_and_mul=is_act_and_mul, is_lora_enabled=vllm_config.lora_config is not None, + activation=activation, ) self.moe_config_use_flashinfer_cutlass_kernels = ( self.moe_config.use_flashinfer_cutlass_kernels diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a7eea0e0ace7..ac64b9ebf92c 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -13,8 +13,10 @@ from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, + FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, @@ -380,11 +382,9 @@ def __init__( """ self.quant_config = quant_config - @property + @staticmethod @abstractmethod - def activation_formats( - self, - ) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]: + def activation_format() -> FusedMoEActivationFormat: """ A property which is a tuple of the input and output activation formats for the 'apply' method. @@ -433,21 +433,40 @@ def moe_problem_size( topk = topk_ids.size(1) return E, M, N, K, topk - + # # Various helpers for registering support for various features. + # Used by the oracle to select a particular kernel for a deployment. # + @staticmethod + def is_supported_config( + cls: type["FusedMoEPermuteExpertsUnpermute"], + moe_config: FusedMoEConfig, + moe_quant_scheme: FusedMoEQuantScheme, + activation_format: FusedMoEActivationFormat, + ) -> bool: + return ( + (cls._supports_current_device()) + and (not moe_config.is_act_and_mul or cls._supports_no_act_and_mul()) + and cls._supports_activation(moe_config.activation) + and cls._supports_quant_scheme(moe_quant_scheme) + and cls._supports_parallel_config(moe_config.moe_parallel_config) + and (activation_format == cls.activation_format()) + ) + @abstractmethod - def supports_current_device(self) -> bool: + @staticmethod + def _supports_current_device() -> bool: """ Whether the kernel supports the current device type (compute cability and current platform). """ raise NotImplementedError - + @abstractmethod - def supports_no_act_and_mul(self) -> bool: + @staticmethod + def _supports_no_act_and_mul() -> bool: """ Whether the kernel supports act_and_mul=False, i.e. non-gated MoE models like Nemotron-Nano. @@ -455,18 +474,21 @@ def supports_no_act_and_mul(self) -> bool: raise NotImplementedError @abstractmethod - def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: raise NotImplementedError @abstractmethod - def supports_act_fn(self, activation: str) -> bool: + @staticmethod + def _supports_activation(activation: str) -> bool: """ Whether the kernel supports a particular act function. """ raise NotImplementedError @abstractmethod - def supports_ep(self) -> bool: + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: """ Whether the kernel supports deployment in expert parallel. """ @@ -739,12 +761,12 @@ def __init__( self._post_init_setup() assert ( - prepare_finalize.activation_format == fused_experts.activation_formats[0] + prepare_finalize.activation_format == fused_experts.activation_format() ), ( f"{prepare_finalize.__class__.__name__}." f"{prepare_finalize.activation_format} == " f"{fused_experts.__class__.__name__}." - f"{fused_experts.activation_formats[0]}" + f"{fused_experts.activation_format()}" ) def _post_init_setup(self): diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index f5c3b9af611f..54ed6e1539e3 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -11,9 +11,13 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, + FusedMoEQuantScheme, fp8_w8a8_moe_quant_config, fp8_w8a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( + is_supported_config_trtllm, +) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, get_flashinfer_moe_backend, @@ -26,128 +30,211 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - cutlass_group_gemm_supported, -) -from vllm.platforms import current_platform -from vllm.utils.deep_gemm import is_deep_gemm_supported -from vllm.utils.flashinfer import has_flashinfer_moe -from vllm.utils.import_utils import has_deep_gemm logger = init_logger(__name__) class Fp8MoeBackend(Enum): NONE = 0 - FLASHINFER_TRTLLM = 1 - FLASHINFER_CUTLASS = 2 - DEEPGEMM = 3 - MARLIN = 4 - TRITON = 5 - AITER = 6 - VLLM_CUTLASS = 7 + FLASHINFER_TRTLLM = "FlashInfer TRTLLM" + FLASHINFER_CUTLASS = "FlashInfer CUTLASS" + DEEPGEMM = "DeepGEMM" + BATCHED_DEEPGEMM = "Batched DeepGEMM" + MARLIN = "Marlin" + TRITON = "Triton" + BATCHED_TRITON = "Triton" + AITER = "AITER" + VLLM_CUTLASS = "vLLM CUTLASS" + + +def backend_2_kernel_cls( + backend: Fp8MoeBackend, +) -> type[mk.FusedMoEPermuteExpertsUnpermute]: + if backend == Fp8MoeBackend.NONE or backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + raise NotImplementedError + elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, + ) + + return FlashInferExperts + + elif backend == Fp8MoeBackend.DEEPGEMM: + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts, + ) + + return TritonOrDeepGemmExperts + + elif backend == Fp8MoeBackend.BATCHED_DEEPGEMM: + from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts, + ) + + return BatchedDeepGemmExperts + + elif backend == Fp8MoeBackend.MARLIN: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + MarlinExperts, + ) + + return MarlinExperts + + elif backend == Fp8MoeBackend.TRITON: + from vllm.model_executor.layers.fused_moe.fused_moe import ( + TritonExperts, + ) + + return TritonExperts + + elif backend == Fp8MoeBackend.BATCHED_TRITON: + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts, + ) + + return BatchedTritonExperts + + elif backend == Fp8MoeBackend.AITER: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + AiterExperts, + ) + + return AiterExperts + + elif backend == Fp8MoeBackend.VLLM_CUTLASS: + from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import ( + TritonOrCutlassExperts, + ) + + return TritonOrCutlassExperts + + else: + raise ValueError(f"Unknown FP8 MoE backend: {backend.value}") def select_fp8_moe_backend( - block_quant: bool, - tp_size: int, - with_lora_support: bool, - is_act_and_mul: bool = True, + config: FusedMoEConfig, + quant_scheme: FusedMoEQuantScheme, + activation_format: mk.FusedMoEActivationFormat, allow_vllm_cutlass: bool = False, -) -> Fp8MoeBackend: +) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: """ Select the primary FP8 MoE backend Note: Shape-specific fallbacks may still occur at runtime. """ - # TODO(rob): in a future PR, we will query each mk for - # supported features and return the mk directly, just like - # we do for the Attention Backend. - - if with_lora_support: - return Fp8MoeBackend.TRITON - - def _make_log_backend(backend_name: str): - return f"Using {backend_name} backend for FP8 MoE" - - # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100. - if ( - current_platform.is_cuda() - and ( - current_platform.is_device_capability_family(100) - or current_platform.is_device_capability(90) + + if config.is_lora_enabled: + return Fp8MoeBackend.TRITON, backend_2_kernel_cls(Fp8MoeBackend.TRITON) + + def _make_log_backend(backend: Fp8MoeBackend): + return f"Using {backend.value} backend for FP8 MoE" + + def _return_or_raise( + backend: Fp8MoeBackend, + config: FusedMoEConfig, + scheme: FusedMoEQuantScheme, + activation_format: mk.FusedMoEActivationFormat, + ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: + k_cls = backend_2_kernel_cls(backend) + if k_cls.is_supported_config(k_cls, config, scheme, activation_format): + logger.info_once(_make_log_backend(backend)) + return backend, k_cls + + raise ValueError( + f"Requested FP8 MoE backend `{backend.value}` " + "does not support the deployment configuration." ) - and envs.VLLM_USE_FLASHINFER_MOE_FP8 - and has_flashinfer_moe() - ): - backend = get_flashinfer_moe_backend() - if backend == FlashinferMoeBackend.TENSORRT_LLM: - logger.info_once(_make_log_backend("FlashInfer TRTLLM")) - if not is_act_and_mul: + + # NOTE(rob): the kernels are selected in the following order. + AVAILABLE_BACKENDS = [ + Fp8MoeBackend.AITER, + Fp8MoeBackend.FLASHINFER_TRTLLM, + Fp8MoeBackend.FLASHINFER_CUTLASS, + Fp8MoeBackend.DEEPGEMM, + Fp8MoeBackend.BATCHED_DEEPGEMM, + Fp8MoeBackend.VLLM_CUTLASS, + Fp8MoeBackend.TRITON, + Fp8MoeBackend.BATCHED_TRITON, + Fp8MoeBackend.MARLIN, + ] + + if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP8"): + if not envs.VLLM_USE_FLASHINFER_MOE_FP8: + # If the user rejects FlashInfer remove those backends. + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.FLASHINFER_TRTLLM) + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.FLASHINFER_CUTLASS) + + elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): + # If user is explicit about backend, validate it. + fi_backend = get_flashinfer_moe_backend() + + if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: + backend = Fp8MoeBackend.FLASHINFER_TRTLLM + # TODO: validate activation format + if is_supported_config_trtllm(config, quant_scheme): + logger.info_once(_make_log_backend(backend)) + return backend, None # ? raise ValueError( - "FlashInfer TRTLLM FP8 MoE backend only supports " - "act_and_mul gate_up_project fusion. Please set " - "VLLM_USE_FLASHINFER_MOE_FP8=throughput to use the " - "FlashInfer CUTLASS backend instead." + f"Requested FP8 MoE backend `{backend.value}` " + "does not support the deployment configuration." ) - return Fp8MoeBackend.FLASHINFER_TRTLLM - else: - if block_quant and current_platform.is_device_capability_family(100): - raise ValueError( - "FlashInfer FP8 MoE throughput backend does not " - "support block quantization on SM100. Please use " - "VLLM_FLASHINFER_MOE_BACKEND=latency to use the " - "FlashInfer TRTLLM backend instead." + + elif fi_backend == FlashinferMoeBackend.CUTLASS: + backend = Fp8MoeBackend.FLASHINFER_CUTLASS + return _return_or_raise( + backend, config, quant_scheme, activation_format ) - logger.info_once(_make_log_backend("FlashInfer CUTLASS")) - return Fp8MoeBackend.FLASHINFER_CUTLASS - - # weight-only path for older GPUs without native FP8 - if ( - current_platform.is_cuda() and not current_platform.has_device_capability(89) - ) or envs.VLLM_TEST_FORCE_FP8_MARLIN: - logger.info_once(_make_log_backend("Marlin"), scope="local") - return Fp8MoeBackend.MARLIN - - # Determine if we should use DeepGEMM with block-quantized weights: - # - If explicitly set by user, respect their choice - # - If not explicitly set (default), disable when TP size is >= 8 - moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM - if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and tp_size >= 8: - moe_use_deep_gemm = False - logger.info_once( - "DeepGEMM MoE is disabled by default when TP size is >= 8. " - "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.", - scope="local", - ) - use_deep_gemm = envs.VLLM_USE_DEEP_GEMM - if not is_deep_gemm_supported(): - use_deep_gemm = False - logger.info_once( - "DeepGEMM is disabled because the platform does not support it.", - scope="local", - ) + else: + assert fi_backend == FlashinferMoeBackend.CUTEDSL + raise ValueError("FlashInfer MaskedGEMM not supported for FP8") - if use_deep_gemm and moe_use_deep_gemm and block_quant: - if not has_deep_gemm(): - logger.warning_once( - "DeepGEMM backend requested but not available.", scope="local" + else: + # If the user is not explicit about the backend, try both. + for backend in [ + Fp8MoeBackend.FLASHINFER_TRTLLM, + Fp8MoeBackend.FLASHINFER_CUTLASS, + ]: + k_cls = backend_2_kernel_cls(backend) + if k_cls.is_supported_config( + k_cls, config, quant_scheme, activation_format + ): + logger.info_once(_make_log_backend(backend)) + return backend, k_cls + + raise NotImplementedError( + "Found VLLM_USE_FLASHINFER_MOE_FP8=1, but no " + "FlashInfer FP8 MoE backend supports the configuration." ) - elif is_deep_gemm_supported(): - logger.info_once(_make_log_backend("DeepGEMM"), scope="local") - return Fp8MoeBackend.DEEPGEMM - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE: - logger.info_once(_make_log_backend("ROCm AITER"), scope="local") - return Fp8MoeBackend.AITER + elif envs.is_set("VLLM_USE_DEEP_GEMM") or envs.is_set("VLLM_MOE_USE_DEEP_GEMM"): + if not envs.VLLM_USE_DEEP_GEMM or not envs.VLLM_MOE_USE_DEEP_GEMM: + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.DEEPGEMM) + else: + backend = Fp8MoeBackend.DEEPGEMM + return _return_or_raise(backend, config, quant_scheme, activation_format) + + elif envs.VLLM_TEST_FORCE_FP8_MARLIN: + backend = Fp8MoeBackend.MARLIN + return _return_or_raise(backend, config, quant_scheme, activation_format) - if allow_vllm_cutlass and not block_quant and cutlass_group_gemm_supported(): - logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local") - return Fp8MoeBackend.VLLM_CUTLASS + elif envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE: + backend = Fp8MoeBackend.AITER + return _return_or_raise(backend, config, quant_scheme, activation_format) - # default to Triton - logger.info_once(_make_log_backend("Triton"), scope="local") - return Fp8MoeBackend.TRITON + if not allow_vllm_cutlass: + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.VLLM_CUTLASS) + + # Select kernels in order of backend. + for backend in AVAILABLE_BACKENDS: + k_cls = backend_2_kernel_cls(backend) + if k_cls.is_supported_config(k_cls, config, quant_scheme, activation_format): + logger.info_once(_make_log_backend(backend)) + return backend, k_cls + + raise NotImplementedError( + "No FP8 MoE backend supports the deployment configuration." + ) def convert_to_fp8_moe_kernel_format( diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 85b14507816d..5aa30f836db1 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -9,7 +9,9 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEParallelConfig, FusedMoEQuantConfig, + FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, @@ -273,32 +275,32 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, quant_config): super().__init__(quant_config) - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard - def supports_current_device(self) -> bool: + @staticmethod + def _supports_current_device() -> bool: return current_platform.is_rocm() - def supports_no_act_and_mul(self) -> bool: + @staticmethod + def _supports_no_act_and_mul() -> bool: return False - def supports_quant_config(self, quant_config: FusedMoEQuantConfig) -> bool: + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: return ( - quant_config.use_fp8_w8a8 - or quant_config.use_mxfp4_w4a4 - or False # TODO: how to represent unquantizes? + quant_scheme.is_fp8_w8a8 + # or quant_scheme.is_mxfp4_w4a4 + or quant_scheme.is_unquantized ) - def supports_act_fn(self, activation: str) -> bool: + @staticmethod + def _supports_activation(activation: str) -> bool: return activation in ["silu", "gelu"] - def supports_ep(self) -> bool: + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True def supports_expert_map(self): diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py index e874ba609be0..231be195d040 100644 --- a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py @@ -5,7 +5,11 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, + FusedMoEQuantConfig, + FusedMoEQuantScheme, +) from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts @@ -30,6 +34,30 @@ def __init__( fallback_experts=TritonExperts(quant_config), ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return CutlassExpertsFp8.activation_format() + + @staticmethod + def _supports_current_device() -> bool: + return CutlassExpertsFp8._supports_current_device() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return CutlassExpertsFp8._supports_no_act_and_mul() + + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + return CutlassExpertsFp8._supports_quant_scheme(quant_scheme) + + @staticmethod + def _supports_activation(activation: str) -> bool: + return CutlassExpertsFp8._supports_activation(activation) + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return CutlassExpertsFp8._supports_parallel_config(moe_parallel_config) + def workspace_shapes( self, M: int, diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 4fcc1a7c1fc0..d181d9ef0ee9 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -4,7 +4,11 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, + FusedMoEQuantConfig, + FusedMoEQuantScheme, +) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, @@ -26,6 +30,31 @@ def __init__(self, quant_config: FusedMoEQuantConfig): fallback_experts=TritonExperts(quant_config), ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + assert DeepGemmExperts.activation_format() == TritonExperts.activation_format() + return DeepGemmExperts.activation_format() + + @staticmethod + def _supports_current_device() -> bool: + return DeepGemmExperts._supports_current_device() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return DeepGemmExperts._supports_no_act_and_mul() + + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + return DeepGemmExperts._supports_quant_scheme(quant_scheme) + + @staticmethod + def _supports_activation(activation: str) -> bool: + return DeepGemmExperts._supports_activation(activation) + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return DeepGemmExperts._supports_parallel_config(moe_parallel_config) + def workspace_shapes( self, M: int, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index b71921c5c3d7..86e07c35ed0f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -45,7 +45,6 @@ Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, - select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( FLASHINFER_NVFP4_MOE_BACKENDS, @@ -579,13 +578,14 @@ def __init__( "For FP8 Fused MoE layer, we require either per tensor or " "channelwise, dynamic per token quantization." ) - self.fp8_backend = select_fp8_moe_backend( - block_quant=self.block_quant, - tp_size=moe.tp_size, - with_lora_support=moe.is_lora_enabled, - # TODO(rob): enable selecting this externally. - allow_vllm_cutlass=True, - ) + # self.fp8_backend = select_fp8_moe_backend( + # block_quant=self.block_quant, + # tp_size=moe.tp_size, + # with_lora_support=moe.is_lora_enabled, + # # TODO(rob): enable selecting this externally. + # allow_vllm_cutlass=True, + # ) + self.fp8_backend = Fp8MoeBackend.FLASHINFER_CUTLASS if self.fp8_backend != Fp8MoeBackend.MARLIN: per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN per_channel_quant = ( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1c0c35bf6f41..d1685e1bf48c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -27,6 +27,7 @@ ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, + FusedMoEQuantScheme, RoutingMethodType, ) from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter @@ -633,34 +634,25 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): self.weight_scale_name = ( "weight_scale_inv" if self.block_quant else "weight_scale" ) - self.fp8_backend = select_fp8_moe_backend( - block_quant=self.block_quant, - tp_size=layer.moe_parallel_config.tp_size, - with_lora_support=self.moe.is_lora_enabled, + + quant_scheme = FusedMoEQuantScheme( + weight_dtype=current_platform.fp8_dtype(), + act_dtype=current_platform.fp8_dtype(), + per_tensor_quant=not self.block_quant, + per_token_quant=False, + block_size=( + (self.weight_block_size[0], self.weight_block_size[1]) + if self.weight_block_size is not None + else None + ), ) - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - if self.block_quant and self.weight_block_size != [128, 128]: - raise NotImplementedError( - "FlashInfer CUTLASS FP8 MoE backend only supports block " - "size [128, 128]." - ) - if layer.activation != "silu": - raise NotImplementedError( - "FlashInfer CUTLASS FP8 MoE backend only supports SiLU " - "activation function, but got {layer.activation}." - ) - dynamic_per_token = ( - not self.block_quant and self.quant_config.activation_scheme != "static" + self.fp8_backend, self.experts_cls = select_fp8_moe_backend( + config=layer.moe_config, + quant_scheme=quant_scheme, + # TODO(rob): select prepare_finalize here. + activation_format=mk.FusedMoEActivationFormat.Standard, ) - if dynamic_per_token and self.fp8_backend in [ - Fp8MoeBackend.FLASHINFER_TRTLLM, - Fp8MoeBackend.FLASHINFER_CUTLASS, - ]: - raise NotImplementedError( - "FlashInfer FP8 MoE backend does not support dynamic per token " - "activation quantization." - ) self.kernel: mk.FusedMoEModularKernel | None = None diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index a646012ddd3a..986d0a906e05 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -28,7 +28,6 @@ convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, make_fp8_moe_quant_config, - select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( FLASHINFER_NVFP4_MOE_BACKENDS, @@ -729,11 +728,12 @@ def __init__( super().__init__(moe_config) self.quant_config = quant_config assert self.quant_config.is_checkpoint_fp8_serialized - self.fp8_backend = select_fp8_moe_backend( - block_quant=False, - tp_size=moe_config.moe_parallel_config.tp_size, - with_lora_support=self.moe.is_lora_enabled, - ) + # self.fp8_backend = select_fp8_moe_backend( + # block_quant=False, + # tp_size=moe_config.moe_parallel_config.tp_size, + # with_lora_support=self.moe.is_lora_enabled, + # ) + self.fp8_backend = Fp8MoeBackend.FLASHINFER_CUTLASS self.kernel: mk.FusedMoEModularKernel | None = None def maybe_make_prepare_finalize( From 8f0a969bd09ec3521362ee789e65abaa82b6685e Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 11:22:32 -0500 Subject: [PATCH 027/113] remove is_cuda Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py | 4 +--- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 4 +--- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 4 +--- vllm/model_executor/layers/fused_moe/modular_kernel.py | 2 +- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index ca1317c55c25..1928a7ba7544 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -278,9 +278,7 @@ def activation_format() -> mk.FusedMoEActivationFormat: @staticmethod def _supports_current_device() -> bool: - return current_platform.is_cuda() and current_platform.has_device_capability( - 9, 0 - ) + return current_platform.has_device_capability((9, 0)) @staticmethod def _supports_no_act_and_mul() -> bool: diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index b3a484f9f3a3..18699b1150f2 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -266,9 +266,7 @@ def __init__( @staticmethod def _supports_current_device() -> bool: - return current_platform.is_cuda() and current_platform.has_device_capability( - 9, 0 - ) + return current_platform.has_device_capability((9, 0)) @staticmethod def _supports_no_act_and_mul() -> bool: diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index dbbeecdca02c..f0c7b3884593 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -125,9 +125,7 @@ def activation_format() -> mk.FusedMoEActivationFormat: @staticmethod def _supports_current_device() -> bool: - return current_platform.is_cuda() and current_platform.has_device_capability( - 9, 0 - ) + return current_platform.has_device_capability((9, 0)) @staticmethod def _supports_no_act_and_mul() -> bool: diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index ac64b9ebf92c..b854e30b0fc5 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -448,7 +448,7 @@ def is_supported_config( ) -> bool: return ( (cls._supports_current_device()) - and (not moe_config.is_act_and_mul or cls._supports_no_act_and_mul()) + and (moe_config.is_act_and_mul or cls._supports_no_act_and_mul()) and cls._supports_activation(moe_config.activation) and cls._supports_quant_scheme(moe_quant_scheme) and cls._supports_parallel_config(moe_config.moe_parallel_config) From e461b6f01da64606b98bd19fcac342bfc07b7108 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 11:25:46 -0500 Subject: [PATCH 028/113] added renamed file Signed-off-by: Robert Shaw --- .../model_executor/layers/fused_moe/config.py | 4 + .../layers/fused_moe/trtllm_mxfp4_moe.py | 167 ++++++++++++++++++ 2 files changed, 171 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/trtllm_mxfp4_moe.py diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index b7a6ede50157..b0311b1f58b4 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -166,6 +166,10 @@ def is_fp8_w8a8(self) -> bool: def is_nvfp4_w4a4(self) -> bool: return self.weight_dtype == "nvfp4" and self.act_dtype == "nvfp4" + @property + def is_mxfp4_w4a4(self) -> bool: + return self.weight_dtype == "mxfp4" and self.act_dtype == "mxfp4" + @dataclass class FusedMoEQuantDesc: diff --git a/vllm/model_executor/layers/fused_moe/trtllm_mxfp4_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_mxfp4_moe.py new file mode 100644 index 000000000000..bb9ef2af182f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/trtllm_mxfp4_moe.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, + FusedMoEQuantScheme, +) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP, +) +from vllm.platforms import current_platform + + +class TrtllmMxFp4Experts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__( + self, + moe: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + gemm1_alpha, + gemm1_beta, + gemm1_clamp_limit, + max_capture_size, + ): + super().__init__(quant_config) + self.moe = moe + self.gemm1_alpha = gemm1_alpha + self.gemm1_beta = gemm1_beta + self.gemm1_clamp_limit = gemm1_clamp_limit + self.max_capture_size = max_capture_size + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + return current_platform.is_cuda() and current_platform.has_device_capability( + 10, 0 + ) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + return ( + True + # quant_scheme.is_mxfp4_w4a16 or + # quant_scheme.is_mxfp4_w4a8 + ) + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["swigluoai"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True + + def supports_chunking(self) -> bool: + return True + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + # The workspaces for this implementation are managed by flashinfer. + workspace1 = (0,) + workspace2 = (0,) + output = (M, K) + return (workspace1, workspace2, output) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + topk = topk_ids.size(-1) + local_num_experts = w1.size(0) + intermediate_size = w2.size(1) + local_expert_offset = self.moe.ep_rank * local_num_experts + + x_quant = hidden_states + x_scale = a1q_scale + if x_scale is not None: + x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x_quant.shape[:-1], -1) + + packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( + torch.bfloat16 + ).view(torch.int16) + + assert self.w1_scale is not None + assert self.w2_scale is not None + kwargs = { + "topk_ids": packed_tensor, + "routing_bias": None, + "hidden_states": x_quant, + "hidden_states_scale": x_scale, + "gemm1_weights": w1, + "gemm1_weights_scale": self.w1_scale, + "gemm1_bias": self.w1_bias, + "gemm1_alpha": self.gemm1_alpha, + "gemm1_beta": self.gemm1_beta, + "gemm1_clamp_limit": self.gemm1_clamp_limit, + "gemm2_weights": w2, + "gemm2_weights_scale": self.w2_scale, + "gemm2_bias": self.w2_bias, + "output1_scale_scalar": None, + "output1_scale_gate_scalar": None, + "output2_scale_scalar": None, + "num_experts": global_num_experts, + "top_k": topk, + "n_group": None, + "topk_group": None, + "intermediate_size": intermediate_size, + "local_expert_offset": local_expert_offset, + "local_num_experts": local_num_experts, + "routed_scaling_factor": None, + "tile_tokens_dim": None, + "routing_method_type": 1, + "do_finalize": True, + "output": output, + "tune_max_num_tokens": max(self.max_capture_size, 1), + } + + from flashinfer import trtllm_fp4_block_scale_routed_moe + + from vllm.utils.flashinfer import autotune + + with autotune(False): + # Enable autotune when, + # https://github.com/flashinfer-ai/flashinfer/issues/2023 is + # resolved. + trtllm_fp4_block_scale_routed_moe(**kwargs) + + return output From 93bd28bf42bdddd201760d33a7450fa5f01a1257 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 11:29:10 -0500 Subject: [PATCH 029/113] improve quant scheme Signed-off-by: Robert Shaw --- .../model_executor/layers/fused_moe/config.py | 27 +++++++++++++------ .../layers/fused_moe/trtllm_mxfp4_moe.py | 10 ++----- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index b0311b1f58b4..765e3d16f09a 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -132,6 +132,14 @@ class FusedMoEQuantScheme: per_tensor_quant: bool block_size: tuple[int, int] | None + @property + def per_block_quant(self) -> bool: + return self.block_size is not None + + @property + def is_unquantized(self) -> bool: + return self.weight_dtype in UNQUANTIZED_DTYPES + def __post_init__(self): if self.per_tensor_quant: assert not self.per_token_quant @@ -147,14 +155,6 @@ def __post_init__(self): if self.is_unquantized: assert self.act_dtype in UNQUANTIZED_DTYPES - @property - def per_block_quant(self) -> bool: - return self.block_size is not None - - @property - def is_unquantized(self) -> bool: - return self.weight_dtype in UNQUANTIZED_DTYPES - @property def is_fp8_w8a8(self) -> bool: return ( @@ -170,6 +170,17 @@ def is_nvfp4_w4a4(self) -> bool: def is_mxfp4_w4a4(self) -> bool: return self.weight_dtype == "mxfp4" and self.act_dtype == "mxfp4" + @property + def is_mxfp4_w4a8(self) -> bool: + return ( + self.weight_dtype == "mxfp4" + and self.act_dtype == current_platform.fp8_dtype() + ) + + @property + def is_mxfp4_w4a16(self) -> bool: + return self.weight_dtype == "mxfp4" and self.act_dtype in UNQUANTIZED_DTYPES + @dataclass class FusedMoEQuantDesc: diff --git a/vllm/model_executor/layers/fused_moe/trtllm_mxfp4_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_mxfp4_moe.py index bb9ef2af182f..e2fe96fa9de3 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_mxfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_mxfp4_moe.py @@ -39,9 +39,7 @@ def activation_format() -> mk.FusedMoEActivationFormat: @staticmethod def _supports_current_device() -> bool: - return current_platform.is_cuda() and current_platform.has_device_capability( - 10, 0 - ) + return current_platform.has_device_capability((10, 0)) @staticmethod def _supports_no_act_and_mul() -> bool: @@ -49,11 +47,7 @@ def _supports_no_act_and_mul() -> bool: @staticmethod def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return ( - True - # quant_scheme.is_mxfp4_w4a16 or - # quant_scheme.is_mxfp4_w4a8 - ) + return quant_scheme.is_mxfp4_w4a16 or quant_scheme.is_mxfp4_w4a8 @staticmethod def _supports_activation(activation: str) -> bool: From 2b24d70356ff9c192461b59dee3099305bf43d82 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 11:32:07 -0500 Subject: [PATCH 030/113] improve validation Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index 475f849720f8..05f965435617 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -44,8 +44,7 @@ def activation_format() -> mk.FusedMoEActivationFormat: @staticmethod def _supports_current_device() -> bool: return ( - current_platform.is_cuda() - and current_platform.has_device_capability(10, 0) + current_platform.has_device_capability((10, 0)) and has_flashinfer_cutedsl_grouped_gemm_nt_masked() ) From 312b7676d8679a3a802bf288447ed584bd440a9f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 11:43:26 -0500 Subject: [PATCH 031/113] improve platform selection logic Signed-off-by: Robert Shaw --- .../layers/fused_moe/batched_deep_gemm_moe.py | 3 ++- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 12 +++++------- .../model_executor/layers/fused_moe/deep_gemm_moe.py | 4 ++-- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 1928a7ba7544..568917b3353d 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -23,6 +23,7 @@ fp8_m_grouped_gemm_nt_masked, get_mk_alignment_for_contiguous_layout, is_deep_gemm_e8m0_used, + is_deep_gemm_supported, ) from vllm.utils.math_utils import cdiv, round_up @@ -278,7 +279,7 @@ def activation_format() -> mk.FusedMoEActivationFormat: @staticmethod def _supports_current_device() -> bool: - return current_platform.has_device_capability((9, 0)) + return is_deep_gemm_supported() @staticmethod def _supports_no_act_and_mul() -> bool: diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 18699b1150f2..fccaf81734d8 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -26,6 +26,9 @@ TopKWeightAndReduceNoOP, ) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + cutlass_group_gemm_supported, +) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -266,7 +269,7 @@ def __init__( @staticmethod def _supports_current_device() -> bool: - return current_platform.has_device_capability((9, 0)) + return cutlass_group_gemm_supported() @staticmethod def _supports_no_act_and_mul() -> bool: @@ -274,12 +277,7 @@ def _supports_no_act_and_mul() -> bool: @staticmethod def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - if not current_platform.has_device_capability((9, 0)): - return False - - return quant_scheme.is_fp8_w8a8 and ( - quant_scheme.per_tensor_quant or quant_scheme.per_token_quant - ) + return quant_scheme.is_fp8_w8a8 and not quant_scheme.per_block_quant @staticmethod def _supports_activation(activation: str) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index f0c7b3884593..4f5171154d64 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -28,10 +28,10 @@ per_token_group_quant_fp8_packed_for_deepgemm, silu_mul_per_token_group_quant_fp8_colmajor, ) -from vllm.platforms import current_platform from vllm.utils.deep_gemm import ( DeepGemmQuantScaleFMT, get_mk_alignment_for_contiguous_layout, + is_deep_gemm_supported, m_grouped_fp8_gemm_nt_contiguous, ) from vllm.utils.import_utils import has_deep_gemm @@ -125,7 +125,7 @@ def activation_format() -> mk.FusedMoEActivationFormat: @staticmethod def _supports_current_device() -> bool: - return current_platform.has_device_capability((9, 0)) + return is_deep_gemm_supported() @staticmethod def _supports_no_act_and_mul() -> bool: From 0f37e9547ed1060fd0e397c1c5ccf7ca3e5fa7b3 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 11:45:33 -0500 Subject: [PATCH 032/113] nit newline Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/fused_marlin_moe.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 4cb03eb18951..bf26948efbcf 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -557,9 +557,8 @@ def __init__( @staticmethod def _supports_current_device() -> bool: - return current_platform.is_cuda() and current_platform.has_device_capability( - 8, 0 - ) + p = current_platform + return p.is_cuda() and p.has_device_capability((8, 0)) @staticmethod def _supports_no_act_and_mul() -> bool: From c187335c01a64d120661cc9cc8330b055ed41b2c Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 11:48:53 -0500 Subject: [PATCH 033/113] nit newline Signed-off-by: Robert Shaw --- .../fused_moe/gpt_oss_triton_kernels_moe.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 3816399f4322..5acde4f8904e 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -244,12 +244,15 @@ def make_routing_data( class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__(self, quant_config: FusedMoEQuantConfig): + # TODO (varun) : Enable activation quantization + assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" + super().__init__(quant_config) + @staticmethod def _supports_current_device() -> bool: - if current_platform.is_cuda(): - return current_platform.has_device_capability(9, 0) - else: - return current_platform.is_cuda_alike() + p = current_platform + return (p.is_cuda() and p.has_device_capability((9, 0))) or p.is_rocm() @staticmethod def _supports_no_act_and_mul() -> bool: @@ -321,11 +324,6 @@ def _make_routing_data( class OAITritonExperts(BaseOAITritonExperts): - def __init__(self, quant_config: FusedMoEQuantConfig): - # TODO (varun) : Enable activation quantization - assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" - super().__init__(quant_config) - @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard @@ -412,6 +410,7 @@ def __init__(self, quant_config: FusedMoEQuantConfig): # TODO (varun) : Enable activation quantization assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" super().__init__(quant_config) + self.quant_config = quant_config @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: From 6ad575e532e8f727733128777c81e72573c0e656 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 11:50:18 -0500 Subject: [PATCH 034/113] revert spurious LOC change Signed-off-by: Robert Shaw --- .../layers/fused_moe/gpt_oss_triton_kernels_moe.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 5acde4f8904e..e4cf3e7808b9 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -245,8 +245,6 @@ def make_routing_data( class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, quant_config: FusedMoEQuantConfig): - # TODO (varun) : Enable activation quantization - assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" super().__init__(quant_config) @staticmethod @@ -324,6 +322,11 @@ def _make_routing_data( class OAITritonExperts(BaseOAITritonExperts): + def __init__(self, quant_config: FusedMoEQuantConfig): + # TODO (varun) : Enable activation quantization + assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" + super().__init__(quant_config) + @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard From 2aabb4f26a7a064c86b23832bf03a8989a412fed Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 11:53:09 -0500 Subject: [PATCH 035/113] revert spurious LOC change Signed-off-by: Robert Shaw --- .../layers/fused_moe/fused_moe.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f1a245a920a6..1492a08336e3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2328,16 +2328,16 @@ def _supports_no_act_and_mul() -> bool: @staticmethod def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - # Supports unquantized and fp8. - # TODO(rob): allow int4 (for kimi --- no, we have marlinexperts for this. - if not (quant_scheme.is_fp8_w8a8 or quant_scheme.is_unquantized): - return False - - if quant_scheme.is_fp8_w8a8: - return current_platform.is_rocm() or current_platform.has_device_capability( - (9, 0) - ) - return False + p = current_platform + device_supports_fp8 = p.is_rocm() or ( + p.is_cuda() and p.has_device_capability((9, 0)) + ) + + return ( + quant_scheme.is_unquantized() + or quant_scheme.is_fp8_w8a8 + and device_supports_fp8 + ) @staticmethod def _supports_activation(activation: str) -> bool: From 3c5f602d3c45f81cd149d4a9a754c372c4a92083 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 11:59:19 -0500 Subject: [PATCH 036/113] updated Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/config.py | 7 +++++++ vllm/model_executor/layers/fused_moe/cutlass_moe.py | 10 +++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 765e3d16f09a..4996b4bee87b 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -181,6 +181,13 @@ def is_mxfp4_w4a8(self) -> bool: def is_mxfp4_w4a16(self) -> bool: return self.weight_dtype == "mxfp4" and self.act_dtype in UNQUANTIZED_DTYPES + @property + def is_int4_w4a8(self) -> bool: + return ( + self.weight_dtype == "int4" + and self.act_dtype == current_platform.fp8_dtype() + ) + @dataclass class FusedMoEQuantDesc: diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index fccaf81734d8..5725c98a42ea 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -868,9 +868,8 @@ def activation_format() -> mk.FusedMoEActivationFormat: @staticmethod def _supports_current_device() -> bool: - return current_platform.is_cuda() and current_platform.has_device_capability( - (9, 0) - ) + p = current_platform + return p.is_cuda() and p.has_device_capability((9, 0)) @staticmethod def _supports_no_act_and_mul() -> bool: @@ -878,10 +877,7 @@ def _supports_no_act_and_mul() -> bool: @staticmethod def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return True - # quant_scheme.is_fp8_w8a8 - # # or quant_scheme.is_mxfp4_w4a4 - # or quant_scheme.is_unquantized + return quant_scheme.is_int4_w4a8 and quant_scheme.per_token_quant @staticmethod def _supports_activation(activation: str) -> bool: From 00a130a03146d74300d01c0bcf806f323f887fbb Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 12:02:13 -0500 Subject: [PATCH 037/113] update marlin Signed-off-by: Robert Shaw --- .../layers/fused_moe/fused_marlin_moe.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index bf26948efbcf..b98c4943393b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -566,15 +566,13 @@ def _supports_no_act_and_mul() -> bool: @staticmethod def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - # TODO(rob): check if we support the Fp8 activation - return True - # return ( - # quant_scheme.use_fp8_w8a16 - # or quant_config.use_int8_w8a16 - # or quant_config.use_int4_w4a16 - # or quant_config.use_nvfp4_w4a16 - # or quant_config.use_mxfp4_w4a16 - # ) + return quant_scheme.weight_dtype in [ + current_platform.fp8_dtype(), + torch.int8, + "int4", + "nvfp4", + "mxfp4", + ] @staticmethod def _supports_activation(activation: str) -> bool: From bd38266bc4e110527130eb87490994834f1b73a8 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 12:02:49 -0500 Subject: [PATCH 038/113] updated Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/fused_marlin_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index b98c4943393b..5bf5060c0892 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -558,6 +558,7 @@ def __init__( @staticmethod def _supports_current_device() -> bool: p = current_platform + # Is this right? Can we do < Ampere? return p.is_cuda() and p.has_device_capability((8, 0)) @staticmethod From bba97f457a673df31976c4029b7add4c048f87e2 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 12:03:30 -0500 Subject: [PATCH 039/113] updated Signed-off-by: Robert Shaw --- .../layers/fused_moe/gpt_oss_triton_kernels_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index e4cf3e7808b9..5e16e1e81273 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -258,8 +258,7 @@ def _supports_no_act_and_mul() -> bool: @staticmethod def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return True - # return quant_scheme.is_mxfp4_w4a16 + return quant_scheme.is_mxfp4_w4a16 @staticmethod def _supports_activation(activation: str) -> bool: From dc9723acd2be79e79d0f8e33f88c99ccf2fe3b4f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 12:07:31 -0500 Subject: [PATCH 040/113] support differentiating on static vs dynamic Signed-off-by: Robert Shaw --- .../model_executor/layers/fused_moe/config.py | 3 +++ .../fused_moe/flashinfer_cutlass_moe.py | 19 +++++++++---------- .../layers/fused_moe/flashinfer_trtllm_moe.py | 7 ++++--- .../model_executor/layers/quantization/fp8.py | 1 + 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 4996b4bee87b..33241640e489 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -130,6 +130,7 @@ class FusedMoEQuantScheme: act_dtype: torch.dtype | str | None per_token_quant: bool per_tensor_quant: bool + static_input_quant: bool block_size: tuple[int, int] | None @property @@ -147,9 +148,11 @@ def __post_init__(self): elif self.per_token_quant: assert not self.per_tensor_quant assert not self.per_block_quant + assert not self.static_input_quant elif self.per_block_quant: assert not self.per_tensor_quant assert not self.per_token_quant + assert not self.static_input_quant assert self.block_size is not None assert self.per_block_quant or self.per_token_quant or self.per_tensor_quant if self.is_unquantized: diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index a1a0c2f1c503..cbde0cb50eae 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -66,21 +66,20 @@ def _supports_no_act_and_mul() -> bool: def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: # Supports: # * unquantized - # * fp8 per-tensor on 9.0+ + # * fp8 static per-tensor on 9.0+ # * fp8 block on 9.0 # * nvfp4 on 10.0+ + s = quant_scheme + p = current_platform return ( - (quant_scheme.is_unquantized) - or (quant_scheme.is_fp8_w8a8 and quant_scheme.per_tensor_quant) + (s.is_unquantized) + or (s.is_fp8_w8a8 and s.per_tensor_quant and s.static_input_quant) or ( - quant_scheme.is_fp8_w8a8 - and quant_scheme.block_size == [128, 128] - and current_platform.is_device_capability((9, 0)) - ) - or ( - quant_scheme.is_nvfp4_w4a4 - and current_platform.has_device_capability((10, 0)) + s.is_fp8_w8a8 + and s.block_size == [128, 128] + and p.is_device_capability((9, 0)) ) + or (s.is_nvfp4_w4a4 and p.has_device_capability((10, 0))) ) @staticmethod diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index e575f4fdb49d..d3189db7bee7 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -38,10 +38,11 @@ def _supports_no_act_and_mul() -> bool: def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: """Supports Fp8 per-tensor, Fp8 block, and Nvfp4 quantization.""" + s = quant_scheme return ( - (quant_scheme.is_fp8_w8a8 and quant_scheme.per_tensor_quant) - or (quant_scheme.is_fp8_w8a8 and quant_scheme.block_size == [128, 128]) - or (quant_scheme.is_nvfp4_w4a4) + (s.is_fp8_w8a8 and s.per_tensor_quant and s.static_input_quant) + or (s.is_fp8_w8a8 and s.block_size == [128, 128]) + or (s.is_nvfp4_w4a4) ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d1685e1bf48c..6a73082ec631 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -645,6 +645,7 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): if self.weight_block_size is not None else None ), + static_input_quant=(self.quant_config.activation_scheme == "static"), ) self.fp8_backend, self.experts_cls = select_fp8_moe_backend( From fc9ea076e84cafc01643ca936fb2a9bf06c2179f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 12:08:39 -0500 Subject: [PATCH 041/113] newline nit Signed-off-by: Robert Shaw --- .../model_executor/layers/fused_moe/flashinfer_trtllm_moe.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index d3189db7bee7..34fe14fd2d91 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -26,9 +26,8 @@ def _supports_current_device() -> bool: """Supports only Blackwell-family GPUs.""" - return current_platform.is_cuda() and current_platform.is_device_capability_family( - 10 - ) + p = current_platform + return p.is_cuda() and p.is_device_capability_family(10) def _supports_no_act_and_mul() -> bool: From ef8bb7e4688a8fae76305436e1c718eafb1c246d Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 12:11:25 -0500 Subject: [PATCH 042/113] update logic for aiter foudn Signed-off-by: Robert Shaw --- .../layers/fused_moe/flashinfer_trtllm_moe.py | 1 + .../layers/fused_moe/rocm_aiter_fused_moe.py | 9 ++++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 34fe14fd2d91..b23c648e9def 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -27,6 +27,7 @@ def _supports_current_device() -> bool: """Supports only Blackwell-family GPUs.""" p = current_platform + # Add check flashinfer trtllm is available return p.is_cuda() and p.is_device_capability_family(10) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 5aa30f836db1..6696e3d0126d 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -16,7 +16,6 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) -from vllm.platforms import current_platform class QuantMethod(IntEnum): @@ -281,7 +280,7 @@ def activation_format() -> mk.FusedMoEActivationFormat: @staticmethod def _supports_current_device() -> bool: - return current_platform.is_rocm() + return rocm_aiter_ops.IS_AITER_FOUND @staticmethod def _supports_no_act_and_mul() -> bool: @@ -290,9 +289,9 @@ def _supports_no_act_and_mul() -> bool: @staticmethod def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: return ( - quant_scheme.is_fp8_w8a8 - # or quant_scheme.is_mxfp4_w4a4 - or quant_scheme.is_unquantized + quant_scheme.is_unquantized + or quant_scheme.is_fp8_w8a8 + or quant_scheme.is_mxfp4_w4a4 ) @staticmethod From dde5dd23b7820e42e98d59805ce8b94958f11473 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 12:16:32 -0500 Subject: [PATCH 043/113] cleanup support logic Signed-off-by: Robert Shaw --- .../layers/fused_moe/fused_batched_moe.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 078847218de2..c4d9a959de7d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -731,17 +731,13 @@ def _supports_no_act_and_mul() -> bool: @staticmethod def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - # Supports unquantized and fp8. - # TODO(rob): allow int4 (for kimi --- no, we have marlinexperts for this. - if quant_scheme.is_unquantized: - return True - - if quant_scheme.is_fp8_w8a8: - return current_platform.is_rocm() or current_platform.has_device_capability( - (9, 0) - ) - - return False + p = current_platform + device_supports_fp8 = p.is_rocm() or ( + p.is_cuda() and p.has_device_capability((9, 0)) + ) + return quant_scheme.is_unquantized or ( + quant_scheme.is_fp8_w8a8 and device_supports_fp8 + ) @staticmethod def _supports_activation(activation: str) -> bool: From f8045c680e2d27f595cefce73e9e8b7c575cbb58 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 12:17:28 -0500 Subject: [PATCH 044/113] update logic for trtllm config Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index b23c648e9def..8cc78e476e3e 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -65,7 +65,7 @@ def is_supported_config_trtllm( """ return ( _supports_current_device() - and (not moe_config.is_act_and_mul or _supports_no_act_and_mul()) + and (moe_config.is_act_and_mul or _supports_no_act_and_mul()) and _supports_activation(moe_config.activation) and _supports_quant_scheme(moe_quant_scheme) and _supports_moe_parallel_config(moe_config.moe_parallel_config) From 40ecf90bca6b2b93bfa924eddc0f5dbe247c451e Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 12:20:22 -0500 Subject: [PATCH 045/113] updated Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/fused_moe.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1492a08336e3..daeea1a61e14 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2524,6 +2524,11 @@ def __init__( ): super().__init__(quant_config) + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + # TODO + return True + def apply( self, output: torch.Tensor, From 372c697a65e9a32533f7144f1583ac2d970309f2 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 19:59:46 -0500 Subject: [PATCH 046/113] more progress Signed-off-by: Robert Shaw --- .../moe/modular_kernel_tools/common.py | 2 + .../model_executor/layers/fused_moe/config.py | 8 +- .../layers/fused_moe/cutlass_moe.py | 16 ++- .../fused_moe/flashinfer_cutlass_moe.py | 33 +++--- vllm/model_executor/layers/fused_moe/layer.py | 2 + .../layers/fused_moe/modular_kernel.py | 67 +++++++++++ .../layers/fused_moe/oracle/fp8.py | 108 +++--------------- .../layers/fused_moe/rocm_aiter_fused_moe.py | 7 ++ .../model_executor/layers/quantization/fp8.py | 80 +++++++------ 9 files changed, 168 insertions(+), 155 deletions(-) diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index 52906c043df0..dd0519c1953a 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -574,11 +574,13 @@ def next_power_of_2(x): num_experts=config.E, experts_per_token=config.topk, hidden_dim=config.K, + intermediate_size_per_partition=config.N, num_local_experts=config.num_local_experts, moe_parallel_config=moe_parallel_config, in_dtype=config.dtype, max_num_tokens=next_power_of_2(config.M), activation="silu", + device=vllm_config.device_config.device, ) # make modular kernel diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 33241640e489..cbcfcb6f487b 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -952,6 +952,10 @@ def use_deepep_ht_kernels(self): @property def use_deepep_ll_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" + + @property + def use_fi_all2all_kernels(self): + return self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv" @staticmethod def flatten_tp_across_dp_and_pcp( @@ -1097,6 +1101,7 @@ class FusedMoEConfig: num_experts: int experts_per_token: int hidden_dim: int + intermediate_size_per_partition: int num_local_experts: int moe_parallel_config: FusedMoEParallelConfig @@ -1115,7 +1120,8 @@ class FusedMoEConfig: is_lora_enabled: bool = False - activation: str = "silu" + activation: str + device: torch.dtype def __post_init__(self): if self.dp_size > 1: diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 5725c98a42ea..86479e405ecc 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -13,6 +13,7 @@ FusedMoEParallelConfig, FusedMoEQuantConfig, FusedMoEQuantScheme, + FusedMoEConfig, ) from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( moe_permute, @@ -244,24 +245,21 @@ def run_cutlass_moe_fp8( class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - e: int, - n: int, - k: int, - out_dtype: torch.dtype | None, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, device: torch.dtype, ): assert quant_config.use_fp8_w8a8 super().__init__(quant_config) - - # E: num_experts - # N: intermediate size per partition - # K: hidden dim + + e = moe_config.num_local_experts + n = moe_config.intermediate_size_per_partition + k = moe_config.hidden_dim ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64) ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64) c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64) - self.out_dtype = out_dtype + self.out_dtype = moe_config.in_dtype self.ab_strides1 = ab_strides1_c_strides2 self.ab_strides2 = ab_strides2 self.c_strides1 = c_strides1 diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index cbde0cb50eae..2111a8ed0c52 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -25,30 +25,33 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - out_dtype: torch.dtype, + moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig, - ep_rank: int = 0, - ep_size: int = 1, - tp_rank: int = 0, - tp_size: int = 1, - use_dp: bool = False, - use_deepseek_fp8_block_scale: bool = False, ): - super().__init__(quant_config) + super().__init__(moe_config, quant_config) assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), ( "Only nvfp4, fp8, bfloat16 and" " float16 quantization are currently supported." ) - self.ep_rank = ep_rank - self.ep_size = ep_size - self.tp_rank = tp_rank - self.tp_size = tp_size - self.out_dtype = out_dtype - self.use_dp = use_dp + self.ep_rank = moe_config.moe_parallel_config.ep_rank + self.ep_size = moe_config.moe_parallel_config.ep_size + self.tp_rank = moe_config.moe_parallel_config.tp_rank + self.tp_size = moe_config.moe_parallel_config.tp_size + self.out_dtype = moe_config.in_dtype + self.use_dp = moe_config.moe_parallel_config.dp_size > 1 # Enables DeepSeek-style FP8 block-scale path: # - pass per-block weight scales to the kernel # - skip input activation quantization (kernel applies scaling) - self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale + self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized + + @staticmethod + def should_pf_defer_input_quant(quant_config): + """ + FlashInfer CUTLASS Block FP8 path handles input quantization. + """ + if quant_config.is_block_quantized: + return True + return False @staticmethod def _supports_current_device() -> bool: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 7d628515a11e..d2e3b28206e3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -584,6 +584,7 @@ def __init__( num_experts=self.global_num_experts, experts_per_token=top_k, hidden_dim=hidden_size, + intermediate_size_per_partition=self.intermediate_size_per_partition, num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, in_dtype=moe_in_dtype, @@ -593,6 +594,7 @@ def __init__( is_act_and_mul=is_act_and_mul, is_lora_enabled=vllm_config.lora_config is not None, activation=activation, + device=vllm_config.device_config.device, ) self.moe_config_use_flashinfer_cutlass_kernels = ( self.moe_config.use_flashinfer_cutlass_kernels diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index b854e30b0fc5..64a36f7713c9 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -375,12 +375,79 @@ class FusedMoEPermuteExpertsUnpermute(ABC): def __init__( self, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): """ + moe_config: MoE layer configuration. quant_config: Quantization parameters for this experts instance. """ + self.moe_config = moe_config self.quant_config = quant_config + self._max_num_tokens: int | None = None + self._max_dispatchers: int | None = None + + @staticmethod + def should_pf_defer_input_quant(quant_config: FusedMoEQuantConfig) -> bool: + """ + Whether or not the PrepareFinalize should defer input quantization + in the prepare step. If True, then the Experts kernel will + execute the input quantization itself. + + Sample subclasses that override are AITER and FlashInfer CUTLASS. + """ + return False + + def _init_batched_experts_addl_params( + self, + max_num_tokens: int, + max_dispatchers: int, + ): + """ + Initialize any additional parameters needed for batched experts. + """ + self._max_num_tokens = max_num_tokens + self._max_dispatchers = max_dispatchers + + @property + def max_num_tokens(self) -> int: + if self._max_num_tokens is None: + raise AttributeError("max_num_tokens only valid for BatchedExperts") + return self._max_num_tokens + + @property + def max_dispatchers(self) -> int: + if self._max_dispatchers is None: + raise AttributeError("max_dispatchers only valid for BatchedExperts") + return self._max_dispatchers + + @classmethod + def make_standard_experts( + cls, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + ) -> "FusedMoEPermuteExpertsUnpermute": + """ + Factory method to create an instance of this class. + """ + assert cls.activation_format() == FusedMoEActivationFormat.Standard + return cls(moe_config, quant_config) + + @classmethod + def make_batched_experts( + cls, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + max_num_tokens: int, + max_dispatchers: int, + ) -> "FusedMoEPermuteExpertsUnpermute": + """ + Factory method to create an instance of this class. + """ + assert cls.activation_format() == FusedMoEActivationFormat.BatchedExperts + instance = cls(moe_config, quant_config) + instance._init_batched_experts_addl_params(max_num_tokens, max_dispatchers) + return instance @staticmethod @abstractmethod diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 54ed6e1539e3..a8a9432ecea6 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -8,6 +8,9 @@ from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -145,7 +148,7 @@ def _return_or_raise( "does not support the deployment configuration." ) - # NOTE(rob): the kernels are selected in the following order. + # NOTE: the kernels are selected in the following order. AVAILABLE_BACKENDS = [ Fp8MoeBackend.AITER, Fp8MoeBackend.FLASHINFER_TRTLLM, @@ -348,98 +351,25 @@ def make_fp8_moe_quant_config( def make_fp8_moe_kernel( - layer: torch.nn.Module, moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, fp8_backend: Fp8MoeBackend, + experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], ) -> tuple[mk.FusedMoEModularKernel, bool]: - # Delayed import is required since the oracle is imported - # by CPU backends which cannot import all of these experts. - # TODO: update the experts to make this not happen. - from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, + assert (experts_cls.activation_format() == + mk.FusedMoEActivationFormat.Standard) + defer_input_quant = experts_cls.should_pf_defer_input_quant( + moe_quant_config, ) - # NOTE(rob): this is a WIP refactor. We are first migrating - # all of the kernels in the TP case to use mk. Once this is - # done, then we will initialzie the TP case and DP/EP case - # via the same code path (i.e. via maybe_init_modular_kernel). - # NOTE(rob): in progress migrating all into this format. - use_inplace = True - if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, - ) - - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP( - defer_input_quant=moe_quant_config.is_block_quantized - ), - FlashInferExperts( - out_dtype=layer.orig_dtype, - quant_config=moe_quant_config, - ep_rank=moe_config.ep_rank, - ep_size=moe_config.ep_size, - tp_rank=moe_config.tp_rank, - tp_size=moe_config.tp_size, - use_dp=(moe_config.dp_size > 1), - use_deepseek_fp8_block_scale=moe_quant_config.is_block_quantized, - ), - ) - use_inplace = False - - elif fp8_backend == Fp8MoeBackend.AITER: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - AiterExperts, - ) - - kernel = mk.FusedMoEModularKernel( - # TODO: make defer_input_quant an attr of the AiterExperts - MoEPrepareAndFinalizeNoEP(defer_input_quant=True), - AiterExperts(quant_config=moe_quant_config), - ) - elif fp8_backend == Fp8MoeBackend.MARLIN: - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - MarlinExperts, - ) - - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - MarlinExperts(quant_config=moe_quant_config), - ) - elif fp8_backend == Fp8MoeBackend.VLLM_CUTLASS: - from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import ( - TritonOrCutlassExperts, - ) - - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - TritonOrCutlassExperts( - out_dtype=moe_config.in_dtype, - e=layer.local_num_experts, - n=layer.intermediate_size_per_partition, - k=layer.hidden_size, - device=layer.w13_weight.device, - quant_config=moe_quant_config, - ), - ) - elif fp8_backend == Fp8MoeBackend.DEEPGEMM: - from vllm.model_executor.layers.fused_moe import ( - TritonOrDeepGemmExperts, - ) - - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - TritonOrDeepGemmExperts(quant_config=moe_quant_config), - ) - else: - from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, - ) + kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant), + experts_cls.make_standard_experts( + moe_config=moe_config, + quant_config=moe_quant_config, + ), + ) - assert fp8_backend == Fp8MoeBackend.TRITON - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - TritonExperts(quant_config=moe_quant_config), - ) - return kernel, use_inplace + # TODO(rob): update inplace logic + inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS + return kernel, inplace diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 6696e3d0126d..294f91ed1393 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -277,6 +277,13 @@ def __init__(self, quant_config): @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def should_pf_defer_input_quant(quant_config): + """ + AITER Fused MoE kernels handle input quantization. + """ + return True @staticmethod def _supports_current_device() -> bool: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 6a73082ec631..c054ad4063dd 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -635,6 +635,9 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): "weight_scale_inv" if self.block_quant else "weight_scale" ) + # Create quant scheme, will be used later to select the quant scales. + # NOTE(rob): we should update QuantConfig to just be the think ts + # holds the scales. Should change the name. quant_scheme = FusedMoEQuantScheme( weight_dtype=current_platform.fp8_dtype(), act_dtype=current_platform.fp8_dtype(), @@ -649,10 +652,16 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): ) self.fp8_backend, self.experts_cls = select_fp8_moe_backend( - config=layer.moe_config, + config=self.moe, quant_scheme=quant_scheme, - # TODO(rob): select prepare_finalize here. - activation_format=mk.FusedMoEActivationFormat.Standard, + # NOTE(rob): this is a hack until we unify the DP/EP + # and TP/TEP cases. + activation_format=( + mk.FusedMoEActivationFormat.BatchedExperts if ( + self.moe.use_deepep_ll_kernels or + self.moe.use_pplx_kernels + ) else mk.FusedMoEActivationFormat.Standard + ), ) self.kernel: mk.FusedMoEModularKernel | None = None @@ -817,6 +826,7 @@ def _setup_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, + experts_cls=self.experts_cls, ) def process_weights_after_loading(self, layer: Module) -> None: @@ -891,19 +901,27 @@ def select_gemm_impl( prepare_finalize: FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: - from vllm.model_executor.layers.fused_moe import ( - BatchedDeepGemmExperts, - BatchedTritonExperts, - TritonExperts, - TritonOrDeepGemmExperts, - ) + # TODO(rob): this is probably fine? if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]: raise NotImplementedError( "Marlin and ROCm AITER are not supported with all2all yet." ) + # TODO(rob): is it really possible to get here? + if self.moe.is_lora_enabled: + from vllm.model_executor.layers.fused_moe import ( + TritonExperts, + ) + return TritonExperts(quant_config=self.moe_quant_config) + assert self.moe_quant_config is not None + if self.experts_cls.activation_format() != prepare_finalize.activation_format: + raise ValueError( + f"FoundMismatch between prepare/finalize activation format " + f"{prepare_finalize.activation_format} and experts " + f"activation format {self.experts_cls.activation_format()}." + ) if ( prepare_finalize.activation_format @@ -911,53 +929,33 @@ def select_gemm_impl( ): max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() assert max_num_tokens_per_rank is not None - - experts_impl = ( - BatchedDeepGemmExperts - if self.fp8_backend == Fp8MoeBackend.DEEPGEMM - else BatchedTritonExperts - ) + logger.debug( "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", - experts_impl.__name__, + self.experts_cls.__name__, self.__class__.__name__, max_num_tokens_per_rank, self.weight_block_size, False, ) - return experts_impl( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), + return self.experts_cls.make_batched_experts( + moe_config=self.moe, quant_config=self.moe_quant_config, + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), ) - elif self.moe.is_lora_enabled: - return TritonExperts(quant_config=self.moe_quant_config) - elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - # Select GEMM experts with block-scale when weights are block-quantized - experts = select_cutlass_fp8_gemm_impl( - self.moe, - self.moe_quant_config, - use_deepseek_fp8_block_scale=self.block_quant, - ) - logger.debug_once("Using %s", experts.__class__.__name__) - return experts - elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM: - logger.debug( - "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", - self.__class__.__name__, - self.weight_block_size, - False, - ) - return TritonOrDeepGemmExperts(self.moe_quant_config) else: - assert self.fp8_backend == Fp8MoeBackend.TRITON logger.debug( - "TritonExperts(%s): block_size=%s, per_act_token=%s", + "%s(%s): block_size=%s, per_act_token=%s", + self.experts_cls.__name__, self.__class__.__name__, self.weight_block_size, False, ) - return TritonExperts(self.moe_quant_config) + return self.experts_cls.make_standard_experts( + moe_config=self.moe, + quant_config=self.moe_quant_config, + ) def get_fused_moe_quant_config( self, layer: torch.nn.Module From 139941aae20c20bbb958cb1b12132604cf78cec4 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 20:00:33 -0500 Subject: [PATCH 047/113] updated Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/fp8.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c054ad4063dd..1b2a89bf92ec 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -53,7 +53,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - select_cutlass_fp8_gemm_impl, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -657,10 +656,9 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): # NOTE(rob): this is a hack until we unify the DP/EP # and TP/TEP cases. activation_format=( - mk.FusedMoEActivationFormat.BatchedExperts if ( - self.moe.use_deepep_ll_kernels or - self.moe.use_pplx_kernels - ) else mk.FusedMoEActivationFormat.Standard + mk.FusedMoEActivationFormat.BatchedExperts + if (self.moe.use_deepep_ll_kernels or self.moe.use_pplx_kernels) + else mk.FusedMoEActivationFormat.Standard ), ) @@ -901,18 +899,18 @@ def select_gemm_impl( prepare_finalize: FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: - # TODO(rob): this is probably fine? if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]: raise NotImplementedError( "Marlin and ROCm AITER are not supported with all2all yet." ) - # TODO(rob): is it really possible to get here? + # TODO(rob): look into whether we can just not do this for LoRA. if self.moe.is_lora_enabled: from vllm.model_executor.layers.fused_moe import ( TritonExperts, ) + return TritonExperts(quant_config=self.moe_quant_config) assert self.moe_quant_config is not None @@ -929,7 +927,7 @@ def select_gemm_impl( ): max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() assert max_num_tokens_per_rank is not None - + logger.debug( "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", self.experts_cls.__name__, @@ -942,7 +940,7 @@ def select_gemm_impl( moe_config=self.moe, quant_config=self.moe_quant_config, max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), + num_dispatchers=prepare_finalize.num_dispatchers(), ) else: logger.debug( From fb90fc06b63a503246b9fd95a2709bc86b107676 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 20:24:18 -0500 Subject: [PATCH 048/113] updated Signed-off-by: Robert Shaw --- tests/kernels/moe/test_nvfp4_moe.py | 1 - tests/kernels/moe/test_pplx_moe.py | 6 +-- .../model_executor/layers/fused_moe/config.py | 12 ++--- .../layers/fused_moe/cutlass_moe.py | 44 +++++++------------ .../fused_moe/flashinfer_cutlass_moe.py | 6 +-- .../layers/fused_moe/modular_kernel.py | 10 ++--- .../layers/fused_moe/oracle/fp8.py | 9 ++-- .../layers/fused_moe/oracle/nvfp4.py | 12 +---- .../layers/fused_moe/rocm_aiter_fused_moe.py | 5 +-- .../compressed_tensors_moe.py | 13 ++---- .../model_executor/layers/quantization/fp8.py | 6 ++- .../layers/quantization/modelopt.py | 1 - 12 files changed, 49 insertions(+), 76 deletions(-) diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 873d72117de7..97db526c2ee4 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -93,7 +93,6 @@ def test_cutlass_fp4_moe_no_graph( MoEPrepareAndFinalizeNoEP(defer_input_quant=True), CutlassExpertsFp4( out_dtype=dtype, - max_experts_per_worker=e, quant_config=quant_config, ), ) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index e0a1f49aced8..aa94f6f7a566 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -716,9 +716,9 @@ def _pplx_moe( else: chunked_shared_output = None - chunked_torch_output = chunk_by_rank( - torch_output, pgi.rank, pgi.world_size - ).to(pplx_output.device) + chunked_torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to( + pplx_output.device + ) torch.testing.assert_close( pplx_output, chunked_torch_output, atol=3e-2, rtol=3e-2 diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index cbcfcb6f487b..9b44442d93f5 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -952,10 +952,12 @@ def use_deepep_ht_kernels(self): @property def use_deepep_ll_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" - + @property def use_fi_all2all_kernels(self): - return self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv" + return ( + self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv" + ) @staticmethod def flatten_tp_across_dp_and_pcp( @@ -1102,8 +1104,9 @@ class FusedMoEConfig: experts_per_token: int hidden_dim: int intermediate_size_per_partition: int - num_local_experts: int + activation: str + device: torch.dtype moe_parallel_config: FusedMoEParallelConfig # The activation type. @@ -1120,9 +1123,6 @@ class FusedMoEConfig: is_lora_enabled: bool = False - activation: str - device: torch.dtype - def __post_init__(self): if self.dp_size > 1: logger.debug_once( diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 86479e405ecc..c283e9c98e55 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -10,10 +10,10 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, FusedMoEQuantScheme, - FusedMoEConfig, ) from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( moe_permute, @@ -247,14 +247,14 @@ def __init__( self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, - device: torch.dtype, ): assert quant_config.use_fp8_w8a8 - super().__init__(quant_config) - + super().__init__(moe_config=moe_config, quant_config=quant_config) + e = moe_config.num_local_experts n = moe_config.intermediate_size_per_partition k = moe_config.hidden_dim + device = moe_config.device ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64) ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64) c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64) @@ -384,18 +384,6 @@ def workspace_shapes( class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): - def __init__( - self, - max_experts_per_worker: int, - num_dispatchers: int, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - assert max_experts_per_worker > 0 - self.max_experts_per_worker = max_experts_per_worker - self.num_dispatchers = num_dispatchers - @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.BatchedExperts @@ -419,11 +407,12 @@ def workspace_shapes( local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - num_dp = self.num_dispatchers + num_dp = self.max_dispatchers assert num_dp is not None - workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K)) - workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K)) - output = (self.max_experts_per_worker, M, K) + experts_per_worker = self.moe_config.num_local_experts + workspace1 = (experts_per_worker, M * num_dp, max(N, K)) + workspace2 = (experts_per_worker, M * num_dp, max(N // 2, K)) + output = (experts_per_worker, M, K) return (workspace1, workspace2, output) @@ -605,13 +594,12 @@ def run_cutlass_moe_fp4( class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_experts_per_worker: int, - out_dtype: torch.dtype, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): - super().__init__(quant_config) - self.max_experts_per_worker = max_experts_per_worker - self.out_dtype = out_dtype + super().__init__(moe_config=moe_config, quant_config=quant_config) + self.max_experts_per_worker = moe_config.num_local_experts + self.out_dtype = moe_config.in_dtype @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: @@ -845,11 +833,12 @@ def __init__( c_strides2: torch.Tensor, s_strides1: torch.Tensor, s_strides2: torch.Tensor, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, group_size: int, ): - super().__init__(quant_config) - self.out_dtype = out_dtype + super().__init__(moe_config=moe_config, quant_config=quant_config) + self.out_dtype = moe_config.in_dtype self.a_strides1 = a_strides1 self.a_strides2 = a_strides2 self.b_strides1 = b_strides1 @@ -1067,6 +1056,7 @@ def cutlass_moe_w4a8_fp8( c_strides2=c_strides2, s_strides1=s_strides1, s_strides2=s_strides2, + # TODO: quant_config=quant_config, group_size=group_size, ), diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 2111a8ed0c52..507a6ca73589 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -43,15 +43,13 @@ def __init__( # - pass per-block weight scales to the kernel # - skip input activation quantization (kernel applies scaling) self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized - + @staticmethod def should_pf_defer_input_quant(quant_config): """ FlashInfer CUTLASS Block FP8 path handles input quantization. """ - if quant_config.is_block_quantized: - return True - return False + return quant_config.is_block_quantized @staticmethod def _supports_current_device() -> bool: diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 64a36f7713c9..c83b18b87240 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -391,13 +391,13 @@ def __init__( def should_pf_defer_input_quant(quant_config: FusedMoEQuantConfig) -> bool: """ Whether or not the PrepareFinalize should defer input quantization - in the prepare step. If True, then the Experts kernel will + in the prepare step. If True, then the Experts kernel will execute the input quantization itself. Sample subclasses that override are AITER and FlashInfer CUTLASS. """ return False - + def _init_batched_experts_addl_params( self, max_num_tokens: int, @@ -408,7 +408,7 @@ def _init_batched_experts_addl_params( """ self._max_num_tokens = max_num_tokens self._max_dispatchers = max_dispatchers - + @property def max_num_tokens(self) -> int: if self._max_num_tokens is None: @@ -420,7 +420,7 @@ def max_dispatchers(self) -> int: if self._max_dispatchers is None: raise AttributeError("max_dispatchers only valid for BatchedExperts") return self._max_dispatchers - + @classmethod def make_standard_experts( cls, @@ -432,7 +432,7 @@ def make_standard_experts( """ assert cls.activation_format() == FusedMoEActivationFormat.Standard return cls(moe_config, quant_config) - + @classmethod def make_batched_experts( cls, diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index a8a9432ecea6..2d29231d0484 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -8,9 +8,6 @@ from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -21,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( is_supported_config_trtllm, ) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, get_flashinfer_moe_backend, @@ -356,8 +356,7 @@ def make_fp8_moe_kernel( fp8_backend: Fp8MoeBackend, experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], ) -> tuple[mk.FusedMoEModularKernel, bool]: - assert (experts_cls.activation_format() == - mk.FusedMoEActivationFormat.Standard) + assert experts_cls.activation_format() == mk.FusedMoEActivationFormat.Standard defer_input_quant = experts_cls.should_pf_defer_input_quant( moe_quant_config, ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 547a2a795d19..6198ff415ab7 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -248,14 +248,8 @@ def make_nvfp4_moe_kernel( return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(defer_input_quant=True), FlashInferExperts( - out_dtype=moe_config.in_dtype, + moe_config=moe_config, quant_config=quant_config, - ep_rank=moe_config.ep_rank, - ep_size=moe_config.ep_size, - tp_rank=moe_config.tp_rank, - tp_size=moe_config.tp_size, - use_dp=False, - use_deepseek_fp8_block_scale=False, ), ) @@ -263,9 +257,7 @@ def make_nvfp4_moe_kernel( return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(defer_input_quant=True), CutlassExpertsFp4( - out_dtype=moe_config.in_dtype, - # TODO(rob): see what impact this has on expert map? - max_experts_per_worker=moe_config.num_experts, + moe_config=moe_config, quant_config=quant_config, ), ) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 294f91ed1393..1bbeba024208 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -271,13 +271,10 @@ def rocm_aiter_fused_experts( class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config): - super().__init__(quant_config) - @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard - + @staticmethod def should_pf_defer_input_quant(quant_config): """ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 86e07c35ed0f..ac3531012621 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -819,7 +819,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: self.kernel, self.use_inplace = make_fp8_moe_kernel( - layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, @@ -856,15 +855,11 @@ def select_gemm_impl( == FusedMoEActivationFormat.BatchedExperts ): logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__) - experts = CutlassBatchedExpertsFp8( - max_experts_per_worker=self.moe.num_local_experts, - num_dispatchers=num_dispatchers, - out_dtype=self.moe.in_dtype, - e=layer.local_num_experts, - n=layer.intermediate_size_per_partition, - k=layer.hidden_size, - device=layer.w13_weight.device, + experts = CutlassBatchedExpertsFp8.make_batched_experts( + moe_config=self.moe, quant_config=self.moe_quant_config, + max_num_tokens=prepare_finalize.max_num_tokens_per_rank(), + num_dispatchers=num_dispatchers, ) else: logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1b2a89bf92ec..39b66004ce40 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -819,8 +819,8 @@ def _setup_kernel( # Setup modular kernel for TP case. self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: + assert self.experts_cls is not None self.kernel, self.use_inplace = make_fp8_moe_kernel( - layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, @@ -905,6 +905,10 @@ def select_gemm_impl( "Marlin and ROCm AITER are not supported with all2all yet." ) + if self.experts_cls is None: + raise NotImplementedError + assert self.experts_cls is not None + # TODO(rob): look into whether we can just not do this for LoRA. if self.moe.is_lora_enabled: from vllm.model_executor.layers.fused_moe import ( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 986d0a906e05..6816964698b4 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -883,7 +883,6 @@ def _setup_kernel( self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: self.kernel, self.use_inplace = make_fp8_moe_kernel( - layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, From b3c818d522b2c4807fb9efc638813cbef323d8c0 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 12 Jan 2026 20:27:47 -0500 Subject: [PATCH 049/113] updated Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/modelopt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 6816964698b4..a2cd8758f1f8 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -886,7 +886,7 @@ def _setup_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, - ) + ) # noqa def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13 = layer.w13_weight From e9327433872c3534cccbaae0a271fb3b4e5446ee Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 10:55:45 -0500 Subject: [PATCH 050/113] updated Signed-off-by: Robert Shaw --- .../layers/fused_moe/all2all_utils.py | 25 +++++- .../model_executor/layers/fused_moe/config.py | 6 +- .../layers/fused_moe/fallback.py | 4 +- vllm/model_executor/layers/fused_moe/layer.py | 10 ++- .../layers/fused_moe/oracle/fp8.py | 45 +++++++--- .../model_executor/layers/quantization/fp8.py | 90 ++++--------------- 6 files changed, 90 insertions(+), 90 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index d6740d8a3ff5..7913c88f6d94 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -12,11 +12,15 @@ FusedMoEParallelConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( + FlashInferAllToAllMoEPrepareAndFinalize, +) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPrepareAndFinalize, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNaiveEP, + MoEPrepareAndFinalizeNoEP, ) from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_ep, has_pplx @@ -71,9 +75,15 @@ def maybe_make_prepare_finalize( moe: FusedMoEConfig, quant_config: FusedMoEQuantConfig | None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + defer_input_quant: bool = False, + allow_new_interface: bool = False, ) -> FusedMoEPrepareAndFinalize | None: if not moe.moe_parallel_config.use_all2all_kernels: - return None + # What happens if DP>1 and EP=1? + if allow_new_interface: + return MoEPrepareAndFinalizeNoEP(defer_input_quant) + else: + return None all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None @@ -173,7 +183,16 @@ def maybe_make_prepare_finalize( local_expert_global_ids=local_expert_global_ids, ) - elif moe.use_naive_kernels: - prepare_finalize = MoEPrepareAndFinalizeNaiveEP() + elif moe.use_fi_all2allv_kernels: + assert quant_config is not None + # TODO: audit if this supports all cases. + prepare_finalize = FlashInferAllToAllMoEPrepareAndFinalize( + use_dp=True, + num_dispatchers=all2all_manager.world_size, + use_deepseek_fp8_block_scale=quant_config.is_block_quantized, + ) + + elif moe.use_naive_kernels and allow_new_interface: + prepare_finalize = MoEPrepareAndFinalizeNaiveEP(defer_input_quant) return prepare_finalize diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index ff5b629dc06a..be558451e295 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -954,7 +954,7 @@ def use_deepep_ll_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" @property - def use_fi_all2all_kernels(self): + def use_fi_all2allv_kernels(self): return ( self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv" ) @@ -1188,6 +1188,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_fi_all2allv_kernels(self): + return self.moe_parallel_config.use_fi_all2allv_kernels + @property def use_naive_kernels(self): return self.moe_parallel_config.use_naive_kernels diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py index ded3b592b124..ec59c04f8b89 100644 --- a/vllm/model_executor/layers/fused_moe/fallback.py +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -16,7 +16,9 @@ def __init__( experts: mk.FusedMoEPermuteExpertsUnpermute, fallback_experts: mk.FusedMoEPermuteExpertsUnpermute, ): - super().__init__(experts.quant_config) + super().__init__( + moe_config=experts.moe_config, quant_config=experts.quant_config + ) self.fallback_experts = fallback_experts self.experts = experts diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 11022c7ec1c9..613b026c5943 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -688,6 +688,10 @@ def _get_quant_method() -> FusedMoEMethodBase: # This is called after all weight loading and post-processing, so it # should be safe to swap out the quant_method. def maybe_init_modular_kernel(self) -> None: + if getattr(self.quant_config, "_SUPPORTS_MK_INTERALLY", False): + logger.info("SKIPPING MK INIT --> done by quant integration internally.") + return + self.ensure_moe_quant_config_init() # routing_tables only needed for round-robin expert placement with # DeepEP all2all backend. @@ -1922,8 +1926,10 @@ def forward_impl( hidden_states, router_logits, has_separate_shared_experts ) - do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance( - self.quant_method, FusedMoEModularMethod + # TODO(rob): remove this once we migrate to internal use of MK. + do_naive_dispatch_combine: bool = self.dp_size > 1 and not ( + isinstance(self.quant_method, FusedMoEModularMethod) + or getattr(self.quant_method, "_SUPPORTS_MK_INTERALLY", False) ) ctx = get_forward_context() diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 2d29231d0484..0e4bf1379ece 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum +from typing import TYPE_CHECKING import torch @@ -18,9 +19,6 @@ from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( is_supported_config_trtllm, ) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, get_flashinfer_moe_backend, @@ -34,6 +32,9 @@ prepare_fp8_moe_layer_for_marlin, ) +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.fused_moe import FusedMoE + logger = init_logger(__name__) @@ -351,24 +352,44 @@ def make_fp8_moe_quant_config( def make_fp8_moe_kernel( + layer: "FusedMoE", moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, fp8_backend: Fp8MoeBackend, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], ) -> tuple[mk.FusedMoEModularKernel, bool]: - assert experts_cls.activation_format() == mk.FusedMoEActivationFormat.Standard - defer_input_quant = experts_cls.should_pf_defer_input_quant( - moe_quant_config, - ) - - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant), - experts_cls.make_standard_experts( + # Create Experts. + if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.Standard: + experts = experts_cls.make_standard_experts( + moe_config=moe_config, + quant_config=moe_quant_config, + ) + else: + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None + experts = experts_cls.make_batched_experts( moe_config=moe_config, quant_config=moe_quant_config, + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + ) + + # NOTE(rob): we only want the ModularKernel to control the SharedExpert + # if we are using all2all (for SBO). Need to make a change somewhere + # else to prevent double running the Shared Expert. + # This needs to be refactored. + kernel = mk.FusedMoEModularKernel( + prepare_finalize, + experts, + shared_experts=( + getattr(layer, "shared_expert", None) + if moe_config.moe_parallel_config.use_all2all_kernels + else None ), + moe_parallel_config=moe_config.moe_parallel_config, ) - # TODO(rob): update inplace logic + # TODO(rob): update inplace logic to be part of the kernel. inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS return kernel, inplace diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 937bc37a5e82..4f25399af00b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -19,7 +19,6 @@ ) from vllm.model_executor.layers.fused_moe import ( FusedMoE, - FusedMoEActivationFormat, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, @@ -52,7 +51,6 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, - build_flashinfer_fp8_cutlass_moe_prepare_finalize, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -613,6 +611,8 @@ def apply( class Fp8MoEMethod(FusedMoEMethodBase): + _SUPPORTS_MK_INTERALLY = True + """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. @@ -650,18 +650,22 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): static_input_quant=(self.quant_config.activation_scheme == "static"), ) + # Select Fp8 MoE backend. + use_batched = ( + self.moe.moe_parallel_config.use_deepep_ll_kernels + or self.moe.moe_parallel_config.use_pplx_kernels + ) self.fp8_backend, self.experts_cls = select_fp8_moe_backend( config=self.moe, quant_scheme=quant_scheme, - # NOTE(rob): this is a hack until we unify the DP/EP - # and TP/TEP cases. activation_format=( mk.FusedMoEActivationFormat.BatchedExperts - if (self.moe.use_deepep_ll_kernels or self.moe.use_pplx_kernels) + if use_batched else mk.FusedMoEActivationFormat.Standard ), ) + # Delay creation of the kernel until after process-weights. self.kernel: mk.FusedMoEModularKernel | None = None def create_weights( @@ -821,9 +825,11 @@ def _setup_kernel( if self.moe_quant_config: assert self.experts_cls is not None self.kernel, self.use_inplace = make_fp8_moe_kernel( + layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, + prepare_finalize=self.prepare_finalize, experts_cls=self.experts_cls, ) @@ -879,37 +885,16 @@ def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if self.fp8_backend in [ - Fp8MoeBackend.AITER, - Fp8MoeBackend.MARLIN, - Fp8MoeBackend.FLASHINFER_TRTLLM, - ]: - return None - elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - # TODO(rob): we can remove this. - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - self.moe, - use_deepseek_fp8_block_scale=self.block_quant, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + "Fp8FusedMoE uses the new modular kernel initialization logic. " + "This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: - # TODO(rob): this is probably fine? - if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]: - raise NotImplementedError( - "Marlin and ROCm AITER are not supported with all2all yet." - ) - - if self.experts_cls is None: - raise NotImplementedError - assert self.experts_cls is not None - # TODO(rob): look into whether we can just not do this for LoRA. if self.moe.is_lora_enabled: from vllm.model_executor.layers.fused_moe import ( @@ -918,47 +903,10 @@ def select_gemm_impl( return TritonExperts(quant_config=self.moe_quant_config) - assert self.moe_quant_config is not None - if self.experts_cls.activation_format() != prepare_finalize.activation_format: - raise ValueError( - f"FoundMismatch between prepare/finalize activation format " - f"{prepare_finalize.activation_format} and experts " - f"activation format {self.experts_cls.activation_format()}." - ) - - if ( - prepare_finalize.activation_format - == FusedMoEActivationFormat.BatchedExperts - ): - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - assert max_num_tokens_per_rank is not None - - logger.debug( - "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", - self.experts_cls.__name__, - self.__class__.__name__, - max_num_tokens_per_rank, - self.weight_block_size, - False, - ) - return self.experts_cls.make_batched_experts( - moe_config=self.moe, - quant_config=self.moe_quant_config, - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - ) - else: - logger.debug( - "%s(%s): block_size=%s, per_act_token=%s", - self.experts_cls.__name__, - self.__class__.__name__, - self.weight_block_size, - False, - ) - return self.experts_cls.make_standard_experts( - moe_config=self.moe, - quant_config=self.moe_quant_config, - ) + raise ValueError( + "Fp8FusedMoE uses the new modular kernel initialization logic. " + "This function should not be called." + ) def get_fused_moe_quant_config( self, layer: torch.nn.Module From b21e8beb2512bff34f56b8a879a3a49002128845 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 11:03:35 -0500 Subject: [PATCH 051/113] updated Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/fp8.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 4f25399af00b..29650d8e32fe 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -820,16 +820,28 @@ def _setup_kernel( replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale) replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale) + from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, + ) + # Setup modular kernel for TP case. self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: + prepare_finalize = maybe_make_prepare_finalize( + moe=self.moe, + quant_config=self.moe_quant_config, + routing_tables=None, # TODO: init routing tables here? + defer_input_quant=self.experts_cls.should_pf_defer_input_quant(), + allow_new_interface=True, + ) + assert self.experts_cls is not None self.kernel, self.use_inplace = make_fp8_moe_kernel( layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, - prepare_finalize=self.prepare_finalize, + prepare_finalize=prepare_finalize, experts_cls=self.experts_cls, ) From 7a269bce601644a3e3755e97e7aa8318fe93b73f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 12:23:16 -0500 Subject: [PATCH 052/113] updated Signed-off-by: Robert Shaw --- .../layers/fused_moe/fused_moe_method_base.py | 11 +++++++++++ vllm/model_executor/layers/fused_moe/layer.py | 4 ++-- .../layers/fused_moe/oracle/fp8.py | 13 ++++++++++++- .../model_executor/layers/quantization/fp8.py | 19 ++++--------------- 4 files changed, 29 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 389ccf358c56..b3085d43c0b2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -30,6 +30,17 @@ def __init__(self, moe: FusedMoEConfig): self.moe: FusedMoEConfig = moe self.moe_quant_config: FusedMoEQuantConfig | None = None + @property + def supports_mk_interally(self) -> bool: + """ + Returns True if this method supports using modular kernels (MK) + internally for MoE operations, False otherwise. + + This method should be overridden by subclasses that support + modular kernels internally. + """ + return False + @abstractmethod def create_weights( self, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 613b026c5943..18d292dcb99e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -688,7 +688,7 @@ def _get_quant_method() -> FusedMoEMethodBase: # This is called after all weight loading and post-processing, so it # should be safe to swap out the quant_method. def maybe_init_modular_kernel(self) -> None: - if getattr(self.quant_config, "_SUPPORTS_MK_INTERALLY", False): + if self.quant_method.supports_mk_interally: logger.info("SKIPPING MK INIT --> done by quant integration internally.") return @@ -1929,7 +1929,7 @@ def forward_impl( # TODO(rob): remove this once we migrate to internal use of MK. do_naive_dispatch_combine: bool = self.dp_size > 1 and not ( isinstance(self.quant_method, FusedMoEModularMethod) - or getattr(self.quant_method, "_SUPPORTS_MK_INTERALLY", False) + or self.quant_method.supports_mk_interally ) ctx = get_forward_context() diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 0e4bf1379ece..e7f1a7899763 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -9,6 +9,9 @@ from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -356,9 +359,17 @@ def make_fp8_moe_kernel( moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, fp8_backend: Fp8MoeBackend, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], ) -> tuple[mk.FusedMoEModularKernel, bool]: + # Create Prepare/Finalize. + prepare_finalize = maybe_make_prepare_finalize( + moe=moe_config, + quant_config=moe_quant_config, + routing_tables=None, # TODO: init routing tables here? + defer_input_quant=experts_cls.should_pf_defer_input_quant(), + allow_new_interface=True, + ) + # Create Experts. if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.Standard: experts = experts_cls.make_standard_experts( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 29650d8e32fe..f0ad852e6fae 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -611,8 +611,6 @@ def apply( class Fp8MoEMethod(FusedMoEMethodBase): - _SUPPORTS_MK_INTERALLY = True - """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. @@ -668,6 +666,10 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): # Delay creation of the kernel until after process-weights. self.kernel: mk.FusedMoEModularKernel | None = None + @property + def supports_mk_interally(self) -> bool: + return True + def create_weights( self, layer: Module, @@ -820,28 +822,15 @@ def _setup_kernel( replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale) replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale) - from vllm.model_executor.layers.fused_moe.all2all_utils import ( - maybe_make_prepare_finalize, - ) - # Setup modular kernel for TP case. self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: - prepare_finalize = maybe_make_prepare_finalize( - moe=self.moe, - quant_config=self.moe_quant_config, - routing_tables=None, # TODO: init routing tables here? - defer_input_quant=self.experts_cls.should_pf_defer_input_quant(), - allow_new_interface=True, - ) - assert self.experts_cls is not None self.kernel, self.use_inplace = make_fp8_moe_kernel( layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, - prepare_finalize=prepare_finalize, experts_cls=self.experts_cls, ) From 02b0848d027d9ed9f70e619ca858c5787cf2a15d Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 12:54:25 -0500 Subject: [PATCH 053/113] updated Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/fp8.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f0ad852e6fae..715b9809eb50 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -648,7 +648,10 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): static_input_quant=(self.quant_config.activation_scheme == "static"), ) - # Select Fp8 MoE backend. + # Select Fp8 MoE backend + # NOTE(rob): this is kind of a hack. We need to peak into + # the prepare-finalize selection to determine if we are using + # the batched or standard expert format. use_batched = ( self.moe.moe_parallel_config.use_deepep_ll_kernels or self.moe.moe_parallel_config.use_pplx_kernels From 85615f1b08c4c8ed71580c737f36528087f3511b Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 16:17:34 -0500 Subject: [PATCH 054/113] update kernel selection logic Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/oracle/fp8.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index e7f1a7899763..17f7db81a4fb 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -218,7 +218,11 @@ def _return_or_raise( if not envs.VLLM_USE_DEEP_GEMM or not envs.VLLM_MOE_USE_DEEP_GEMM: AVAILABLE_BACKENDS.remove(Fp8MoeBackend.DEEPGEMM) else: - backend = Fp8MoeBackend.DEEPGEMM + backend = ( + Fp8MoeBackend.DEEPGEMM + if activation_format == mk.FusedMoEActivationFormat.Standard + else Fp8MoeBackend.BATCHED_DEEPGEMM + ) return _return_or_raise(backend, config, quant_scheme, activation_format) elif envs.VLLM_TEST_FORCE_FP8_MARLIN: From 9a4a871f63c31b1291716479bdfb61d5dfba6201 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 22:27:48 +0000 Subject: [PATCH 055/113] seem to have accuracy with dp/ep and tp for deepgemm Signed-off-by: Robert Shaw --- .../layers/fused_moe/batched_deep_gemm_moe.py | 13 ++-- .../layers/fused_moe/cutlass_moe.py | 74 +++++++++---------- .../layers/fused_moe/deep_gemm_moe.py | 63 +++++++--------- .../fused_moe/flashinfer_cutlass_moe.py | 4 +- .../layers/fused_moe/flashinfer_trtllm_moe.py | 2 +- .../layers/fused_moe/fused_moe.py | 16 ++-- vllm/model_executor/layers/fused_moe/layer.py | 2 +- .../layers/fused_moe/modular_kernel.py | 54 ++++++++------ .../layers/fused_moe/oracle/fp8.py | 13 +++- .../layers/fused_moe/rocm_aiter_fused_moe.py | 2 +- .../layers/fused_moe/triton_deep_gemm_moe.py | 7 +- .../model_executor/layers/quantization/fp8.py | 6 ++ 12 files changed, 132 insertions(+), 124 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index b1130dd0e495..fc70d0979cbf 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -8,6 +8,7 @@ from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, FusedMoEQuantScheme, @@ -258,8 +259,7 @@ def persistent_masked_m_silu_mul_quant( class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_num_tokens: int, - num_dispatchers: int, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): """ @@ -267,11 +267,9 @@ def __init__( num_dispatchers: The number of DP dispatchers. quant_config: Quantization configuration """ - super().__init__(quant_config) + super().__init__(moe_config, quant_config) assert self.block_shape == get_mk_alignment_for_contiguous_layout() assert self.quant_config.use_fp8_w8a8 - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: @@ -287,7 +285,7 @@ def _supports_no_act_and_mul() -> bool: @staticmethod def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return quant_scheme.is_fp8_w8a8 and quant_scheme.block_size == [128, 128] + return quant_scheme.is_fp8_w8a8 and quant_scheme.block_size == (128, 128) @staticmethod def _supports_activation(activation: str) -> bool: @@ -332,7 +330,8 @@ def workspace_shapes( # end up sending their tokens. This needs to be fixed. num_dispatchers = self.num_dispatchers num_experts = local_num_experts - max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens + # max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens + max_num_tokens = self.max_num_tokens activation_out_dim = self.adjust_N_for_activation(N, activation) workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N)) workspace2 = (num_experts, max_num_tokens * num_dispatchers, activation_out_dim) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 3bcc5a8b64b9..5da1006d5397 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -19,9 +19,6 @@ moe_permute, moe_unpermute, ) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP, @@ -410,7 +407,7 @@ def workspace_shapes( expert_tokens_meta: mk.ExpertTokensMetadata | None, activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - num_dp = self.max_dispatchers + num_dp = self.num_dispatchers assert num_dp is not None experts_per_worker = self.moe_config.num_local_experts activation_out_dim = self.adjust_N_for_activation(N, activation) @@ -640,7 +637,8 @@ def workspace_shapes( workspace1: tuple[int, ...] = () workspace2: tuple[int, ...] = () output: tuple[int, ...] = () - if self.use_batched_format: + if False: + # if self.use_batched_format: workspace1 = (self.max_experts_per_worker, M, max(N, K)) workspace2 = (self.max_experts_per_worker, M, activation_out_dim) output = (self.max_experts_per_worker, M, K) @@ -1060,36 +1058,36 @@ def cutlass_moe_w4a8_fp8( Returns: - torch.Tensor: The bf16 output tensor after applying the MoE layer. """ - assert quant_config is not None - - num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) - - fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - CutlassExpertsW4A8Fp8( - out_dtype=a.dtype, - a_strides1=a_strides1, - a_strides2=a_strides2, - b_strides1=b_strides1, - b_strides2=b_strides2, - c_strides1=c_strides1, - c_strides2=c_strides2, - s_strides1=s_strides1, - s_strides2=s_strides2, - # TODO: - quant_config=quant_config, - group_size=group_size, - ), - ) - - return fn( - a, - w1_q, - w2_q, - topk_weights, - topk_ids, - activation=activation, - global_num_experts=num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) + # assert quant_config is not None + + # num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) + + # fn = mk.FusedMoEModularKernel( + # MoEPrepareAndFinalizeNoEP(), + # CutlassExpertsW4A8Fp8( + # out_dtype=a.dtype, + # a_strides1=a_strides1, + # a_strides2=a_strides2, + # b_strides1=b_strides1, + # b_strides2=b_strides2, + # c_strides1=c_strides1, + # c_strides2=c_strides2, + # s_strides1=s_strides1, + # s_strides2=s_strides2, + # # TODO: + # quant_config=quant_config, + # group_size=group_size, + # ), + # ) + + # return fn( + # a, + # w1_q, + # w2_q, + # topk_weights, + # topk_ids, + # activation=activation, + # global_num_experts=num_experts, + # expert_map=expert_map, + # apply_router_weight_on_input=apply_router_weight_on_input, + # ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 546270eb81e2..871100abdd2a 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -6,19 +6,16 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, FusedMoEQuantScheme, - fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce, ) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) @@ -112,8 +109,8 @@ def _valid_deep_gemm( class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config: FusedMoEQuantConfig): - super().__init__(quant_config) + def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig): + super().__init__(moe_config=moe_config, quant_config=quant_config) assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout() assert quant_config.quant_dtype == torch.float8_e4m3fn assert not quant_config.per_act_token_quant @@ -133,11 +130,7 @@ def _supports_no_act_and_mul() -> bool: @staticmethod def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return ( - quant_scheme.is_fp8_w8a8 - and quant_scheme.per_block_quant - and quant_scheme.block_size == [128, 128] - ) + return quant_scheme.is_fp8_w8a8 and quant_scheme.block_size == (128, 128) @staticmethod def _supports_activation(activation: str) -> bool: @@ -360,27 +353,27 @@ def deep_gemm_moe_fp8( Returns: - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. """ - quant_config = fp8_w8a8_moe_quant_config( - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=get_mk_alignment_for_contiguous_layout(), - ) - - fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - DeepGemmExperts(quant_config), - ) - return fn( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace=inplace, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) + # quant_config = fp8_w8a8_moe_quant_config( + # w1_scale=w1_scale, + # w2_scale=w2_scale, + # a1_scale=a1_scale, + # a2_scale=a2_scale, + # block_shape=get_mk_alignment_for_contiguous_layout(), + # ) + + # fn = mk.FusedMoEModularKernel( + # MoEPrepareAndFinalizeNoEP(), + # DeepGemmExperts(quant_config), + # ) + # return fn( + # hidden_states, + # w1, + # w2, + # topk_weights, + # topk_ids, + # inplace=inplace, + # activation=activation, + # global_num_experts=global_num_experts, + # expert_map=expert_map, + # apply_router_weight_on_input=apply_router_weight_on_input, + # ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 1bd69639288c..2866f889b0ee 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -45,7 +45,7 @@ def __init__( self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized @staticmethod - def should_pf_defer_input_quant(quant_config): + def should_pf_defer_input_quant(quant_config: FusedMoEQuantConfig) -> bool: """ FlashInfer CUTLASS Block FP8 path handles input quantization. """ @@ -77,7 +77,7 @@ def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: or (s.is_fp8_w8a8 and s.per_tensor_quant and s.static_input_quant) or ( s.is_fp8_w8a8 - and s.block_size == [128, 128] + and s.block_size == (128, 128) and p.is_device_capability((9, 0)) ) or (s.is_nvfp4_w4a4 and p.has_device_capability((10, 0))) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 8cc78e476e3e..086ca87d31b6 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -41,7 +41,7 @@ def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: s = quant_scheme return ( (s.is_fp8_w8a8 and s.per_tensor_quant and s.static_input_quant) - or (s.is_fp8_w8a8 and s.block_size == [128, 128]) + or (s.is_fp8_w8a8 and s.block_size == (128, 128)) or (s.is_nvfp4_w4a4) ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index cf8b3e77186c..6c38c3dfeaaa 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -21,6 +21,7 @@ ) from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, FusedMoEQuantScheme, @@ -2301,9 +2302,10 @@ def fused_experts_impl( class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): - super().__init__(quant_config) + super().__init__(moe_config, quant_config) @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: @@ -2512,12 +2514,6 @@ def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: class TritonWNA16Experts(TritonExperts): - def __init__( - self, - quant_config: FusedMoEQuantConfig, - ): - super().__init__(quant_config) - @staticmethod def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: # TODO @@ -2660,10 +2656,12 @@ def apply( def modular_triton_fused_moe( - quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + shared_experts: torch.nn.Module | None = None, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - TritonExperts(quant_config), + TritonExperts(moe_config, quant_config), shared_experts, ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 18d292dcb99e..19dc5aeafcf9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -689,7 +689,7 @@ def _get_quant_method() -> FusedMoEMethodBase: # should be safe to swap out the quant_method. def maybe_init_modular_kernel(self) -> None: if self.quant_method.supports_mk_interally: - logger.info("SKIPPING MK INIT --> done by quant integration internally.") + logger.info_once("DEBUG: SKIPPING MK INIT: Handled Internally!!!!") return self.ensure_moe_quant_config_init() diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index ef3a9b572545..baa7c72351b4 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -386,7 +386,7 @@ def __init__( self.moe_config = moe_config self.quant_config = quant_config self._max_num_tokens: int | None = None - self._max_dispatchers: int | None = None + self._num_dispatchers: int | None = None @staticmethod def should_pf_defer_input_quant(quant_config: FusedMoEQuantConfig) -> bool: @@ -402,13 +402,13 @@ def should_pf_defer_input_quant(quant_config: FusedMoEQuantConfig) -> bool: def _init_batched_experts_addl_params( self, max_num_tokens: int, - max_dispatchers: int, + num_dispatchers: int, ): """ Initialize any additional parameters needed for batched experts. """ self._max_num_tokens = max_num_tokens - self._max_dispatchers = max_dispatchers + self._num_dispatchers = num_dispatchers @property def max_num_tokens(self) -> int: @@ -417,10 +417,10 @@ def max_num_tokens(self) -> int: return self._max_num_tokens @property - def max_dispatchers(self) -> int: - if self._max_dispatchers is None: - raise AttributeError("max_dispatchers only valid for BatchedExperts") - return self._max_dispatchers + def num_dispatchers(self) -> int: + if self._num_dispatchers is None: + raise AttributeError("num_dispatchers only valid for BatchedExperts") + return self._num_dispatchers @classmethod def make_standard_experts( @@ -440,14 +440,14 @@ def make_batched_experts( moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, max_num_tokens: int, - max_dispatchers: int, + num_dispatchers: int, ) -> "FusedMoEPermuteExpertsUnpermute": """ Factory method to create an instance of this class. """ assert cls.activation_format() == FusedMoEActivationFormat.BatchedExperts instance = cls(moe_config, quant_config) - instance._init_batched_experts_addl_params(max_num_tokens, max_dispatchers) + instance._init_batched_experts_addl_params(max_num_tokens, num_dispatchers) return instance @staticmethod @@ -513,18 +513,26 @@ def is_supported_config( moe_config: FusedMoEConfig, moe_quant_scheme: FusedMoEQuantScheme, activation_format: FusedMoEActivationFormat, - ) -> bool: - return ( - (cls._supports_current_device()) - and (moe_config.is_act_and_mul or cls._supports_no_act_and_mul()) - and cls._supports_activation(moe_config.activation) - and cls._supports_quant_scheme(moe_quant_scheme) - and cls._supports_parallel_config(moe_config.moe_parallel_config) - and (activation_format == cls.activation_format()) - ) + ) -> tuple[bool, str | None]: + def _make_reason(reason: str) -> str: + return f"kernel does not support {reason}" + + if not cls._supports_current_device(): + return False, _make_reason("current device") + elif not (moe_config.is_act_and_mul or cls._supports_no_act_and_mul()): + return False, _make_reason("no act_and_mul MLP layer") + elif not cls._supports_activation(moe_config.activation): + return False, _make_reason(f"{moe_config.activation} activation") + elif not cls._supports_quant_scheme(moe_quant_scheme): + return False, _make_reason("quantization scheme") + elif not cls._supports_parallel_config(moe_config.moe_parallel_config): + return False, _make_reason("parallel config") + elif activation_format != cls.activation_format(): + return False, _make_reason(f"{activation_format.value} activation format") + return True, None - @abstractmethod @staticmethod + @abstractmethod def _supports_current_device() -> bool: """ Whether the kernel supports the current device type @@ -532,8 +540,8 @@ def _supports_current_device() -> bool: """ raise NotImplementedError - @abstractmethod @staticmethod + @abstractmethod def _supports_no_act_and_mul() -> bool: """ Whether the kernel supports act_and_mul=False, i.e. @@ -541,21 +549,21 @@ def _supports_no_act_and_mul() -> bool: """ raise NotImplementedError - @abstractmethod @staticmethod + @abstractmethod def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: raise NotImplementedError - @abstractmethod @staticmethod + @abstractmethod def _supports_activation(activation: str) -> bool: """ Whether the kernel supports a particular act function. """ raise NotImplementedError - @abstractmethod @staticmethod + @abstractmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: """ Whether the kernel supports deployment in expert parallel. diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 17f7db81a4fb..2efc685fe10f 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -36,7 +36,7 @@ ) if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.fused_moe import FusedMoE + from vllm.model_executor.layers.fused_moe.layer import FusedMoE logger = init_logger(__name__) @@ -143,13 +143,18 @@ def _return_or_raise( activation_format: mk.FusedMoEActivationFormat, ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: k_cls = backend_2_kernel_cls(backend) - if k_cls.is_supported_config(k_cls, config, scheme, activation_format): + supported, reason = k_cls.is_supported_config( + k_cls, config, scheme, activation_format + ) + if supported: logger.info_once(_make_log_backend(backend)) return backend, k_cls + assert reason is not None raise ValueError( f"Requested FP8 MoE backend `{backend.value}` " - "does not support the deployment configuration." + "does not support the deployment configuration since " + f"{reason}." ) # NOTE: the kernels are selected in the following order. @@ -370,7 +375,7 @@ def make_fp8_moe_kernel( moe=moe_config, quant_config=moe_quant_config, routing_tables=None, # TODO: init routing tables here? - defer_input_quant=experts_cls.should_pf_defer_input_quant(), + defer_input_quant=experts_cls.should_pf_defer_input_quant(moe_quant_config), allow_new_interface=True, ) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 911f85749450..842c3db342b7 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -276,7 +276,7 @@ def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard @staticmethod - def should_pf_defer_input_quant(quant_config): + def should_pf_defer_input_quant(quant_config: FusedMoEQuantConfig) -> bool: """ AITER Fused MoE kernels handle input quantization. """ diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index de6f4896880c..a8beef14c3a1 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -5,6 +5,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, FusedMoEQuantScheme, @@ -24,10 +25,10 @@ class TritonOrDeepGemmExperts(FallbackExperts): """DeepGemm with fallback to Triton for low latency shapes.""" - def __init__(self, quant_config: FusedMoEQuantConfig): + def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig): super().__init__( - experts=DeepGemmExperts(quant_config), - fallback_experts=TritonExperts(quant_config), + experts=DeepGemmExperts(moe_config, quant_config), + fallback_experts=TritonExperts(moe_config, quant_config), ) @staticmethod diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 715b9809eb50..e72010ea5dc7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -669,6 +669,12 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): # Delay creation of the kernel until after process-weights. self.kernel: mk.FusedMoEModularKernel | None = None + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None + @property def supports_mk_interally(self) -> bool: return True From 9f829f02a28a555de38167f6d26f08390d839948 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 23:07:11 +0000 Subject: [PATCH 056/113] updated Signed-off-by: Robert Shaw --- docs/design/fused_moe_modular_kernel.md | 2 +- docs/design/moe_kernel_features.md | 3 +- .../layers/fused_moe/trtllm_mxfp4_moe.py | 162 ------------------ .../layers/quantization/mxfp4.py | 4 +- 4 files changed, 5 insertions(+), 166 deletions(-) delete mode 100644 vllm/model_executor/layers/fused_moe/trtllm_mxfp4_moe.py diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index 39898189c440..e1a96be6c344 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -166,7 +166,7 @@ We suggest picking an already existing `FusedMoEPrepareAndFinalize` implementati FusedMoEPermuteExpertsUnpermute performs the core of the FusedMoE operations. The various functions exposed by the abstract class and their significance is as follows, -`FusedMoEPermuteExpertsUnpermute::activation_format()`: Return the supported activation formats. i.e. Standard / Batched (MaskedGEMM) format. +`FusedMoEPermuteExpertsUnpermute::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format. `FusedMoEPermuteExpertsUnpermute::supports_chunking()`: Return True if the implementation supports chunking. Typically implementations that input `FusedMoEActivationFormat.Standard` support chunking and `FusedMoEActivationFormat.BatchedExperts` do not. diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index abd30041e44a..5b35b152ff38 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -91,10 +91,11 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels | flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | marlin | standard,
batched | 3 / N/A | 3 / N/A | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | -| trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtllmMxFp4Experts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtllmMxFp4Experts] | +| trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | | rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] | | cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] | +| naive batched4 | batched | int8,
fp8 | G,A,T | silu, gelu | 6 | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] | !!! info "Table key" 1. All types: mxfp4, nvfp4, int4, int8, fp8 diff --git a/vllm/model_executor/layers/fused_moe/trtllm_mxfp4_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_mxfp4_moe.py deleted file mode 100644 index 11a0e7ac557e..000000000000 --- a/vllm/model_executor/layers/fused_moe/trtllm_mxfp4_moe.py +++ /dev/null @@ -1,162 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, - FusedMoEParallelConfig, - FusedMoEQuantConfig, - FusedMoEQuantScheme, -) -from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP, -) -from vllm.platforms import current_platform - - -class TrtllmMxFp4Experts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( - self, - moe: FusedMoEConfig, - quant_config: FusedMoEQuantConfig, - gemm1_alpha, - gemm1_beta, - gemm1_clamp_limit, - max_capture_size, - ): - super().__init__(quant_config) - self.moe = moe - self.gemm1_alpha = gemm1_alpha - self.gemm1_beta = gemm1_beta - self.gemm1_clamp_limit = gemm1_clamp_limit - self.max_capture_size = max_capture_size - - @staticmethod - def activation_format() -> mk.FusedMoEActivationFormat: - return mk.FusedMoEActivationFormat.Standard - - @staticmethod - def _supports_current_device() -> bool: - return current_platform.has_device_capability((10, 0)) - - @staticmethod - def _supports_no_act_and_mul() -> bool: - return False - - @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return quant_scheme.is_mxfp4_w4a16 or quant_scheme.is_mxfp4_w4a8 - - @staticmethod - def _supports_activation(activation: str) -> bool: - return activation in ["swigluoai"] - - @staticmethod - def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return True - - def supports_chunking(self) -> bool: - return True - - def supports_expert_map(self) -> bool: - return True - - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - return TopKWeightAndReduceNoOP() - - def workspace_shapes( - self, - M: int, - N: int, - K: int, - topk: int, - global_num_experts: int, - local_num_experts: int, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - # The workspaces for this implementation are managed by flashinfer. - workspace1 = (0,) - workspace2 = (0,) - output = (M, K) - return (workspace1, workspace2, output) - - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: torch.Tensor | None, - a1q_scale: torch.Tensor | None, - a2_scale: torch.Tensor | None, - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - apply_router_weight_on_input: bool, - ): - topk = topk_ids.size(-1) - local_num_experts = w1.size(0) - intermediate_size = w2.size(1) - local_expert_offset = self.moe.ep_rank * local_num_experts - - x_quant = hidden_states - x_scale = a1q_scale - if x_scale is not None: - x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x_quant.shape[:-1], -1) - - packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( - torch.bfloat16 - ).view(torch.int16) - - assert self.w1_scale is not None - assert self.w2_scale is not None - kwargs = { - "topk_ids": packed_tensor, - "routing_bias": None, - "hidden_states": x_quant, - "hidden_states_scale": x_scale, - "gemm1_weights": w1, - "gemm1_weights_scale": self.w1_scale, - "gemm1_bias": self.w1_bias, - "gemm1_alpha": self.gemm1_alpha, - "gemm1_beta": self.gemm1_beta, - "gemm1_clamp_limit": self.gemm1_clamp_limit, - "gemm2_weights": w2, - "gemm2_weights_scale": self.w2_scale, - "gemm2_bias": self.w2_bias, - "output1_scale_scalar": None, - "output1_scale_gate_scalar": None, - "output2_scale_scalar": None, - "num_experts": global_num_experts, - "top_k": topk, - "n_group": None, - "topk_group": None, - "intermediate_size": intermediate_size, - "local_expert_offset": local_expert_offset, - "local_num_experts": local_num_experts, - "routed_scaling_factor": None, - "tile_tokens_dim": None, - "routing_method_type": 1, - "do_finalize": True, - "output": output, - "tune_max_num_tokens": max(self.max_capture_size, 1), - } - - from flashinfer import trtllm_fp4_block_scale_routed_moe - - from vllm.utils.flashinfer import autotune - - with autotune(False): - # Enable autotune when, - # https://github.com/flashinfer-ai/flashinfer/issues/2023 is - # resolved. - trtllm_fp4_block_scale_routed_moe(**kwargs) - - return output diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 3fcc61878b3c..8e050b795f94 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -32,7 +32,7 @@ OAITritonExperts, UnfusedOAITritonExperts, ) -from vllm.model_executor.layers.fused_moe.trtllm_mxfp4_moe import TrtllmMxFp4Experts +from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( @@ -873,7 +873,7 @@ def select_gemm_impl( # TODO(bnell): part of quant_config "max_capture_size": self.max_capture_size, } - return TrtllmMxFp4Experts(self.moe, self.moe_quant_config, **kwargs) + return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs) elif self.mxfp4_backend == Mxfp4Backend.MARLIN: return MarlinExperts(self.moe_quant_config) elif self.mxfp4_backend == Mxfp4Backend.TRITON: From 03ce528fe34694bea69529a42b2081ba60dca9ff Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 23:07:34 +0000 Subject: [PATCH 057/113] re-add Signed-off-by: Robert Shaw --- .../layers/fused_moe/trtllm_moe.py | 162 ++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/trtllm_moe.py diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py new file mode 100644 index 000000000000..97cddd5897a7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, + FusedMoEQuantScheme, +) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP, +) +from vllm.platforms import current_platform + + +class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__( + self, + moe: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + gemm1_alpha, + gemm1_beta, + gemm1_clamp_limit, + max_capture_size, + ): + super().__init__(quant_config) + self.moe = moe + self.gemm1_alpha = gemm1_alpha + self.gemm1_beta = gemm1_beta + self.gemm1_clamp_limit = gemm1_clamp_limit + self.max_capture_size = max_capture_size + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + return current_platform.has_device_capability((10, 0)) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + return quant_scheme.is_mxfp4_w4a16 or quant_scheme.is_mxfp4_w4a8 + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["swigluoai"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True + + def supports_chunking(self) -> bool: + return True + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + # The workspaces for this implementation are managed by flashinfer. + workspace1 = (0,) + workspace2 = (0,) + output = (M, K) + return (workspace1, workspace2, output) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + topk = topk_ids.size(-1) + local_num_experts = w1.size(0) + intermediate_size = w2.size(1) + local_expert_offset = self.moe.ep_rank * local_num_experts + + x_quant = hidden_states + x_scale = a1q_scale + if x_scale is not None: + x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x_quant.shape[:-1], -1) + + packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( + torch.bfloat16 + ).view(torch.int16) + + assert self.w1_scale is not None + assert self.w2_scale is not None + kwargs = { + "topk_ids": packed_tensor, + "routing_bias": None, + "hidden_states": x_quant, + "hidden_states_scale": x_scale, + "gemm1_weights": w1, + "gemm1_weights_scale": self.w1_scale, + "gemm1_bias": self.w1_bias, + "gemm1_alpha": self.gemm1_alpha, + "gemm1_beta": self.gemm1_beta, + "gemm1_clamp_limit": self.gemm1_clamp_limit, + "gemm2_weights": w2, + "gemm2_weights_scale": self.w2_scale, + "gemm2_bias": self.w2_bias, + "output1_scale_scalar": None, + "output1_scale_gate_scalar": None, + "output2_scale_scalar": None, + "num_experts": global_num_experts, + "top_k": topk, + "n_group": None, + "topk_group": None, + "intermediate_size": intermediate_size, + "local_expert_offset": local_expert_offset, + "local_num_experts": local_num_experts, + "routed_scaling_factor": None, + "tile_tokens_dim": None, + "routing_method_type": 1, + "do_finalize": True, + "output": output, + "tune_max_num_tokens": max(self.max_capture_size, 1), + } + + from flashinfer import trtllm_fp4_block_scale_routed_moe + + from vllm.utils.flashinfer import autotune + + with autotune(False): + # Enable autotune when, + # https://github.com/flashinfer-ai/flashinfer/issues/2023 is + # resolved. + trtllm_fp4_block_scale_routed_moe(**kwargs) + + return output From 9ce0412a4180a9d3576ede790497542c4b3920c8 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 23:10:54 +0000 Subject: [PATCH 058/113] added back naive batched experts Signed-off-by: Robert Shaw --- .../layers/fused_moe/fused_batched_moe.py | 128 ++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 0790a18e028f..d66b59c18e22 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -21,6 +21,7 @@ normalize_batched_scales_shape, normalize_scales_shape, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -628,6 +629,133 @@ def finalize( ) +class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): + """ + A reference MoE expert class that operates on expert batched format, + i.e. E x max_num_tokens x K. This is the format that the pplx + dispatch/combine kernels use. + """ + + def __init__( + self, + max_num_tokens: int, + num_dispatchers: int, + quant_config: FusedMoEQuantConfig, + ): + super().__init__(quant_config) + assert not self.quant_config.use_int8_w8a8, "NYI" + assert not self.quant_config.use_int8_w8a16, "NYI" + assert not self.quant_config.use_int4_w4a16, "NYI" + assert self.quant_config.ocp_mx_scheme is None, "NYI" + self.max_num_tokens = max_num_tokens + self.num_dispatchers = num_dispatchers + + @property + def activation_formats( + self, + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return ( + mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts, + ) + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + num_dp = self.num_dispatchers + num_experts = local_num_experts + workspace13 = (num_experts, self.max_num_tokens * num_dp, K) + workspace2 = (self.max_num_tokens * num_dp, N) + output = workspace13 + return (workspace13, workspace2, output) + + def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + assert self.quant_config.is_quantized + f32 = torch.float32 + if self.quant_config.is_per_act_token or self.quant_config.is_per_tensor: + return t.to(f32) * scale + else: + return t.to(f32) * group_broadcast(scale, t.shape) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + assert hidden_states.dim() == 3 + assert expert_tokens_meta is not None + expert_num_tokens = expert_tokens_meta.expert_num_tokens + + num_local_experts = w1.size(0) + assert num_local_experts == w1.size(0), f"{num_local_experts} == {w1.size(0)}" + + N = w1.size(1) // 2 + + for expert in range(num_local_experts): + # Indexing expert_num_tokens doesn't work w/cudagraphs or inductor + if ( + torch.compiler.is_compiling() + or torch.cuda.is_current_stream_capturing() + ): + num = hidden_states.shape[1] + else: + num = int(expert_num_tokens[expert].item()) + + if num == 0: + continue + + tmp = _resize_cache(workspace2, (num, N)) + + if self.quant_config.is_quantized: + assert a1q_scale is not None and self.w1_scale is not None + input = self.dequant(hidden_states[expert, :, :], a1q_scale[expert]) + w1_dq = self.dequant(w1[expert], self.w1_scale[expert]) + input = input[:num] @ w1_dq.transpose(0, 1) + else: + input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) + + self.activation(activation, tmp, input.to(tmp.dtype)) + + if self.quant_config.is_quantized: + assert self.w2_scale is not None + w2_dq = self.dequant(w2[expert], self.w2_scale[expert]) + else: + w2_dq = w2[expert] + + output[expert, :num, :] = tmp @ w2_dq.transpose(0, 1).to(tmp.dtype) + + def batched_moe_kernel_quantize_input( A: torch.Tensor, A_scale: torch.Tensor | None, From 856d9dc089fb4caf9df0070179834b4b359b44e1 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 23:13:27 +0000 Subject: [PATCH 059/113] added back naive batched experts Signed-off-by: Robert Shaw --- .../kernels/moe/modular_kernel_tools/mk_objects.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 4400d2cf53d5..80caedfdae92 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts, + NaiveBatchedExperts, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, @@ -183,6 +184,15 @@ def expert_info(kind) -> ExpertInfo: needs_matching_quant=True, ) +register_experts( + NaiveBatchedExperts, + batched_format, + common_float_and_int_types, + blocked_quantization_support=True, + supports_chunking=False, + supports_expert_map=True, +) + # Disable on blackwell for now if has_deep_ep() and not current_platform.has_device_capability(100): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( @@ -441,6 +451,10 @@ def make_fused_experts( kwargs = quant_kwargs print(f"Making TritonExperts {kwargs} ...") experts = TritonExperts(**kwargs) + elif fused_experts_type == NaiveBatchedExperts: + kwargs = batch_kwargs | quant_kwargs + print(f"Making NaiveBatchedExperts {kwargs} ...") + experts = NaiveBatchedExperts(**kwargs) elif fused_experts_type == TritonOrDeepGemmExperts: kwargs = quant_kwargs | deepgemm_kwargs print(f"Making TritonOrDeepGemmExperts {kwargs} ...") From 38ae5e7ada8c2649fb2ce57b5f78bc5bd700a67c Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 23:14:28 +0000 Subject: [PATCH 060/113] added back naive batched experts Signed-off-by: Robert Shaw --- tests/kernels/moe/test_batched_moe.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 08e080f64d2b..c9d425b5b990 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -10,6 +10,7 @@ batched_moe, make_quantized_test_activations, make_test_weights, + naive_batched_moe, ) from tests.kernels.quant_utils import native_batched_masked_quant_matmul from tests.kernels.utils import torch_experts @@ -317,6 +318,21 @@ def test_fused_moe_batched_experts( block_shape=block_shape, ) + batched_output = naive_batched_moe( + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + a2_scale=a2_scale, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) + triton_output = batched_moe( a, w1, @@ -332,4 +348,6 @@ def test_fused_moe_batched_experts( block_shape=block_shape, ) - torch.testing.assert_close(triton_output, baseline_output, atol=2e-2, rtol=2e-2) + torch.testing.assert_close(batched_output, baseline_output, atol=3e-2, rtol=2e-2) + + torch.testing.assert_close(triton_output, batched_output, atol=2e-2, rtol=2e-2) From 81b5ed36b9cfdd73663bb94acdb86abd5eb79e25 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 23:18:33 +0000 Subject: [PATCH 061/113] added back naive batched experts Signed-off-by: Robert Shaw --- tests/kernels/moe/test_pplx_moe.py | 60 ++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index aa94f6f7a566..c08a54f0e9f6 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -31,6 +31,7 @@ from tests.kernels.moe.utils import ( make_shared_experts, make_test_weights, + naive_batched_moe, ) from tests.kernels.quant_utils import dequant from tests.kernels.utils import torch_experts @@ -170,6 +171,40 @@ def torch_batched_moe( return torch_finalize(out, topk_weight, topk_ids) +@pytest.mark.parametrize("m,n,k", BATCHED_MOE_MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + workspace_init, +): + set_random_seed(7) + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + baseline_output = torch_experts( + a, w1, w2, topk_weight, topk_ids + ) # only for baseline + torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) + batched_output = naive_batched_moe( + a, w1, w2, topk_weight, topk_ids + ) # pick torch_experts or this + + torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0) + + def create_pplx_prepare_finalize( num_tokens: int, hidden_dim: int, @@ -681,6 +716,21 @@ def _pplx_moe( block_shape=block_shape, ) + batched_output = naive_batched_moe( + a, + w1, + w2, + topk_weight, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + a2_scale=a2_scale, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) + pplx_outputs = pplx_moe( group_name, rank, @@ -716,12 +766,14 @@ def _pplx_moe( else: chunked_shared_output = None - chunked_torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to( - pplx_output.device - ) + chunked_batch_output = chunk_by_rank( + batched_output, pgi.rank, pgi.world_size + ).to(pplx_output.device) + + torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2) torch.testing.assert_close( - pplx_output, chunked_torch_output, atol=3e-2, rtol=3e-2 + pplx_output, chunked_batch_output, atol=3e-2, rtol=3e-2 ) if shared_experts is not None: From 550b42a0aa4a6cafff60fb9442b64e8390122ec6 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 23:21:54 +0000 Subject: [PATCH 062/113] reduce LOC change Signed-off-by: Robert Shaw --- tests/kernels/moe/test_flashinfer_moe.py | 3 +++ .../fused_moe/flashinfer_cutlass_moe.py | 27 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index 7739d0040e07..16ccb5e8e5ed 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -14,6 +14,7 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, + is_valid_flashinfer_cutlass_fused_moe, ) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( create_flashinfer_prepare_finalize, @@ -88,6 +89,8 @@ def test_flashinfer_fp4_moe_no_graph( FlashInferExperts(out_dtype=dtype, quant_config=quant_config), ) + assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) + fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation] flashinfer_output = flashinfer_experts( diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 2866f889b0ee..e87fec015c60 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -22,6 +22,33 @@ logger = init_logger(__name__) +def is_valid_flashinfer_cutlass_fused_moe( + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor +) -> bool: + """ + Check if the given problem size is supported by the FlashInfer CUTLASS MoE + kernel. + """ + if not has_flashinfer_cutlass_fused_moe(): + logger.debug_once( + "FlashInferExperts disabled: flashinfer_cutlass_fused_moe not available." + ) + return False + # Data type checks + if ( + w1.dtype != torch.uint8 + or w2.dtype != torch.uint8 + or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16] + ): + logger.debug_once( + "FlashInferExperts disabled: w1/w2 must be torch.uint8 " + f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be " + f"float32, float16, or bfloat16 (got {hidden_states.dtype})." + ) + return False + return True + + class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, From d85bd8befe9b10b421318c07da94d51055b4eb37 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 23:22:38 +0000 Subject: [PATCH 063/113] reduce LOC change Signed-off-by: Robert Shaw --- tests/kernels/moe/modular_kernel_tools/mk_objects.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 80caedfdae92..99b168dc7554 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -451,14 +451,14 @@ def make_fused_experts( kwargs = quant_kwargs print(f"Making TritonExperts {kwargs} ...") experts = TritonExperts(**kwargs) - elif fused_experts_type == NaiveBatchedExperts: - kwargs = batch_kwargs | quant_kwargs - print(f"Making NaiveBatchedExperts {kwargs} ...") - experts = NaiveBatchedExperts(**kwargs) elif fused_experts_type == TritonOrDeepGemmExperts: kwargs = quant_kwargs | deepgemm_kwargs print(f"Making TritonOrDeepGemmExperts {kwargs} ...") experts = TritonOrDeepGemmExperts(**kwargs) + elif fused_experts_type == NaiveBatchedExperts: + kwargs = batch_kwargs | quant_kwargs + print(f"Making NaiveBatchedExperts {kwargs} ...") + experts = NaiveBatchedExperts(**kwargs) elif fused_experts_type == CutlassExpertsFp8: strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim) kwargs = { From 642b85cedbfcf74215aa2244106190868ccfa244 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 23:23:35 +0000 Subject: [PATCH 064/113] reduce LOC change Signed-off-by: Robert Shaw --- tests/kernels/moe/test_flashinfer_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index 16ccb5e8e5ed..1262eea70bab 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -84,13 +84,13 @@ def test_flashinfer_fp4_moe_no_graph( score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) + assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) + flashinfer_experts = FusedMoEModularKernel( create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True), FlashInferExperts(out_dtype=dtype, quant_config=quant_config), ) - assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) - fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation] flashinfer_output = flashinfer_experts( From 5027a285a95a7e86172d89cd27a4da3ff2ccad46 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 18:29:17 -0500 Subject: [PATCH 065/113] reduce loc change Signed-off-by: Robert Shaw --- tests/kernels/moe/utils.py | 41 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index 8d66fdecee37..f0c8c8033b8e 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -12,6 +12,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize, BatchedTritonExperts, + NaiveBatchedExperts, ) from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input @@ -86,6 +87,46 @@ def batched_moe( return fused_experts(a, w1, w2, topk_weight, topk_ids) +def naive_batched_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + quant_dtype: torch.dtype | None = None, + per_act_token_quant: bool = False, + block_shape: list[int] | None = None, +) -> torch.Tensor: + max_num_tokens = round_up(a.shape[0], 64) + + quant_config = FusedMoEQuantConfig.make( + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + fused_experts = FusedMoEModularKernel( + BatchedPrepareAndFinalize( + max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0 + ), + NaiveBatchedExperts( + max_num_tokens=max_num_tokens, + num_dispatchers=1, + quant_config=quant_config, + ), + ) + + return fused_experts(a, w1, w2, topk_weight, topk_ids) + + def chunk_scales( scales: torch.Tensor | None, start: int, end: int ) -> torch.Tensor | None: From 6a009d6a5c012edcd5c3cd75e7493e5888e5dfe1 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 13 Jan 2026 23:30:27 +0000 Subject: [PATCH 066/113] reduce LOC change Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index fc70d0979cbf..47f73e053ca6 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -330,8 +330,7 @@ def workspace_shapes( # end up sending their tokens. This needs to be fixed. num_dispatchers = self.num_dispatchers num_experts = local_num_experts - # max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens - max_num_tokens = self.max_num_tokens + max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens activation_out_dim = self.adjust_N_for_activation(N, activation) workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N)) workspace2 = (num_experts, max_num_tokens * num_dispatchers, activation_out_dim) From 1cc46bec00118f264f4d8803cf492f8ace81fdf7 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 01:49:57 +0000 Subject: [PATCH 067/113] fix native ag/rs Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/layer.py | 3 +++ vllm/model_executor/layers/fused_moe/oracle/fp8.py | 2 ++ vllm/model_executor/layers/fused_moe/prepare_finalize.py | 4 +++- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 19dc5aeafcf9..53148036d721 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1901,10 +1901,12 @@ def forward_impl( self.ensure_moe_quant_config_init() self.ensure_dp_chunking_init() + # TODO: this is not right anymore. has_separate_shared_experts = ( not isinstance(self.quant_method, FusedMoEModularMethod) and self.shared_experts is not None ) + has_separate_shared_experts = False use_chunked_impl = self.use_dp_chunking @@ -1931,6 +1933,7 @@ def forward_impl( isinstance(self.quant_method, FusedMoEModularMethod) or self.quant_method.supports_mk_interally ) + logger.info_once(f"{do_naive_dispatch_combine=}") ctx = get_forward_context() sp_ctx = ( diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 2efc685fe10f..8196b192c63c 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -379,6 +379,8 @@ def make_fp8_moe_kernel( allow_new_interface=True, ) + logger.info_once("Using %s", prepare_finalize.__class__.__name__) + # Create Experts. if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.Standard: experts = experts_cls.make_standard_experts( diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index b66e2a1dd73f..1546810e457c 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -64,6 +64,7 @@ def prepare( quant_config.per_act_token_quant, quant_config.block_shape, ) + # TODO - this is just for deepgemm? expert_tokens_meta = None @@ -92,7 +93,8 @@ def prepare( a1q, topk_weights, topk_ids = res else: a1q, topk_weights, topk_ids, scales = res - a1q_scale = res[0] + assert scales is not None and len(scales) == 1 + a1q_scale = scales[0] if use_int8_view: a1q = a1q.view(current_platform.fp8_dtype()) From 62d9d7c4b0a41a87674c23fb48f252e7b2f4f31e Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 02:46:32 +0000 Subject: [PATCH 068/113] oracle is now working properly Signed-off-by: Robert Shaw --- .../layers/fused_moe/flashinfer_trtllm_moe.py | 29 +++++--- vllm/model_executor/layers/fused_moe/layer.py | 1 - .../layers/fused_moe/oracle/fp8.py | 69 +++++++++++++------ .../layers/fused_moe/rocm_aiter_fused_moe.py | 4 +- 4 files changed, 72 insertions(+), 31 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 086ca87d31b6..b082f64de0bd 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -3,6 +3,7 @@ import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -59,17 +60,29 @@ def _supports_moe_parallel_config(moe_parallel_config: FusedMoEParallelConfig) - def is_supported_config_trtllm( moe_config: FusedMoEConfig, moe_quant_scheme: FusedMoEQuantScheme, -) -> bool: + activation_format: mk.FusedMoEActivationFormat, +) -> tuple[bool, str | None]: """ This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config """ - return ( - _supports_current_device() - and (moe_config.is_act_and_mul or _supports_no_act_and_mul()) - and _supports_activation(moe_config.activation) - and _supports_quant_scheme(moe_quant_scheme) - and _supports_moe_parallel_config(moe_config.moe_parallel_config) - ) + + def _make_reason(reason: str) -> str: + return f"kernel does not support {reason}" + + if not _supports_current_device(): + return False, _make_reason("current device") + elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()): + return False, _make_reason("no act_and_mul MLP layer") + elif not _supports_activation(moe_config.activation): + return False, _make_reason(f"{moe_config.activation} activation") + elif not _supports_quant_scheme(moe_quant_scheme): + return False, _make_reason("quantization scheme") + elif not _supports_moe_parallel_config(moe_config.moe_parallel_config): + return False, _make_reason("parallel config") + elif activation_format != mk.FusedMoEActivationFormat.Standard: + return False, _make_reason("activation format") + + return True, None def flashinfer_fused_moe_blockscale_fp8( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 53148036d721..c4ea35410547 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1933,7 +1933,6 @@ def forward_impl( isinstance(self.quant_method, FusedMoEModularMethod) or self.quant_method.supports_mk_interally ) - logger.info_once(f"{do_naive_dispatch_combine=}") ctx = get_forward_context() sp_ctx = ( diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 8196b192c63c..271de5a29e55 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -49,7 +49,7 @@ class Fp8MoeBackend(Enum): BATCHED_DEEPGEMM = "Batched DeepGEMM" MARLIN = "Marlin" TRITON = "Triton" - BATCHED_TRITON = "Triton" + BATCHED_TRITON = "Batched Triton" AITER = "AITER" VLLM_CUTLASS = "vLLM CUTLASS" @@ -134,7 +134,19 @@ def select_fp8_moe_backend( return Fp8MoeBackend.TRITON, backend_2_kernel_cls(Fp8MoeBackend.TRITON) def _make_log_backend(backend: Fp8MoeBackend): - return f"Using {backend.value} backend for FP8 MoE" + return f"Using `{backend.value}` backend for FP8 MoE" + + def _make_log_unsupported(backend: Fp8MoeBackend, reason: str | None) -> str: + if reason: + return ( + f"FP8 MoE backend `{backend.value}` does not support the " + f"deployment configuration since {reason}." + ) + else: + return ( + f"FP8 MoE backend `{backend.value}` does not support the " + "deployment configuration." + ) def _return_or_raise( backend: Fp8MoeBackend, @@ -149,13 +161,7 @@ def _return_or_raise( if supported: logger.info_once(_make_log_backend(backend)) return backend, k_cls - - assert reason is not None - raise ValueError( - f"Requested FP8 MoE backend `{backend.value}` " - "does not support the deployment configuration since " - f"{reason}." - ) + raise ValueError(_make_log_unsupported(backend, reason)) # NOTE: the kernels are selected in the following order. AVAILABLE_BACKENDS = [ @@ -183,13 +189,14 @@ def _return_or_raise( if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: backend = Fp8MoeBackend.FLASHINFER_TRTLLM # TODO: validate activation format - if is_supported_config_trtllm(config, quant_scheme): - logger.info_once(_make_log_backend(backend)) - return backend, None # ? - raise ValueError( - f"Requested FP8 MoE backend `{backend.value}` " - "does not support the deployment configuration." + supported, reason = is_supported_config_trtllm( + config, quant_scheme, activation_format ) + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, None + else: + raise ValueError(_make_log_unsupported(backend, reason)) elif fi_backend == FlashinferMoeBackend.CUTLASS: backend = Fp8MoeBackend.FLASHINFER_CUTLASS @@ -234,19 +241,39 @@ def _return_or_raise( backend = Fp8MoeBackend.MARLIN return _return_or_raise(backend, config, quant_scheme, activation_format) - elif envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE: - backend = Fp8MoeBackend.AITER - return _return_or_raise(backend, config, quant_scheme, activation_format) + elif envs.is_set("VLLM_ROCM_USE_AITER") or envs.is_set("VLLM_ROCM_USE_AITER_MOE"): + if not envs.VLLM_ROCM_USE_AITER or not envs.VLLM_ROCM_USE_AITER_MOE: + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.AITER) + else: + backend = Fp8MoeBackend.AITER + return _return_or_raise(backend, config, quant_scheme, activation_format) if not allow_vllm_cutlass: AVAILABLE_BACKENDS.remove(Fp8MoeBackend.VLLM_CUTLASS) # Select kernels in order of backend. for backend in AVAILABLE_BACKENDS: - k_cls = backend_2_kernel_cls(backend) - if k_cls.is_supported_config(k_cls, config, quant_scheme, activation_format): - logger.info_once(_make_log_backend(backend)) + if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + k_cls = None # type: ignore[assignment] + supported, reason = is_supported_config_trtllm( + config, + quant_scheme, + activation_format, + ) + else: + k_cls = backend_2_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, + config, + quant_scheme, + activation_format, + ) + + if supported: + logger.info_once(_make_log_backend(backend), scope="local") return backend, k_cls + else: + logger.info_once(_make_log_unsupported(backend, reason), scope="local") raise NotImplementedError( "No FP8 MoE backend supports the deployment configuration." diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 842c3db342b7..101cfe81ad81 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -284,7 +284,9 @@ def should_pf_defer_input_quant(quant_config: FusedMoEQuantConfig) -> bool: @staticmethod def _supports_current_device() -> bool: - return rocm_aiter_ops.IS_AITER_FOUND + # Figure out a way to check if ROCm AITER is available. + # return rocm_aiter_ops.is_fused_moe_enabled() + return False @staticmethod def _supports_no_act_and_mul() -> bool: From 7d7fcd99eeb3e905654b7bf4ddc2e7128ee86402 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 02:51:52 +0000 Subject: [PATCH 069/113] confirmed aiter env variables work as expected Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 101cfe81ad81..915944d49c4e 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -284,9 +284,7 @@ def should_pf_defer_input_quant(quant_config: FusedMoEQuantConfig) -> bool: @staticmethod def _supports_current_device() -> bool: - # Figure out a way to check if ROCm AITER is available. - # return rocm_aiter_ops.is_fused_moe_enabled() - return False + return rocm_aiter_ops.is_fused_moe_enabled() @staticmethod def _supports_no_act_and_mul() -> bool: From 51a93fef9b4f7d923433617fdbac438fe76ce17f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 02:59:32 +0000 Subject: [PATCH 070/113] fix up oracle Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/oracle/fp8.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 271de5a29e55..0dce1fabb30a 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -176,6 +176,7 @@ def _return_or_raise( Fp8MoeBackend.MARLIN, ] + # Handle explicit FlashInfer FP8 configuration. if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP8"): if not envs.VLLM_USE_FLASHINFER_MOE_FP8: # If the user rejects FlashInfer remove those backends. @@ -226,7 +227,8 @@ def _return_or_raise( "FlashInfer FP8 MoE backend supports the configuration." ) - elif envs.is_set("VLLM_USE_DEEP_GEMM") or envs.is_set("VLLM_MOE_USE_DEEP_GEMM"): + # Handle explicit DeepGEMM FP8 configuration. + if envs.is_set("VLLM_USE_DEEP_GEMM") or envs.is_set("VLLM_MOE_USE_DEEP_GEMM"): if not envs.VLLM_USE_DEEP_GEMM or not envs.VLLM_MOE_USE_DEEP_GEMM: AVAILABLE_BACKENDS.remove(Fp8MoeBackend.DEEPGEMM) else: @@ -237,11 +239,13 @@ def _return_or_raise( ) return _return_or_raise(backend, config, quant_scheme, activation_format) - elif envs.VLLM_TEST_FORCE_FP8_MARLIN: + # Handle explicit MARLIN FP8 configuration. + if envs.VLLM_TEST_FORCE_FP8_MARLIN: backend = Fp8MoeBackend.MARLIN return _return_or_raise(backend, config, quant_scheme, activation_format) - elif envs.is_set("VLLM_ROCM_USE_AITER") or envs.is_set("VLLM_ROCM_USE_AITER_MOE"): + # Handle explicit AITER FP8 configuration. + if envs.is_set("VLLM_ROCM_USE_AITER") or envs.is_set("VLLM_ROCM_USE_AITER_MOE"): if not envs.VLLM_ROCM_USE_AITER or not envs.VLLM_ROCM_USE_AITER_MOE: AVAILABLE_BACKENDS.remove(Fp8MoeBackend.AITER) else: From a600601f946a5ec5672f78e2788d9712f67a73c2 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 03:55:13 +0000 Subject: [PATCH 071/113] trying to make ct work Signed-off-by: Robert Shaw --- .../layers/fused_moe/fused_moe.py | 2 +- .../layers/fused_moe/oracle/fp8.py | 3 +- .../layers/fused_moe/triton_cutlass_moe.py | 11 +- .../compressed_tensors_moe.py | 227 +++++++++--------- 4 files changed, 120 insertions(+), 123 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6c38c3dfeaaa..b9a4a136b878 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2327,7 +2327,7 @@ def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: ) return ( - quant_scheme.is_unquantized() + quant_scheme.is_unquantized or quant_scheme.is_fp8_w8a8 and device_supports_fp8 ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 0dce1fabb30a..0540c2e8b98f 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -57,8 +57,9 @@ class Fp8MoeBackend(Enum): def backend_2_kernel_cls( backend: Fp8MoeBackend, ) -> type[mk.FusedMoEPermuteExpertsUnpermute]: - if backend == Fp8MoeBackend.NONE or backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: raise NotImplementedError + elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py index fafc6d743c89..3865babf20a1 100644 --- a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py @@ -6,6 +6,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, FusedMoEQuantScheme, @@ -21,17 +22,13 @@ class TritonOrCutlassExperts(FallbackExperts): def __init__( self, - e: int, - n: int, - k: int, - out_dtype: torch.dtype | None, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, - device: torch.dtype, ): self.is_sm100 = current_platform.has_device_capability(100) super().__init__( - experts=CutlassExpertsFp8(e, n, k, out_dtype, quant_config, device), - fallback_experts=TritonExperts(quant_config), + experts=CutlassExpertsFp8(moe_config, quant_config), + fallback_experts=TritonExperts(moe_config, quant_config), ) @staticmethod diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ac3531012621..9c3240d25437 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -19,14 +19,16 @@ from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, - FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod, ) from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEQuantConfig, + FusedMoEQuantScheme, + RoutingMethodType, fp8_w8a8_moe_quant_config, fp8_w8a16_moe_quant_config, int4_w4a16_moe_quant_config, @@ -45,6 +47,7 @@ Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, + select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( FLASHINFER_NVFP4_MOE_BACKENDS, @@ -65,6 +68,9 @@ flashinfer_trtllm_fp4_routed_moe, select_nvfp4_gemm_impl, ) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + apply_fi_trtllm_fp8_per_tensor_moe, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( process_fp8_input_tensor_strategy_moe, process_fp8_weight_tensor_strategy_moe, @@ -578,33 +584,55 @@ def __init__( "For FP8 Fused MoE layer, we require either per tensor or " "channelwise, dynamic per token quantization." ) - # self.fp8_backend = select_fp8_moe_backend( - # block_quant=self.block_quant, - # tp_size=moe.tp_size, - # with_lora_support=moe.is_lora_enabled, - # # TODO(rob): enable selecting this externally. - # allow_vllm_cutlass=True, - # ) - self.fp8_backend = Fp8MoeBackend.FLASHINFER_CUTLASS - if self.fp8_backend != Fp8MoeBackend.MARLIN: - per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN - per_channel_quant = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL - ) - if per_act_token != per_channel_quant: - raise NotImplementedError( - "For FP8 Fused MoE layers, per-token and per-channel must be " - "used together." - ) - # TODO(rob): hook this up in a follow up PR. - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - raise NotImplementedError( - "FlashInfer TRTLLM backend not supported for compressed-tensors yet." - ) - self.disable_expert_map = False + # Create quant scheme, will be used later to select the quant scales. + # NOTE(rob): we should update QuantConfig to just be the think ts + # holds the scales. Should change the name. + quant_scheme = FusedMoEQuantScheme( + weight_dtype=current_platform.fp8_dtype(), + act_dtype=current_platform.fp8_dtype(), + per_tensor_quant=per_tensor, + per_token_quant=per_channel, + block_size=( + (self.weight_block_size[0], self.weight_block_size[1]) + if self.weight_block_size is not None + else None + ), + static_input_quant=self.static_input_scales, + ) + + # Select Fp8 MoE backend + # NOTE(rob): this is kind of a hack. We need to peak into + # the prepare-finalize selection to determine if we are using + # the batched or standard expert format. + use_batched = ( + self.moe.moe_parallel_config.use_deepep_ll_kernels + or self.moe.moe_parallel_config.use_pplx_kernels + ) + self.fp8_backend, self.experts_cls = select_fp8_moe_backend( + config=self.moe, + quant_scheme=quant_scheme, + activation_format=( + mk.FusedMoEActivationFormat.BatchedExperts + if use_batched + else mk.FusedMoEActivationFormat.Standard + ), + allow_vllm_cutlass=True, + ) + + # Delay creation of the kernel until after process-weights. self.kernel: mk.FusedMoEModularKernel | None = None + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None + + @property + def supports_mk_interally(self) -> bool: + return True + def create_weights( self, layer: torch.nn.Module, @@ -818,111 +846,33 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: + assert self.experts_cls is not None self.kernel, self.use_inplace = make_fp8_moe_kernel( + layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, + experts_cls=self.experts_cls, ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]: - return None - else: - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + "CompressedTensorsW8A8Fp8MoE uses the new modular kernel " + "initialization logic. This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: - # cutlass path - assert self.moe_quant_config is not None - if self.fp8_backend == Fp8MoeBackend.VLLM_CUTLASS: - from vllm.model_executor.layers.fused_moe import ( - CutlassBatchedExpertsFp8, - CutlassExpertsFp8, - ) - - experts: FusedMoEPermuteExpertsUnpermute - - num_dispatchers = prepare_finalize.num_dispatchers() - - if ( - prepare_finalize.activation_format - == FusedMoEActivationFormat.BatchedExperts - ): - logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__) - experts = CutlassBatchedExpertsFp8.make_batched_experts( - moe_config=self.moe, - quant_config=self.moe_quant_config, - max_num_tokens=prepare_finalize.max_num_tokens_per_rank(), - num_dispatchers=num_dispatchers, - ) - else: - logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) - experts = CutlassExpertsFp8( - out_dtype=self.moe.in_dtype, - e=layer.local_num_experts, - n=layer.intermediate_size_per_partition, - k=layer.hidden_size, - device=layer.w13_weight.device, - quant_config=self.moe_quant_config, - ) - - # TODO(rob): investigate disable_expert_map - self.disable_expert_map = ( - num_dispatchers > 1 or not experts.supports_expert_map() - ) - - return experts - - from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts, + raise ValueError( + "CompressedTensorsW8A8Fp8MoE uses the new modular kernel " + "initialization logic. This function should not be called." ) - from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts, - ) - from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, - ) - from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts, - ) - - assert self.fp8_backend not in [Fp8MoeBackend.AITER, Fp8MoeBackend.MARLIN] - - if ( - prepare_finalize.activation_format - == FusedMoEActivationFormat.BatchedExperts - ): - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - assert max_num_tokens_per_rank is not None - - if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: - logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__) - return BatchedDeepGemmExperts( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, - ) - else: - logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) - return BatchedTritonExperts( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, - ) - - else: - if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: - logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) - return TritonOrDeepGemmExperts(self.moe_quant_config) - else: - logger.debug("TritonExperts(%s)", self.__class__.__name__) - return TritonExperts(self.moe_quant_config) def get_fused_moe_quant_config( self, layer: torch.nn.Module @@ -954,6 +904,55 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + if layer.enable_eplb: + raise NotImplementedError( + "EPLB not supported for `FlashInfer TRTLLM FP8 MoE`." + ) + assert layer.activation == "silu" + + if self.block_quant: + import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 + + e_score_correction_bias = ( + layer.e_score_correction_bias.to(x.dtype) + if layer.e_score_correction_bias is not None + else None + ) + routing_method_type = layer.routing_method_type + return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( + routing_logits=router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits, + routing_bias=e_score_correction_bias, + x=x, + w13_weight=layer.w13_weight, + w13_weight_scale_inv=layer.w13_weight_scale, + w2_weight=layer.w2_weight, + w2_weight_scale_inv=layer.w2_weight_scale, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + intermediate_size=layer.intermediate_size_per_partition, + expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + block_shape=self.weight_block_size, + routing_method_type=routing_method_type, + routed_scaling=layer.routed_scaling_factor, + ) + else: + result = apply_fi_trtllm_fp8_per_tensor_moe( + layer=layer, + hidden_states=x, + router_logits=router_logits, + routing_bias=layer.e_score_correction_bias, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + ) topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, @@ -971,7 +970,7 @@ def apply( global_num_experts=layer.global_num_experts, # TODO(rob): investigate the disable_expert_map introduced by: # https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501 - expert_map=None if self.disable_expert_map else layer.expert_map, + expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) From f83c63e5bc0ac3c6bedbcc359cedb967ca2d9256 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 04:44:38 +0000 Subject: [PATCH 072/113] made llama 4 work via compressed tensors Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/layer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c4ea35410547..93cb4778f864 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1901,12 +1901,17 @@ def forward_impl( self.ensure_moe_quant_config_init() self.ensure_dp_chunking_init() - # TODO: this is not right anymore. + # TODO: figure out a better way to express who is responsible for the SE. + mk_has_shared_expert = ( + self.quant_method.supports_mk_interally + and hasattr(self.quant_method, "kernel") + and getattr(self.quant_method.kernel, "shared_experts", None) is not None + ) has_separate_shared_experts = ( not isinstance(self.quant_method, FusedMoEModularMethod) + and not mk_has_shared_expert and self.shared_experts is not None ) - has_separate_shared_experts = False use_chunked_impl = self.use_dp_chunking From 9cb777112e636923ec5b3db14ff1d2f869d0de68 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 04:58:21 +0000 Subject: [PATCH 073/113] initial attempt at modelopt Signed-off-by: Robert Shaw --- .../layers/quantization/modelopt.py | 87 ++++++++++++------- 1 file changed, 56 insertions(+), 31 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index ffe3dffceee7..a687c5201b44 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, + FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( @@ -28,6 +29,7 @@ convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, make_fp8_moe_quant_config, + select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( FLASHINFER_NVFP4_MOE_BACKENDS, @@ -57,8 +59,6 @@ ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, - build_flashinfer_fp8_cutlass_moe_prepare_finalize, - select_cutlass_fp8_gemm_impl, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -91,6 +91,7 @@ PerTensorScaleParameter, ) from vllm.model_executor.utils import replace_parameter +from vllm.platforms import current_platform from vllm.utils.flashinfer import ( flashinfer_scaled_fp4_mm, has_flashinfer, @@ -728,46 +729,68 @@ def __init__( super().__init__(moe_config) self.quant_config = quant_config assert self.quant_config.is_checkpoint_fp8_serialized - # self.fp8_backend = select_fp8_moe_backend( - # block_quant=False, - # tp_size=moe_config.moe_parallel_config.tp_size, - # with_lora_support=self.moe.is_lora_enabled, - # ) - self.fp8_backend = Fp8MoeBackend.FLASHINFER_CUTLASS + + # Create quant scheme, will be used later to select the quant scales. + # NOTE(rob): we should update QuantConfig to just be the think ts + # holds the scales. Should change the name. + quant_scheme = FusedMoEQuantScheme( + weight_dtype=current_platform.fp8_dtype(), + act_dtype=current_platform.fp8_dtype(), + per_tensor_quant=True, + per_token_quant=False, + block_size=None, + static_input_quant=True, + ) + + # Select Fp8 MoE backend + # NOTE(rob): this is kind of a hack. We need to peak into + # the prepare-finalize selection to determine if we are using + # the batched or standard expert format. + use_batched = ( + self.moe.moe_parallel_config.use_deepep_ll_kernels + or self.moe.moe_parallel_config.use_pplx_kernels + ) + self.fp8_backend, self.experts_cls = select_fp8_moe_backend( + config=self.moe, + quant_scheme=quant_scheme, + activation_format=( + mk.FusedMoEActivationFormat.BatchedExperts + if use_batched + else mk.FusedMoEActivationFormat.Standard + ), + ) + + # Delay creation of the kernel until after process-weights. self.kernel: mk.FusedMoEModularKernel | None = None + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None + + @property + def supports_mk_interally(self) -> bool: + return True + def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - # TRT LLM not supported with all2all yet. - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - return None - elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - # TP case: avoid convert to ModularKernelMethod - to be refactored. - if self.moe.dp_size == 1: - return None - - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - self.moe, - use_deepseek_fp8_block_scale=False, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + "ModelOptFp8MoEMethod uses the new modular kernel initialization " + "logic. This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: - assert self.moe_quant_config is not None - experts = select_cutlass_fp8_gemm_impl( - self.moe, - self.moe_quant_config, + raise ValueError( + "ModelOptFp8MoEMethod uses the new modular kernel initialization " + "logic. This function should not be called." ) - logger.debug_once("Using %s", experts.__class__.__name__) - return experts def create_weights( self, @@ -879,14 +902,16 @@ def _setup_kernel( replace_parameter(layer, "w13_weight_scale", w13_scale) replace_parameter(layer, "w2_weight_scale", w2_scale) - # Setup modular kernel for TP case. + # Setup modular kernel. self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: + assert self.experts_cls is not None self.kernel, self.use_inplace = make_fp8_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, - ) # noqa + experts_cls=self.experts_cls, + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13 = layer.w13_weight From ad28a184fdc8deaa10af6467f1b1d10e7f353e39 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 05:02:47 +0000 Subject: [PATCH 074/113] attempt to fix naive multicast Signed-off-by: Robert Shaw --- vllm/distributed/device_communicators/all2all.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 891c9439da25..520bb7613ee9 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -66,7 +66,10 @@ def dispatch( topk_ids: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): if extra_tensors is not None: raise NotImplementedError( "extra_tensors is not supported for NaiveAll2AllManager" @@ -86,7 +89,14 @@ def dispatch( topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel ) - return hidden_states, topk_weights, topk_ids + if extra_tensors is None: + return hidden_states, topk_weights, topk_ids + + extra_tensors = [ + self.naive_multicast(t, cu_tokens_across_sp_cpu, is_sequence_parallel) + for t in extra_tensors + ] + return hidden_states, topk_weights, topk_ids, extra_tensors def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False From b507afa7988cfbb9a1c462b35f132e7801728b3b Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 05:06:48 +0000 Subject: [PATCH 075/113] flashinfer appears to be working Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/modelopt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index a687c5201b44..dc6844fe1e46 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -907,6 +907,7 @@ def _setup_kernel( if self.moe_quant_config: assert self.experts_cls is not None self.kernel, self.use_inplace = make_fp8_moe_kernel( + layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, From 50b106bc4f3f70e6b1edc68f07120e7f88da9548 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 11:57:53 -0500 Subject: [PATCH 076/113] stash changes for review Signed-off-by: Robert Shaw --- .../model_executor/layers/fused_moe/config.py | 1 - .../layers/fused_moe/oracle/fp8.py | 1 - .../layers/fused_moe/oracle/nvfp4.py | 288 +++++++++++++----- .../layers/quantization/modelopt.py | 102 ++++--- .../quantization/utils/flashinfer_fp4_moe.py | 63 ++++ 5 files changed, 332 insertions(+), 123 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index be558451e295..2e361edac70f 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -154,7 +154,6 @@ def __post_init__(self): assert not self.per_token_quant assert not self.static_input_quant assert self.block_size is not None - assert self.per_block_quant or self.per_token_quant or self.per_tensor_quant if self.is_unquantized: assert self.act_dtype in UNQUANTIZED_DTYPES diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 0540c2e8b98f..839695630276 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -190,7 +190,6 @@ def _return_or_raise( if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: backend = Fp8MoeBackend.FLASHINFER_TRTLLM - # TODO: validate activation format supported, reason = is_supported_config_trtllm( config, quant_scheme, activation_format ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 6198ff415ab7..7c4dd787e872 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -1,33 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum +from typing import TYPE_CHECKING import torch import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, + FusedMoEQuantScheme, nvfp4_moe_quant_config, nvfp4_w4a16_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - CutlassExpertsFp4, -) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, -) -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - MarlinExperts, -) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - is_flashinfer_fp4_cutedsl_moe_available, - is_flashinfer_fp4_cutlass_moe_available, + is_supported_config_trtllm, + # is_flashinfer_fp4_cutedsl_moe_available, + # is_flashinfer_fp4_cutlass_moe_available, prepare_nvfp4_moe_layer_for_fi_or_cutlass, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( @@ -35,19 +29,23 @@ get_flashinfer_moe_backend, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - is_fp4_marlin_supported, + # is_fp4_marlin_supported, prepare_nvfp4_moe_layer_for_marlin, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - cutlass_fp4_supported, -) + +# from vllm.model_executor.layers.quantization.utils.quant_utils import ( +# cutlass_fp4_supported, +# ) + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.layer import FusedMoE logger = init_logger(__name__) class NvFp4MoeBackend(Enum): - FLASHINFER_CUTLASS = "FlashInfer CUTLASS" FLASHINFER_TRTLLM = "FlashInfer TRTLLM" + FLASHINFER_CUTLASS = "FlashInfer CUTLASS" FLASHINFER_CUTEDSL = "FlashInfer CUTEDSL" VLLM_CUTLASS = "vLLM CUTASS" MARLIN = "vLLM MARLIN" @@ -77,38 +75,162 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool: ] -def select_nvfp4_moe_backend() -> NvFp4MoeBackend: +def backend_2_kernel_cls( + backend: NvFp4MoeBackend, +) -> type[mk.FusedMoEPermuteExpertsUnpermute]: + if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: + raise NotImplementedError + + elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, + ) + + return FlashInferExperts + + elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL: + from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( + FlashInferCuteDSLExperts, + ) + + return FlashInferCuteDSLExperts + + elif backend == NvFp4MoeBackend.VLLM_CUTLASS: + from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import ( + TritonOrCutlassExperts, + ) + + return TritonOrCutlassExperts + + elif backend == NvFp4MoeBackend.MARLIN: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + MarlinExperts, + ) + + return MarlinExperts + else: + raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}") + + +def select_nvfp4_moe_backend( + config: FusedMoEConfig, + quant_scheme: FusedMoEQuantScheme, + activation_format: mk.FusedMoEActivationFormat, +) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: + """ + Select the primary NvFP4 MoE backend + Note: Shape-specific fallbacks may still occur at runtime. + """ + def _make_log_backend(backend: NvFp4MoeBackend): return f"Using {backend.value} backend for NvFp4 MoE" - if cutlass_fp4_supported() and not envs.VLLM_TEST_FORCE_FP8_MARLIN: - allow_flashinfer = ( - is_flashinfer_fp4_cutlass_moe_available() - or is_flashinfer_fp4_cutedsl_moe_available() + def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str: + if reason: + return ( + f"NvFP4 MoE backend `{backend.value}` does not support the " + f"deployment configuration since {reason}." + ) + else: + return ( + f"NvFP4 MoE backend `{backend.value}` does not support the " + "deployment configuration." + ) + + def _return_or_raise( + backend: NvFp4MoeBackend, + config: FusedMoEConfig, + scheme: FusedMoEQuantScheme, + activation_format: mk.FusedMoEActivationFormat, + ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: + k_cls = backend_2_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, config, scheme, activation_format ) - if allow_flashinfer and envs.VLLM_USE_FLASHINFER_MOE_FP4: - backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()] + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, k_cls + raise ValueError(_make_log_unsupported(backend, reason)) + + # NOTE: the kernels are selected in the following order. + AVAILABLE_BACKENDS = [ + NvFp4MoeBackend.FLASHINFER_TRTLLM, + NvFp4MoeBackend.FLASHINFER_CUTEDSL, + NvFp4MoeBackend.FLASHINFER_CUTLASS, + NvFp4MoeBackend.VLLM_CUTLASS, + ] + + if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"): + if not envs.VLLM_USE_FLASHINFER_MOE_FP8: + # If the user rejects FlashInfer remove those backends. + for fi_backend in FLASHINFER_NVFP4_MOE_BACKENDS: + AVAILABLE_BACKENDS.remove(fi_backend) + + elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): + # If user is explicit about backend, validate it. + fi_backend = get_flashinfer_moe_backend() + + if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: + backend = NvFp4MoeBackend.FLASHINFER_TRTLLM + supported, reason = is_supported_config_trtllm( + config, quant_scheme, activation_format + ) + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, None + else: + raise ValueError(_make_log_unsupported(backend, reason)) + else: + backend = fi_2_vllm_backend_map[fi_backend] + return _return_or_raise( + backend, config, quant_scheme, activation_format + ) else: - backend = NvFp4MoeBackend.VLLM_CUTLASS - elif is_fp4_marlin_supported(): + # If the user is not explicit about the backend, try each. + for backend in FLASHINFER_NVFP4_MOE_BACKENDS: + k_cls = backend_2_kernel_cls(backend) + if k_cls.is_supported_config( + k_cls, config, quant_scheme, activation_format + ): + logger.info_once(_make_log_backend(backend)) + return backend, k_cls + + raise NotImplementedError( + "Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no " + "FlashInfer NVFP4 MoE backend supports the configuration." + ) + + if envs.VLLM_TEST_FORCE_FP8_MARLIN: backend = NvFp4MoeBackend.MARLIN - else: - raise ValueError("No NvFp4 kernel backend available for NvFp4 MoE.") + return _return_or_raise(backend, config, quant_scheme, activation_format) - # Log warning if FI backend requested but not available. - if ( - backend not in FLASHINFER_NVFP4_MOE_BACKENDS - and envs.VLLM_USE_FLASHINFER_MOE_FP4 - ): - logger.warning_once( - "Requested FlashInfer backend for NvFp4 MoE, but it's not available. " - "Falling back to %s for NvFp4 MoE", - backend.value, - scope="local", - ) - else: - logger.info_once(_make_log_backend(backend), scope="local") - return backend + # Select kernels in order of backend. + for backend in AVAILABLE_BACKENDS: + if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: + k_cls = None # type: ignore[assignment] + supported, reason = is_supported_config_trtllm( + config, + quant_scheme, + activation_format, + ) + else: + k_cls = backend_2_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, + config, + quant_scheme, + activation_format, + ) + + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + else: + logger.info_once(_make_log_unsupported(backend, reason), scope="local") + + raise NotImplementedError( + "No NvFp4 MoE backend supports the deployment configuration." + ) def convert_to_nvfp4_moe_kernel_format( @@ -227,46 +349,52 @@ def make_nvfp4_moe_quant_config( def make_nvfp4_moe_kernel( - backend: NvFp4MoeBackend, - quant_config: FusedMoEQuantConfig, + layer: "FusedMoE", + moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, -) -> mk.FusedMoEModularKernel | None: - assert moe_config.dp_size == 1 - - UNSUPPORTED_BACKENDS = [ - # TRTLLM does not use the modular kernl abstraction. - NvFp4MoeBackend.FLASHINFER_TRTLLM, - # CUTEDSL is used with BATCHED (masked) format only. - # TODO: add here once we support dp/ep via the oracle. - NvFp4MoeBackend.FLASHINFER_CUTEDSL, - ] + nvfp4_backend: NvFp4MoeBackend, + experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], +) -> mk.FusedMoEModularKernel: + # Create Prepare/Finalize. + prepare_finalize = maybe_make_prepare_finalize( + moe=moe_config, + quant_config=moe_quant_config, + routing_tables=None, # TODO: init routing tables here? + defer_input_quant=experts_cls.should_pf_defer_input_quant(moe_quant_config), + allow_new_interface=True, + ) - if backend in UNSUPPORTED_BACKENDS: - return None + logger.info_once("Using %s", prepare_finalize.__class__.__name__) - elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: - return mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(defer_input_quant=True), - FlashInferExperts( - moe_config=moe_config, - quant_config=quant_config, - ), + # Create Experts. + if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.Standard: + experts = experts_cls.make_standard_experts( + moe_config=moe_config, + quant_config=moe_quant_config, ) - - elif backend == NvFp4MoeBackend.VLLM_CUTLASS: - return mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(defer_input_quant=True), - CutlassExpertsFp4( - moe_config=moe_config, - quant_config=quant_config, - ), + else: + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None + experts = experts_cls.make_batched_experts( + moe_config=moe_config, + quant_config=moe_quant_config, + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), ) - elif backend == NvFp4MoeBackend.MARLIN: - return mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - MarlinExperts(quant_config=quant_config), - ) + # NOTE(rob): we only want the ModularKernel to control the SharedExpert + # if we are using all2all (for SBO). Need to make a change somewhere + # else to prevent double running the Shared Expert. + # This needs to be refactored. + kernel = mk.FusedMoEModularKernel( + prepare_finalize, + experts, + shared_experts=( + getattr(layer, "shared_expert", None) + if moe_config.moe_parallel_config.use_all2all_kernels + else None + ), + moe_parallel_config=moe_config.moe_parallel_config, + ) - else: - raise ValueError(f"Unknown NvFp4 MoE backend: {backend}") + return kernel diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index dc6844fe1e46..8483d4594cff 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -32,13 +32,12 @@ select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( - FLASHINFER_NVFP4_MOE_BACKENDS, NvFp4MoeBackend, convert_to_nvfp4_moe_kernel_format, - is_global_sf_supported_for_nvfp4_backend, make_nvfp4_moe_kernel, make_nvfp4_moe_quant_config, select_nvfp4_moe_backend, + is_global_sf_supported_for_nvfp4_backend, ) from vllm.model_executor.layers.linear import ( LinearBase, @@ -52,10 +51,8 @@ ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_prepare_finalize, flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_routed_moe, - select_nvfp4_gemm_impl, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, @@ -1360,55 +1357,76 @@ def __init__( ) -> None: super().__init__(moe_config) self.quant_config = quant_config - self.nvfp4_backend = select_nvfp4_moe_backend() - # TODO: move this type of check into the oracle. - if ( - not self.moe.is_act_and_mul - and not self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS - ): - raise NotImplementedError( - "Non-gated activations are only supported by FlashInfer " - "CUTLASS NvFP4 MoE backend." - ) + # Create quant scheme, will be used later to select the quant scales. + # NOTE(rob): we should update QuantConfig to just be the think ts + # holds the scales. Should change the name. + quant_scheme = FusedMoEQuantScheme( + weight_dtype="nvfp4", + act_dtype="nvfp4", + per_tensor_quant=False, + per_token_quant=False, + block_size=None, + static_input_quant=False, + ) + + # Select NvFp4 MoE backend + # NOTE(rob): this is kind of a hack. We need to peak into + # the prepare-finalize selection to determine if we are using + # the batched or standard expert format. + use_batched = ( + self.moe.moe_parallel_config.use_deepep_ll_kernels + or self.moe.moe_parallel_config.use_pplx_kernels + ) + self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( + config=self.moe, + quant_scheme=quant_scheme, + activation_format=( + mk.FusedMoEActivationFormat.BatchedExperts + if use_batched + else mk.FusedMoEActivationFormat.Standard + ), + ) + + # Delay creation of the kernel until after process-weights. + self.kernel: mk.FusedMoEModularKernel | None = None + # Used for weight loading. self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.nvfp4_backend ) - self.kernel: mk.FusedMoEModularKernel | None = None + + # prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( + # self.moe + # ) + + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None + + @property + def supports_mk_interally(self) -> bool: + return True def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM] - if self.nvfp4_backend in UNSUPPORTED: - return None - elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: - # TP case: avoid convert to ModularKernelMethod - to be refactored. - if self.moe.dp_size == 1: - return None - # For now, fp4 moe only works with the flashinfer dispatcher. - prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( - self.moe - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - else: - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + "ModelOptNvFp4FusedMoE uses the new modular kernel " + "initialization logic. This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: - assert self.moe_quant_config is not None - experts = select_nvfp4_gemm_impl( - self.moe, - self.moe_quant_config, - allow_flashinfer=self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS, + raise ValueError( + "ModelOptNvFp4FusedMoE uses the new modular kernel " + "initialization logic. This function should not be called." ) - logger.debug_once("Using %s", experts.__class__.__name__) - return experts def uses_weight_scale_2_pattern(self) -> bool: """ @@ -1581,12 +1599,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: replace_parameter(layer, "w2_input_scale", a2_scale) self.moe_quant_config = self.get_fused_moe_quant_config(layer) - use_dp = self.moe.dp_size > 1 - if self.moe_quant_config is not None and not use_dp: + if self.moe_quant_config is not None: + assert self.experts_cls is not None self.kernel = make_nvfp4_moe_kernel( - backend=self.nvfp4_backend, - quant_config=self.moe_quant_config, + layer=layer, + moe_quant_config=self.moe_quant_config, moe_config=self.moe, + nvfp4_backend=self.nvfp4_backend, + experts_cls=self.experts_cls, ) def prepare_dp_allgather_tensor( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 912ff5a4a12a..af349d57cf1b 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -11,7 +11,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, + FusedMoEParallelConfig, FusedMoEQuantConfig, + FusedMoEQuantScheme, RoutingMethodType, ) from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( @@ -47,6 +49,65 @@ "build_flashinfer_fp4_cutlass_moe_prepare_finalize", ] +# +# Methods used by the oracle for kernel selection. +# + + +def _supports_current_device() -> bool: + """Supports only Blackwell-family GPUs.""" + p = current_platform + # Add check flashinfer trtllm is available + return p.is_cuda() and p.is_device_capability_family(10) + + +def _supports_no_act_and_mul() -> bool: + """Does not support non-gated MoE (i.e. Nemotron-Nano).""" + return False + + +def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + """Supports Fp8 per-tensor, Fp8 block, and Nvfp4 quantization.""" + return quant_scheme.is_nvfp4_w4a4 + + +def _supports_activation(activation: str) -> bool: + """Supports silu activation only.""" + return activation in ["silu"] + + +def _supports_moe_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + """Supports EP.""" + return True + + +def is_supported_config_trtllm( + moe_config: FusedMoEConfig, + moe_quant_scheme: FusedMoEQuantScheme, + activation_format: mk.FusedMoEActivationFormat, +) -> tuple[bool, str | None]: + """ + This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config + """ + + def _make_reason(reason: str) -> str: + return f"kernel does not support {reason}" + + if not _supports_current_device(): + return False, _make_reason("current device") + elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()): + return False, _make_reason("no act_and_mul MLP layer") + elif not _supports_activation(moe_config.activation): + return False, _make_reason(f"{moe_config.activation} activation") + elif not _supports_quant_scheme(moe_quant_scheme): + return False, _make_reason("quantization scheme") + elif not _supports_moe_parallel_config(moe_config.moe_parallel_config): + return False, _make_reason("parallel config") + elif activation_format != mk.FusedMoEActivationFormat.Standard: + return False, _make_reason("activation format") + + return True, None + def is_flashinfer_fp4_cutlass_moe_available() -> bool: """Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" @@ -284,9 +345,11 @@ def flashinfer_trtllm_fp4_moe( # Quantize input to FP4 if isinstance(x, tuple): + print("ALREADY QUANTIZED") hidden_states_fp4, hidden_states_scale_linear_fp4 = x else: # hidden_states is the already quantized + print("QUANTIIZING HERE") (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( x, layer.a1_gscale, From 75aca72e73bcaa2f0675f6208b9c07c3ac2132bf Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 12:05:31 -0500 Subject: [PATCH 077/113] made modelopt start up Signed-off-by: Robert Shaw --- .../layers/fused_moe/flashinfer_cutlass_moe.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index e87fec015c60..a56cb1e9ccce 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -73,10 +73,11 @@ def __init__( @staticmethod def should_pf_defer_input_quant(quant_config: FusedMoEQuantConfig) -> bool: - """ - FlashInfer CUTLASS Block FP8 path handles input quantization. - """ - return quant_config.is_block_quantized + # NVFP4 kenrels and FP8 block-quantized kernels apply + # input quantization inside FusedMoEPermuteExpertsUnpermute. + return quant_config.use_nvfp4_w4a4 or ( + quant_config.use_fp8_w8a8 and quant_config.is_block_quantized + ) @staticmethod def _supports_current_device() -> bool: From 32875123abefcc4258fbc610ee69456e14512399 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 12:13:47 -0500 Subject: [PATCH 078/113] still enable flashinfer Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py | 2 +- vllm/model_executor/layers/fused_moe/oracle/nvfp4.py | 1 - vllm/model_executor/layers/quantization/modelopt.py | 3 +-- .../layers/quantization/utils/flashinfer_fp4_moe.py | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index b082f64de0bd..910d07f2427b 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -29,7 +29,7 @@ def _supports_current_device() -> bool: """Supports only Blackwell-family GPUs.""" p = current_platform # Add check flashinfer trtllm is available - return p.is_cuda() and p.is_device_capability_family(10) + return p.is_cuda() and p.is_device_capability_family(100) def _supports_no_act_and_mul() -> bool: diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 7c4dd787e872..fd1cb77addf2 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -352,7 +352,6 @@ def make_nvfp4_moe_kernel( layer: "FusedMoE", moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, - nvfp4_backend: NvFp4MoeBackend, experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], ) -> mk.FusedMoEModularKernel: # Create Prepare/Finalize. diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 8483d4594cff..d4c3c44b6f85 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -34,10 +34,10 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( NvFp4MoeBackend, convert_to_nvfp4_moe_kernel_format, + is_global_sf_supported_for_nvfp4_backend, make_nvfp4_moe_kernel, make_nvfp4_moe_quant_config, select_nvfp4_moe_backend, - is_global_sf_supported_for_nvfp4_backend, ) from vllm.model_executor.layers.linear import ( LinearBase, @@ -1605,7 +1605,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, - nvfp4_backend=self.nvfp4_backend, experts_cls=self.experts_cls, ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index af349d57cf1b..a78e1065cbe8 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -58,7 +58,7 @@ def _supports_current_device() -> bool: """Supports only Blackwell-family GPUs.""" p = current_platform # Add check flashinfer trtllm is available - return p.is_cuda() and p.is_device_capability_family(10) + return p.is_cuda() and p.is_device_capability_family(100) def _supports_no_act_and_mul() -> bool: From 84b83b6be1456d45260192a4446dd674f97577ce Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 12:14:25 -0500 Subject: [PATCH 079/113] remove comments Signed-off-by: Robert Shaw --- .../layers/quantization/utils/flashinfer_fp4_moe.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index a78e1065cbe8..53591945e47d 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -345,11 +345,9 @@ def flashinfer_trtllm_fp4_moe( # Quantize input to FP4 if isinstance(x, tuple): - print("ALREADY QUANTIZED") hidden_states_fp4, hidden_states_scale_linear_fp4 = x else: # hidden_states is the already quantized - print("QUANTIIZING HERE") (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( x, layer.a1_gscale, From 3f90b88492fd17b9ca55d8ca73db94ac525c8699 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 12:58:49 -0500 Subject: [PATCH 080/113] appears that we have AG/RS working for nvfp4 Signed-off-by: Robert Shaw --- .../layers/fused_moe/all2all_utils.py | 5 ----- vllm/model_executor/layers/fused_moe/config.py | 12 ------------ .../layers/fused_moe/flashinfer_cutlass_moe.py | 13 ++++++++----- vllm/model_executor/layers/fused_moe/layer.py | 13 +------------ .../layers/fused_moe/modular_kernel.py | 4 +++- vllm/model_executor/layers/fused_moe/oracle/fp8.py | 5 ++++- .../model_executor/layers/fused_moe/oracle/nvfp4.py | 13 ++++++++----- .../layers/fused_moe/prepare_finalize.py | 10 +++++++++- .../layers/fused_moe/rocm_aiter_fused_moe.py | 4 +++- .../layers/fused_moe/shared_fused_moe.py | 5 +---- 10 files changed, 37 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 7913c88f6d94..b2efb5dcda1c 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -90,11 +90,6 @@ def maybe_make_prepare_finalize( prepare_finalize: FusedMoEPrepareAndFinalize | None = None - # TODO(rob): update this as part of the MoE refactor. - assert not moe.use_flashinfer_cutlass_kernels, ( - "Must be created in modelopt.py or fp8.py" - ) - if moe.use_pplx_kernels: assert quant_config is not None diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 2e361edac70f..8b75c5da837a 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -20,7 +20,6 @@ ) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform -from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.import_utils import has_triton_kernels from vllm.utils.math_utils import cdiv @@ -1194,14 +1193,3 @@ def use_fi_all2allv_kernels(self): @property def use_naive_kernels(self): return self.moe_parallel_config.use_naive_kernels - - @property - def use_flashinfer_cutlass_kernels(self): - """ - Whether to use FlashInfer cutlass kernels for NVFP4 MoE. - """ - return ( - envs.VLLM_USE_FLASHINFER_MOE_FP4 - and has_flashinfer_cutlass_fused_moe() - and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput" - ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index a56cb1e9ccce..1335266773be 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -72,12 +72,15 @@ def __init__( self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized @staticmethod - def should_pf_defer_input_quant(quant_config: FusedMoEQuantConfig) -> bool: - # NVFP4 kenrels and FP8 block-quantized kernels apply + def should_pf_defer_input_quant( + moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig + ) -> bool: + # NVFP4 TP kernels and FP8 block-quantized kernels apply # input quantization inside FusedMoEPermuteExpertsUnpermute. - return quant_config.use_nvfp4_w4a4 or ( - quant_config.use_fp8_w8a8 and quant_config.is_block_quantized - ) + return ( + quant_config.use_nvfp4_w4a4 + and not moe_config.moe_parallel_config.use_all2all_kernels + ) or (quant_config.use_fp8_w8a8 and quant_config.is_block_quantized) @staticmethod def _supports_current_device() -> bool: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 93cb4778f864..a1c4955ec6ba 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -599,9 +599,6 @@ def __init__( activation=activation, device=vllm_config.device_config.device, ) - self.moe_config_use_flashinfer_cutlass_kernels = ( - self.moe_config.use_flashinfer_cutlass_kernels - ) self.quant_config = quant_config @@ -770,14 +767,6 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels - @property - def use_flashinfer_cutlass_kernels(self): - return ( - self.moe_quant_config is not None - and self.moe_quant_config.quant_dtype == "nvfp4" - and self.moe_config_use_flashinfer_cutlass_kernels - ) - @property def use_marlin_kernels(self): return getattr(self.quant_method, "use_marlin", False) @@ -787,7 +776,7 @@ def use_dp_chunking(self) -> bool: return ( self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels - or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) + or self.moe_parallel_config.use_fi_all2allv_kernels ) and envs.VLLM_ENABLE_MOE_DP_CHUNK @property diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index baa7c72351b4..39234035b51d 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -389,7 +389,9 @@ def __init__( self._num_dispatchers: int | None = None @staticmethod - def should_pf_defer_input_quant(quant_config: FusedMoEQuantConfig) -> bool: + def should_pf_defer_input_quant( + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig + ) -> bool: """ Whether or not the PrepareFinalize should defer input quantization in the prepare step. If True, then the Experts kernel will diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 839695630276..fa1cf03fcf08 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -406,9 +406,12 @@ def make_fp8_moe_kernel( moe=moe_config, quant_config=moe_quant_config, routing_tables=None, # TODO: init routing tables here? - defer_input_quant=experts_cls.should_pf_defer_input_quant(moe_quant_config), + defer_input_quant=experts_cls.should_pf_defer_input_quant( + moe_config, moe_quant_config + ), allow_new_interface=True, ) + assert prepare_finalize is not None logger.info_once("Using %s", prepare_finalize.__class__.__name__) diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index fd1cb77addf2..15a57dc4540b 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -123,17 +123,17 @@ def select_nvfp4_moe_backend( """ def _make_log_backend(backend: NvFp4MoeBackend): - return f"Using {backend.value} backend for NvFp4 MoE" + return f"Using '{backend.value}' backend for NvFp4 MoE" def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str: if reason: return ( - f"NvFP4 MoE backend `{backend.value}` does not support the " + f"NvFP4 MoE backend '{backend.value}' does not support the " f"deployment configuration since {reason}." ) else: return ( - f"NvFP4 MoE backend `{backend.value}` does not support the " + f"NvFP4 MoE backend '{backend.value}' does not support the " "deployment configuration." ) @@ -161,7 +161,7 @@ def _return_or_raise( ] if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"): - if not envs.VLLM_USE_FLASHINFER_MOE_FP8: + if not envs.VLLM_USE_FLASHINFER_MOE_FP4: # If the user rejects FlashInfer remove those backends. for fi_backend in FLASHINFER_NVFP4_MOE_BACKENDS: AVAILABLE_BACKENDS.remove(fi_backend) @@ -359,9 +359,12 @@ def make_nvfp4_moe_kernel( moe=moe_config, quant_config=moe_quant_config, routing_tables=None, # TODO: init routing tables here? - defer_input_quant=experts_cls.should_pf_defer_input_quant(moe_quant_config), + defer_input_quant=experts_cls.should_pf_defer_input_quant( + moe_config, moe_quant_config + ), allow_new_interface=True, ) + assert prepare_finalize is not None logger.info_once("Using %s", prepare_finalize.__class__.__name__) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 1546810e457c..d04cd3806881 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -57,12 +57,14 @@ def prepare( a1q = a1 a1q_scale = None else: + use_nvfp4 = quant_config.use_nvfp4_w4a4 a1q, a1q_scale = moe_kernel_quantize_input( a1, - quant_config.a1_scale, + quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape, + is_fp4_scale_swizzled=False, ) # TODO - this is just for deepgemm? @@ -99,6 +101,12 @@ def prepare( if use_int8_view: a1q = a1q.view(current_platform.fp8_dtype()) + # TODO(rob): move this out of the experts. + if use_nvfp4: + from vllm.utils.flashinfer import nvfp4_block_scale_interleave + + a1q_scale = nvfp4_block_scale_interleave(a1q_scale) + return a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights def finalize( diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 915944d49c4e..4e86a21bce91 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -276,7 +276,9 @@ def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard @staticmethod - def should_pf_defer_input_quant(quant_config: FusedMoEQuantConfig) -> bool: + def should_pf_defer_input_quant( + fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig + ) -> bool: """ AITER Fused MoE kernels handle input quantization. """ diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index a143347b19f2..b2e98cae2d14 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -35,10 +35,7 @@ def __init__( backend = self.moe_parallel_config.all2all_backend self.use_overlapped = ( use_overlapped - and not ( - (self.enable_eplb and backend != "allgather_reducescatter") - or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) - ) + and not (self.enable_eplb and backend != "allgather_reducescatter") and self._shared_experts is not None ) From 19537ff24ef3ef8db526172f82752160db8cf79c Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 12:59:59 -0500 Subject: [PATCH 081/113] updated Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/prepare_finalize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index d04cd3806881..d823e51e6922 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -101,7 +101,7 @@ def prepare( if use_int8_view: a1q = a1q.view(current_platform.fp8_dtype()) - # TODO(rob): move this out of the experts. + # TODO(rob): move this out of the P/F. if use_nvfp4: from vllm.utils.flashinfer import nvfp4_block_scale_interleave From dfefddf9cc3abbaeae332947ad82fa0f70bca41a Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 13:14:16 -0500 Subject: [PATCH 082/113] updated Signed-off-by: Robert Shaw --- .../layers/fused_moe/flashinfer_cutlass_moe.py | 1 + vllm/model_executor/layers/fused_moe/prepare_finalize.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 1335266773be..cec628c0c9a9 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -222,6 +222,7 @@ def apply( assert self.w1_scale is not None and self.w2_scale is not None, ( "w1_scale and w2_scale must not be None for FlashInferExperts" ) + # Flashinfer CUTLASS kernel takes scalar global scales, # min because inv_scale. quant_scales = [ diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index d823e51e6922..3ca85a5045cb 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -11,6 +11,7 @@ TopKWeightAndReduceDelegate, ) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.utils.flashinfer import nvfp4_block_scale_interleave class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): @@ -101,10 +102,11 @@ def prepare( if use_int8_view: a1q = a1q.view(current_platform.fp8_dtype()) - # TODO(rob): move this out of the P/F. + # NOTE(rob): this is for FLASHINFER_CUTLASS. There are + # currently no other kernels that use this prepare/finalize + # with nvfp4. If we add others in the future, we may need + # a way to register how to shuffle into the kernel format. if use_nvfp4: - from vllm.utils.flashinfer import nvfp4_block_scale_interleave - a1q_scale = nvfp4_block_scale_interleave(a1q_scale) return a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights From 232a5b908ef85b07a9d893a680367774251a251b Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 13:18:43 -0500 Subject: [PATCH 083/113] updared comment Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/prepare_finalize.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 3ca85a5045cb..e853c4b2147f 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -102,10 +102,10 @@ def prepare( if use_int8_view: a1q = a1q.view(current_platform.fp8_dtype()) - # NOTE(rob): this is for FLASHINFER_CUTLASS. There are - # currently no other kernels that use this prepare/finalize - # with nvfp4. If we add others in the future, we may need - # a way to register how to shuffle into the kernel format. + # NOTE: shuffle into format expected by FLASHINFER_CUTLASS + # There are currently no other kernels that use this P/F + # with nvfp4. If we add other kernels in the future, we + # can regsiter a shuffle that gets called here. if use_nvfp4: a1q_scale = nvfp4_block_scale_interleave(a1q_scale) From d2dd97936c55fa1d98428bbeb5d2e1e773d806e8 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 13:21:46 -0500 Subject: [PATCH 084/113] make marling work with current device Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/fused_marlin_moe.py | 3 +-- vllm/model_executor/layers/fused_moe/oracle/nvfp4.py | 7 ------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index e333d7ab1a46..09a7bef1b00a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -558,8 +558,7 @@ def __init__( @staticmethod def _supports_current_device() -> bool: p = current_platform - # Is this right? Can we do < Ampere? - return p.is_cuda() and p.has_device_capability((8, 0)) + return p.is_cuda() and p.has_device_capability((7, 5)) @staticmethod def _supports_no_act_and_mul() -> bool: diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 15a57dc4540b..bd15688f5e35 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -20,8 +20,6 @@ ) from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( is_supported_config_trtllm, - # is_flashinfer_fp4_cutedsl_moe_available, - # is_flashinfer_fp4_cutlass_moe_available, prepare_nvfp4_moe_layer_for_fi_or_cutlass, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( @@ -29,14 +27,9 @@ get_flashinfer_moe_backend, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - # is_fp4_marlin_supported, prepare_nvfp4_moe_layer_for_marlin, ) -# from vllm.model_executor.layers.quantization.utils.quant_utils import ( -# cutlass_fp4_supported, -# ) - if TYPE_CHECKING: from vllm.model_executor.layers.fused_moe.layer import FusedMoE From d35c2475ddf2bfd9b848dee9564670f805e6a848 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 13:46:32 -0500 Subject: [PATCH 085/113] added compressed tensors nvfp4 Signed-off-by: Robert Shaw --- .../model_executor/layers/fused_moe/config.py | 4 + .../compressed_tensors_moe.py | 108 ++++++++++-------- .../layers/quantization/modelopt.py | 4 - 3 files changed, 63 insertions(+), 53 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 8b75c5da837a..297faa9084c0 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -167,6 +167,10 @@ def is_fp8_w8a8(self) -> bool: def is_nvfp4_w4a4(self) -> bool: return self.weight_dtype == "nvfp4" and self.act_dtype == "nvfp4" + @property + def is_nvfp4_w4a16(self) -> bool: + return self.weight_dtype == "nvfp4" and self.act_dtype is None + @property def is_mxfp4_w4a4(self) -> bool: return self.weight_dtype == "mxfp4" and self.act_dtype == "mxfp4" diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 9c3240d25437..7c8dc69e0af5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -50,7 +50,6 @@ select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( - FLASHINFER_NVFP4_MOE_BACKENDS, NvFp4MoeBackend, convert_to_nvfp4_moe_kernel_format, is_global_sf_supported_for_nvfp4_backend, @@ -63,10 +62,8 @@ WNA16_SUPPORTED_TYPES_MAP, ) from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_prepare_finalize, flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_routed_moe, - select_nvfp4_gemm_impl, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, @@ -82,9 +79,6 @@ marlin_make_workspace_new, marlin_moe_permute_scales, ) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - is_fp4_marlin_supported, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( convert_bf16_scales_to_fp8, convert_packed_uint4b8_to_signed_int4_inplace, @@ -200,7 +194,7 @@ def get_moe_method( f"or None for NVFP4A16, found {input_quant}", ) return CompressedTensorsW4A4Nvfp4MoEMethod( - layer.moe_config, layer_name, use_marlin=input_quant is None + layer.moe_config, layer_name, use_a16=(input_quant is None) ) elif ( quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) @@ -234,7 +228,7 @@ def __init__( self, moe: FusedMoEConfig, layer_name: str | None = None, - use_marlin: bool = False, + use_a16: bool = False, ): if not moe.is_act_and_mul: raise ValueError( @@ -243,21 +237,53 @@ def __init__( ) super().__init__(moe) - self.group_size = 16 - if use_marlin: - if is_fp4_marlin_supported(): - self.nvfp4_backend = NvFp4MoeBackend.MARLIN - else: - raise ValueError( - "Marlin FP4 MoE kernel requested but not ", - "supported on current platform.", - ) - else: - self.nvfp4_backend = select_nvfp4_moe_backend() + # Create quant scheme, will be used later to select the quant scales. + # NOTE(rob): we should update QuantConfig to just be the think ts + # holds the scales. Should change the name. + quant_scheme = FusedMoEQuantScheme( + weight_dtype="nvfp4", + act_dtype="nvfp4" if not use_a16 else None, + per_tensor_quant=False, + per_token_quant=False, + block_size=None, + static_input_quant=False, + ) + + # Select NvFp4 MoE backend + # NOTE(rob): this is kind of a hack. We need to peak into + # the prepare-finalize selection to determine if we are using + # the batched or standard expert format. + use_batched = ( + self.moe.moe_parallel_config.use_deepep_ll_kernels + or self.moe.moe_parallel_config.use_pplx_kernels + ) + self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( + config=self.moe, + quant_scheme=quant_scheme, + activation_format=( + mk.FusedMoEActivationFormat.BatchedExperts + if use_batched + else mk.FusedMoEActivationFormat.Standard + ), + ) + + # Delay creation of the kernel until after process-weights. + self.kernel: mk.FusedMoEModularKernel | None = None + + # Used for weight loading. self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.nvfp4_backend ) - self.kernel: mk.FusedMoEModularKernel | None = None + + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None + + @property + def supports_mk_interally(self) -> bool: + return True def create_weights( self, @@ -430,50 +456,34 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_input_scale = a13_scale layer.w2_input_scale = a2_scale - # Initialize the kernel that will be called in apply(). self.moe_quant_config = self.get_fused_moe_quant_config(layer) - use_dp = self.moe.dp_size > 1 - if self.moe_quant_config is not None and not use_dp: + if self.moe_quant_config is not None: + assert self.experts_cls is not None self.kernel = make_nvfp4_moe_kernel( - backend=self.nvfp4_backend, - quant_config=self.moe_quant_config, + layer=layer, + moe_quant_config=self.moe_quant_config, moe_config=self.moe, + experts_cls=self.experts_cls, ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM] - if self.nvfp4_backend in UNSUPPORTED: - return None - elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: - # TP case: avoid convert to ModularKernelMethod - to be refactored. - if self.moe.dp_size == 1: - return None - # For now, fp4 moe only works with the flashinfer dispatcher. - prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( - self.moe - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - else: - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + "CompressedTensorsW4A4NvFp4MoE uses the new modular kernel " + "initialization logic. This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: - assert self.moe_quant_config is not None - """Return the appropriate GEMM experts implementation.""" - experts = select_nvfp4_gemm_impl( - self.moe, - self.moe_quant_config, - allow_flashinfer=(self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS), + ) -> FusedMoEPermuteExpertsUnpermute: + raise ValueError( + "CompressedTensorsW4A4NvFp4MoE uses the new modular kernel " + "initialization logic. This function should not be called." ) - logger.debug_once("Using %s", experts.__class__.__name__) - return experts def get_fused_moe_quant_config( self, layer: torch.nn.Module diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index d4c3c44b6f85..da7844c3fcd9 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1395,10 +1395,6 @@ def __init__( self.nvfp4_backend ) - # prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( - # self.moe - # ) - @property def topk_indices_dtype(self) -> torch.dtype | None: if self.kernel is not None: From 4c9656a3c6098db11ec6ddcb53ff3529b98bbf5d Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 13:49:21 -0500 Subject: [PATCH 086/113] added compressed tensors nvfp4 - nit Signed-off-by: Robert Shaw --- .../quantization/compressed_tensors/compressed_tensors_moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 7c8dc69e0af5..72bb77ccc419 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -237,6 +237,8 @@ def __init__( ) super().__init__(moe) + self.group_size = 16 + # Create quant scheme, will be used later to select the quant scales. # NOTE(rob): we should update QuantConfig to just be the think ts # holds the scales. Should change the name. From 9d5d3ee962c2520e3141d580a4431d61747af5d0 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 13:58:25 -0500 Subject: [PATCH 087/113] remove shared expert overlap functionality Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/shared_fused_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index b2e98cae2d14..c3e3744935aa 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -35,7 +35,10 @@ def __init__( backend = self.moe_parallel_config.all2all_backend self.use_overlapped = ( use_overlapped - and not (self.enable_eplb and backend != "allgather_reducescatter") + and not ( + (self.enable_eplb and backend != "allgather_reducescatter") + or self.moe_config.use_fi_all2allv_kernels + ) and self._shared_experts is not None ) From 553efc15b2ab802d3b786069140cc35991769d4a Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 14:16:38 -0500 Subject: [PATCH 088/113] make cutedsl work Signed-off-by: Robert Shaw --- .../layers/fused_moe/flashinfer_cutedsl_moe.py | 14 ++++++-------- .../layers/fused_moe/oracle/nvfp4.py | 2 +- vllm/utils/flashinfer.py | 1 + 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index 46914187804c..5aec5abde971 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -7,6 +7,7 @@ from vllm import envs from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, FusedMoEQuantScheme, @@ -17,7 +18,6 @@ from vllm.platforms import current_platform from vllm.utils.flashinfer import ( flashinfer_cutedsl_grouped_gemm_nt_masked, - has_flashinfer_cutedsl_grouped_gemm_nt_masked, scaled_fp4_grouped_quantize, silu_and_mul_scaled_nvfp4_experts_quantize, ) @@ -28,14 +28,14 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - out_dtype: torch.dtype, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): - super().__init__(quant_config) + super().__init__(moe_config, quant_config) assert quant_config.quant_dtype == "nvfp4", ( "Only nvfp4 quantization are currently supported." ) - self.out_dtype = out_dtype + self.out_dtype = moe_config.in_dtype @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: @@ -43,10 +43,8 @@ def activation_format() -> mk.FusedMoEActivationFormat: @staticmethod def _supports_current_device() -> bool: - return ( - current_platform.has_device_capability((10, 0)) - and has_flashinfer_cutedsl_grouped_gemm_nt_masked() - ) + # TODO: add check cutedsl support? + return current_platform.has_device_capability((10, 0)) @staticmethod def _supports_no_act_and_mul() -> bool: diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index bd15688f5e35..44f1b7446bc8 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -40,7 +40,7 @@ class NvFp4MoeBackend(Enum): FLASHINFER_TRTLLM = "FlashInfer TRTLLM" FLASHINFER_CUTLASS = "FlashInfer CUTLASS" FLASHINFER_CUTEDSL = "FlashInfer CUTEDSL" - VLLM_CUTLASS = "vLLM CUTASS" + VLLM_CUTLASS = "vLLM CUTLASS" MARLIN = "vLLM MARLIN" diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 067c6fb3e785..4aa6da222e84 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -238,6 +238,7 @@ def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool: for module_name, attr_name in required_functions: mod = _get_submodule(module_name) if not mod or not hasattr(mod, attr_name): + print(f"{attr_name} not found in {module_name}") return False return True From 218e9bf22415704901fc3370888fe7e39abd96b5 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 14:24:58 -0500 Subject: [PATCH 089/113] hook up marlin experts properly Signed-off-by: Robert Shaw --- .../layers/fused_moe/fused_marlin_moe.py | 22 +++---------------- .../layers/fused_moe/oracle/nvfp4.py | 6 ++--- 2 files changed, 6 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 09a7bef1b00a..c048a31f7ab0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -9,6 +9,7 @@ import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, FusedMoEQuantScheme, @@ -534,6 +535,7 @@ def batched_fused_marlin_moe( class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, w13_g_idx: torch.Tensor | None = None, w2_g_idx: torch.Tensor | None = None, @@ -553,7 +555,7 @@ def __init__( self.w13_g_idx_sort_indices = w13_g_idx_sort_indices self.w2_g_idx_sort_indices = w2_g_idx_sort_indices self.is_k_full = is_k_full - super().__init__(quant_config) + super().__init__(moe_config, quant_config) @staticmethod def _supports_current_device() -> bool: @@ -626,24 +628,6 @@ def moe_problem_size( class MarlinExperts(MarlinExpertsBase): - def __init__( - self, - quant_config: FusedMoEQuantConfig, - w13_g_idx: torch.Tensor | None = None, - w2_g_idx: torch.Tensor | None = None, - w13_g_idx_sort_indices: torch.Tensor | None = None, - w2_g_idx_sort_indices: torch.Tensor | None = None, - is_k_full: bool = True, - ): - super().__init__( - quant_config, - w13_g_idx, - w2_g_idx, - w13_g_idx_sort_indices, - w2_g_idx_sort_indices, - is_k_full, - ) - def supports_expert_map(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 44f1b7446bc8..a56ddcd04ed1 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -89,11 +89,11 @@ def backend_2_kernel_cls( return FlashInferCuteDSLExperts elif backend == NvFp4MoeBackend.VLLM_CUTLASS: - from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import ( - TritonOrCutlassExperts, + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassExpertsFp4, ) - return TritonOrCutlassExperts + return CutlassExpertsFp4 elif backend == NvFp4MoeBackend.MARLIN: from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( From 3c97fda6e5f8bb70f41501fb31bc9bb230d33b4c Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 14:25:17 -0500 Subject: [PATCH 090/113] hook up marlin experts properly Signed-off-by: Robert Shaw --- .../layers/fused_moe/fused_marlin_moe.py | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index c048a31f7ab0..810529f0b9ac 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -730,28 +730,6 @@ def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: class BatchedMarlinExperts(MarlinExpertsBase): - def __init__( - self, - max_num_tokens: int, - num_dispatchers: int, - quant_config: FusedMoEQuantConfig, - w13_g_idx: torch.Tensor | None = None, - w2_g_idx: torch.Tensor | None = None, - w13_g_idx_sort_indices: torch.Tensor | None = None, - w2_g_idx_sort_indices: torch.Tensor | None = None, - is_k_full: bool = True, - ): - super().__init__( - quant_config, - w13_g_idx, - w2_g_idx, - w13_g_idx_sort_indices, - w2_g_idx_sort_indices, - is_k_full, - ) - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers - def supports_expert_map(self) -> bool: return True From f2714d89f8f8a26be9e93bdc67bd1333964d51ab Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 14:35:54 -0500 Subject: [PATCH 091/113] make marlin work with AG/RS Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/oracle/nvfp4.py | 3 ++- vllm/utils/flashinfer.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index a56ddcd04ed1..a58f2487011b 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -150,7 +150,8 @@ def _return_or_raise( NvFp4MoeBackend.FLASHINFER_TRTLLM, NvFp4MoeBackend.FLASHINFER_CUTEDSL, NvFp4MoeBackend.FLASHINFER_CUTLASS, - NvFp4MoeBackend.VLLM_CUTLASS, + NvFp4MoeBackend.MARLIN, + # NvFp4MoeBackend.VLLM_CUTLASS, ] if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"): diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 4aa6da222e84..067c6fb3e785 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -238,7 +238,6 @@ def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool: for module_name, attr_name in required_functions: mod = _get_submodule(module_name) if not mod or not hasattr(mod, attr_name): - print(f"{attr_name} not found in {module_name}") return False return True From 471428fcf27727ecbe3f5688a349c877b28b6a5a Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 17:07:15 -0500 Subject: [PATCH 092/113] convert to using quant key Signed-off-by: Robert Shaw --- .../layers/fused_moe/batched_deep_gemm_moe.py | 18 +++- .../model_executor/layers/fused_moe/config.py | 71 --------------- .../layers/fused_moe/cutlass_moe.py | 28 ++++-- .../layers/fused_moe/deep_gemm_moe.py | 16 +++- .../fused_moe/flashinfer_cutedsl_moe.py | 15 +++- .../fused_moe/flashinfer_cutlass_moe.py | 36 ++++++-- .../layers/fused_moe/flashinfer_trtllm_moe.py | 30 ++++--- .../layers/fused_moe/fused_batched_moe.py | 77 +++++++++------- .../layers/fused_moe/fused_marlin_moe.py | 31 +++++-- .../layers/fused_moe/fused_moe.py | 37 +++++--- .../fused_moe/gpt_oss_triton_kernels_moe.py | 25 ++---- .../layers/fused_moe/modular_kernel.py | 14 ++- .../layers/fused_moe/oracle/fp8.py | 50 ++++++++--- .../layers/fused_moe/oracle/nvfp4.py | 42 ++++++--- .../layers/fused_moe/rocm_aiter_fused_moe.py | 20 +++-- .../layers/fused_moe/triton_cutlass_moe.py | 11 ++- .../layers/fused_moe/triton_deep_gemm_moe.py | 11 ++- .../layers/fused_moe/trtllm_moe.py | 19 ++-- .../compressed_tensors_moe.py | 87 +++++++------------ .../model_executor/layers/quantization/fp8.py | 46 ++++------ .../layers/quantization/modelopt.py | 60 ++----------- .../quantization/utils/flashinfer_fp4_moe.py | 73 +++------------- .../layers/quantization/utils/quant_utils.py | 28 +++--- 23 files changed, 420 insertions(+), 425 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 47f73e053ca6..49ff4239c14e 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -11,12 +11,16 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, ) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, +) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import ( @@ -284,8 +288,16 @@ def _supports_no_act_and_mul() -> bool: return False @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return quant_scheme.is_fp8_w8a8 and quant_scheme.block_size == (128, 128) + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + if weight_key is None or activation_key is None: + return False + + return ( + weight_key == kFp8Dynamic128Sym and activation_key == kFp8Static128BlockSym + ) @staticmethod def _supports_activation(activation: str) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 297faa9084c0..55b30c7a43c6 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -123,77 +123,6 @@ class RoutingMethodType(IntEnum): UNQUANTIZED_DTYPES = [torch.bfloat16, torch.float16, torch.float32] -@dataclass -class FusedMoEQuantScheme: - weight_dtype: torch.dtype | str | None - act_dtype: torch.dtype | str | None - per_token_quant: bool - per_tensor_quant: bool - static_input_quant: bool - block_size: tuple[int, int] | None - - @property - def per_block_quant(self) -> bool: - return self.block_size is not None - - @property - def is_unquantized(self) -> bool: - return self.weight_dtype in UNQUANTIZED_DTYPES - - def __post_init__(self): - if self.per_tensor_quant: - assert not self.per_token_quant - assert not self.per_block_quant - elif self.per_token_quant: - assert not self.per_tensor_quant - assert not self.per_block_quant - assert not self.static_input_quant - elif self.per_block_quant: - assert not self.per_tensor_quant - assert not self.per_token_quant - assert not self.static_input_quant - assert self.block_size is not None - if self.is_unquantized: - assert self.act_dtype in UNQUANTIZED_DTYPES - - @property - def is_fp8_w8a8(self) -> bool: - return ( - self.weight_dtype == current_platform.fp8_dtype() - and self.act_dtype == current_platform.fp8_dtype() - ) - - @property - def is_nvfp4_w4a4(self) -> bool: - return self.weight_dtype == "nvfp4" and self.act_dtype == "nvfp4" - - @property - def is_nvfp4_w4a16(self) -> bool: - return self.weight_dtype == "nvfp4" and self.act_dtype is None - - @property - def is_mxfp4_w4a4(self) -> bool: - return self.weight_dtype == "mxfp4" and self.act_dtype == "mxfp4" - - @property - def is_mxfp4_w4a8(self) -> bool: - return ( - self.weight_dtype == "mxfp4" - and self.act_dtype == current_platform.fp8_dtype() - ) - - @property - def is_mxfp4_w4a16(self) -> bool: - return self.weight_dtype == "mxfp4" and self.act_dtype in UNQUANTIZED_DTYPES - - @property - def is_int4_w4a8(self) -> bool: - return ( - self.weight_dtype == "int4" - and self.act_dtype == current_platform.fp8_dtype() - ) - - @dataclass class FusedMoEQuantDesc: """ diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 5da1006d5397..f48f7c4cb0cc 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -13,7 +13,6 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( moe_permute, @@ -24,6 +23,13 @@ TopKWeightAndReduceNoOP, ) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_group_gemm_supported, ) @@ -271,8 +277,16 @@ def _supports_no_act_and_mul() -> bool: return False @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return quant_scheme.is_fp8_w8a8 and not quant_scheme.per_block_quant + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [ + (kFp8StaticChannelSym, kFp8DynamicTokenSym), + (kFp8StaticTensorSym, kFp8DynamicTensorSym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A @staticmethod def _supports_activation(activation: str) -> bool: @@ -879,8 +893,12 @@ def _supports_no_act_and_mul() -> bool: return False @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return quant_scheme.is_int4_w4a8 and quant_scheme.per_token_quant + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # return quant_scheme.is_int4_w4a8 and quant_scheme.per_token_quant + return False @staticmethod def _supports_activation(activation: str) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 871100abdd2a..a86f39e96959 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -9,7 +9,6 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( compute_aligned_M, @@ -25,6 +24,11 @@ per_token_group_quant_fp8_packed_for_deepgemm, silu_mul_per_token_group_quant_fp8_colmajor, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, +) from vllm.utils.deep_gemm import ( DeepGemmQuantScaleFMT, get_mk_alignment_for_contiguous_layout, @@ -129,8 +133,14 @@ def _supports_no_act_and_mul() -> bool: return False @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return quant_scheme.is_fp8_w8a8 and quant_scheme.block_size == (128, 128) + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A @staticmethod def _supports_activation(activation: str) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index 5aec5abde971..a2917a5736ef 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -10,11 +10,14 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kNvfp4Quant, +) from vllm.platforms import current_platform from vllm.utils.flashinfer import ( flashinfer_cutedsl_grouped_gemm_nt_masked, @@ -51,8 +54,14 @@ def _supports_no_act_and_mul() -> bool: return False @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return quant_scheme.is_nvfp4_w4a4 + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [ + (kNvfp4Quant, kNvfp4Quant), + ] + return (weight_key, activation_key) in SUPPORTED_W_A @staticmethod def _supports_activation(activation: str) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index cec628c0c9a9..1e3cd64a3f00 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -8,11 +8,17 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEParallelConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, + kFp8StaticTensorSym, + kNvfp4Quant, +) from vllm.platforms import current_platform from vllm.utils.flashinfer import ( flashinfer_cutlass_fused_moe, @@ -86,6 +92,7 @@ def should_pf_defer_input_quant( def _supports_current_device() -> bool: return ( current_platform.is_cuda() + # Is this right? Or 9.0+10.0 only? and current_platform.has_device_capability((9, 0)) and has_flashinfer_cutlass_fused_moe() ) @@ -95,23 +102,34 @@ def _supports_no_act_and_mul() -> bool: return False @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: # Supports: # * unquantized # * fp8 static per-tensor on 9.0+ # * fp8 block on 9.0 # * nvfp4 on 10.0+ - s = quant_scheme + p = current_platform + scheme = (weight_key, activation_key) return ( - (s.is_unquantized) - or (s.is_fp8_w8a8 and s.per_tensor_quant and s.static_input_quant) + ( + scheme + in [ + (None, None), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + ) + or ( + (scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym)) + and (p.is_device_capability((9, 0))) + ) or ( - s.is_fp8_w8a8 - and s.block_size == (128, 128) - and p.is_device_capability((9, 0)) + (scheme == (kNvfp4Quant, kNvfp4Quant)) + and (p.is_device_capability((10, 0))) # GB? ) - or (s.is_nvfp4_w4a4 and p.has_device_capability((10, 0))) ) @staticmethod diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 910d07f2427b..30af9d912dff 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -7,7 +7,6 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, - FusedMoEQuantScheme, RoutingMethodType, ) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input @@ -17,6 +16,12 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, + kFp8StaticTensorSym, +) from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op @@ -37,14 +42,16 @@ def _supports_no_act_and_mul() -> bool: return False -def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - """Supports Fp8 per-tensor, Fp8 block, and Nvfp4 quantization.""" - s = quant_scheme - return ( - (s.is_fp8_w8a8 and s.per_tensor_quant and s.static_input_quant) - or (s.is_fp8_w8a8 and s.block_size == (128, 128)) - or (s.is_nvfp4_w4a4) - ) +def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, +) -> bool: + """Supports Fp8 per-tensor and Fp8 block.""" + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A def _supports_activation(activation: str) -> bool: @@ -59,7 +66,8 @@ def _supports_moe_parallel_config(moe_parallel_config: FusedMoEParallelConfig) - def is_supported_config_trtllm( moe_config: FusedMoEConfig, - moe_quant_scheme: FusedMoEQuantScheme, + weight_key: QuantKey | None, + activation_key: QuantKey | None, activation_format: mk.FusedMoEActivationFormat, ) -> tuple[bool, str | None]: """ @@ -75,7 +83,7 @@ def _make_reason(reason: str) -> str: return False, _make_reason("no act_and_mul MLP layer") elif not _supports_activation(moe_config.activation): return False, _make_reason(f"{moe_config.activation} activation") - elif not _supports_quant_scheme(moe_quant_scheme): + elif not _supports_quant_scheme(weight_key, activation_key): return False, _make_reason("quantization scheme") elif not _supports_moe_parallel_config(moe_config.moe_parallel_config): return False, _make_reason("parallel config") diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index d66b59c18e22..fbb30d408457 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -6,9 +6,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( @@ -21,7 +21,10 @@ normalize_batched_scales_shape, normalize_scales_shape, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + group_broadcast, +) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -638,26 +641,41 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_num_tokens: int, - num_dispatchers: int, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): - super().__init__(quant_config) + super().__init__(moe_config, quant_config) assert not self.quant_config.use_int8_w8a8, "NYI" assert not self.quant_config.use_int8_w8a16, "NYI" assert not self.quant_config.use_int4_w4a16, "NYI" assert self.quant_config.ocp_mx_scheme is None, "NYI" - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) + @staticmethod + def activation_formats() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + @staticmethod + def _supports_current_device() -> bool: + return True + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return False + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu", "swigluoai"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True def supports_chunking(self) -> bool: return False @@ -831,19 +849,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_num_tokens: int, - num_dispatchers: int, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): - super().__init__(quant_config) + super().__init__(moe_config, quant_config) assert not self.quant_config.use_int8_w8a8, "NYI" assert not self.quant_config.use_int8_w8a16, "NYI" assert not self.quant_config.use_int4_w4a16, "NYI" assert self.quant_config.ocp_mx_scheme is None, "NYI" - assert max_num_tokens > 0 - assert num_dispatchers > 0 - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: @@ -858,14 +871,18 @@ def _supports_no_act_and_mul() -> bool: return False @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - p = current_platform - device_supports_fp8 = p.is_rocm() or ( - p.is_cuda() and p.has_device_capability((9, 0)) - ) - return quant_scheme.is_unquantized or ( - quant_scheme.is_fp8_w8a8 and device_supports_fp8 - ) + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # p = current_platform + # device_supports_fp8 = p.is_rocm() or ( + # p.is_cuda() and p.has_device_capability((9, 0)) + # ) + # return quant_scheme.is_unquantized or ( + # quant_scheme.is_fp8_w8a8 and device_supports_fp8 + # ) + return False @staticmethod def _supports_activation(activation: str) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 810529f0b9ac..114e04abe916 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -12,7 +12,6 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( batched_moe_align_block_size, @@ -28,6 +27,11 @@ marlin_moe_intermediate_size, marlin_quant_input, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8StaticChannelSym, + kNvfp4Quant, +) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -567,14 +571,25 @@ def _supports_no_act_and_mul() -> bool: return False @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return quant_scheme.weight_dtype in [ - current_platform.fp8_dtype(), - torch.int8, - "int4", - "nvfp4", - "mxfp4", + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # return quant_scheme.weight_dtype in [ + # current_platform.fp8_dtype(), + # torch.int8, + # "int4", + # "nvfp4", + # "mxfp4", + # ] + + # NOTE: Marlin runs activations unquantized, so + # it can run all activation formats. + SUPPORTED_W = [ + kFp8StaticChannelSym, + kNvfp4Quant, ] + return weight_key in SUPPORTED_W @staticmethod def _supports_activation(activation: str) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b9a4a136b878..4564e156041f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -24,7 +24,6 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, _get_config_dtype_str, ) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( @@ -52,6 +51,14 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8DynamicTokenSym, + kFp8Static128BlockSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, +) from vllm.model_executor.utils import maybe_disable_graph_partition from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -2320,17 +2327,25 @@ def _supports_no_act_and_mul() -> bool: return False @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: p = current_platform device_supports_fp8 = p.is_rocm() or ( p.is_cuda() and p.has_device_capability((9, 0)) ) - return ( - quant_scheme.is_unquantized - or quant_scheme.is_fp8_w8a8 - and device_supports_fp8 - ) + if not device_supports_fp8: + return (weight_key, activation_key) == (None, None) + + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kFp8StaticChannelSym, kFp8DynamicTokenSym), + (kFp8StaticTensorSym, kFp8DynamicTokenSym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A @staticmethod def _supports_activation(activation: str) -> bool: @@ -2515,9 +2530,11 @@ def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: class TritonWNA16Experts(TritonExperts): @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - # TODO - return True + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return False def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index f54722c7a498..eb9d0e669245 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -11,12 +11,14 @@ FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEParallelConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.import_utils import has_triton_kernels @@ -244,9 +246,6 @@ def make_routing_data( class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config: FusedMoEQuantConfig): - super().__init__(quant_config) - @staticmethod def _supports_current_device() -> bool: p = current_platform @@ -257,8 +256,11 @@ def _supports_no_act_and_mul() -> bool: return False @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return quant_scheme.is_mxfp4_w4a16 + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return False @staticmethod def _supports_activation(activation: str) -> bool: @@ -321,11 +323,6 @@ def _make_routing_data( class OAITritonExperts(BaseOAITritonExperts): - def __init__(self, quant_config: FusedMoEQuantConfig): - # TODO (varun) : Enable activation quantization - assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" - super().__init__(quant_config) - @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard @@ -410,12 +407,6 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): One use case for it is to inject LoRA modules on the activation and moe_sum. """ - def __init__(self, quant_config: FusedMoEQuantConfig): - # TODO (varun) : Enable activation quantization - assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" - super().__init__(quant_config) - self.quant_config = quant_config - @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 39234035b51d..fad57edcd16b 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -16,7 +16,6 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, @@ -24,6 +23,9 @@ count_expert_num_tokens, disable_inplace, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) from vllm.utils.math_utils import cdiv from vllm.v1.worker.ubatching import ( dbo_enabled, @@ -513,7 +515,8 @@ def moe_problem_size( def is_supported_config( cls: type["FusedMoEPermuteExpertsUnpermute"], moe_config: FusedMoEConfig, - moe_quant_scheme: FusedMoEQuantScheme, + weight_key: QuantKey | None, + activation_key: QuantKey | None, activation_format: FusedMoEActivationFormat, ) -> tuple[bool, str | None]: def _make_reason(reason: str) -> str: @@ -525,7 +528,7 @@ def _make_reason(reason: str) -> str: return False, _make_reason("no act_and_mul MLP layer") elif not cls._supports_activation(moe_config.activation): return False, _make_reason(f"{moe_config.activation} activation") - elif not cls._supports_quant_scheme(moe_quant_scheme): + elif not cls._supports_quant_scheme(weight_key, activation_key): return False, _make_reason("quantization scheme") elif not cls._supports_parallel_config(moe_config.moe_parallel_config): return False, _make_reason("parallel config") @@ -553,7 +556,10 @@ def _supports_no_act_and_mul() -> bool: @staticmethod @abstractmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: raise NotImplementedError @staticmethod diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index fa1cf03fcf08..c7b90f1285ba 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, fp8_w8a8_moe_quant_config, fp8_w8a16_moe_quant_config, ) @@ -34,6 +33,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) if TYPE_CHECKING: from vllm.model_executor.layers.fused_moe.layer import FusedMoE @@ -122,8 +124,8 @@ def backend_2_kernel_cls( def select_fp8_moe_backend( config: FusedMoEConfig, - quant_scheme: FusedMoEQuantScheme, - activation_format: mk.FusedMoEActivationFormat, + weight_key: QuantKey | None, + activation_key: QuantKey | None, allow_vllm_cutlass: bool = False, ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: """ @@ -134,6 +136,19 @@ def select_fp8_moe_backend( if config.is_lora_enabled: return Fp8MoeBackend.TRITON, backend_2_kernel_cls(Fp8MoeBackend.TRITON) + # NOTE(rob): this is kind of a hack. We need to peak into + # the prepare-finalize selection to determine if we are using + # the batched or standard expert format. + use_batched = ( + config.moe_parallel_config.use_deepep_ll_kernels + or config.moe_parallel_config.use_pplx_kernels + ) + activation_format = ( + mk.FusedMoEActivationFormat.BatchedExperts + if use_batched + else mk.FusedMoEActivationFormat.Standard + ) + def _make_log_backend(backend: Fp8MoeBackend): return f"Using `{backend.value}` backend for FP8 MoE" @@ -152,12 +167,13 @@ def _make_log_unsupported(backend: Fp8MoeBackend, reason: str | None) -> str: def _return_or_raise( backend: Fp8MoeBackend, config: FusedMoEConfig, - scheme: FusedMoEQuantScheme, + weight_key: QuantKey | None, + activation_key: QuantKey | None, activation_format: mk.FusedMoEActivationFormat, ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: k_cls = backend_2_kernel_cls(backend) supported, reason = k_cls.is_supported_config( - k_cls, config, scheme, activation_format + k_cls, config, weight_key, activation_key, activation_format ) if supported: logger.info_once(_make_log_backend(backend)) @@ -191,7 +207,7 @@ def _return_or_raise( if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: backend = Fp8MoeBackend.FLASHINFER_TRTLLM supported, reason = is_supported_config_trtllm( - config, quant_scheme, activation_format + config, weight_key, activation_key, activation_format ) if supported: logger.info_once(_make_log_backend(backend)) @@ -202,7 +218,7 @@ def _return_or_raise( elif fi_backend == FlashinferMoeBackend.CUTLASS: backend = Fp8MoeBackend.FLASHINFER_CUTLASS return _return_or_raise( - backend, config, quant_scheme, activation_format + backend, config, weight_key, activation_key, activation_format ) else: @@ -217,7 +233,7 @@ def _return_or_raise( ]: k_cls = backend_2_kernel_cls(backend) if k_cls.is_supported_config( - k_cls, config, quant_scheme, activation_format + k_cls, config, weight_key, activation_key, activation_format ): logger.info_once(_make_log_backend(backend)) return backend, k_cls @@ -237,12 +253,16 @@ def _return_or_raise( if activation_format == mk.FusedMoEActivationFormat.Standard else Fp8MoeBackend.BATCHED_DEEPGEMM ) - return _return_or_raise(backend, config, quant_scheme, activation_format) + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) # Handle explicit MARLIN FP8 configuration. if envs.VLLM_TEST_FORCE_FP8_MARLIN: backend = Fp8MoeBackend.MARLIN - return _return_or_raise(backend, config, quant_scheme, activation_format) + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) # Handle explicit AITER FP8 configuration. if envs.is_set("VLLM_ROCM_USE_AITER") or envs.is_set("VLLM_ROCM_USE_AITER_MOE"): @@ -250,7 +270,9 @@ def _return_or_raise( AVAILABLE_BACKENDS.remove(Fp8MoeBackend.AITER) else: backend = Fp8MoeBackend.AITER - return _return_or_raise(backend, config, quant_scheme, activation_format) + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) if not allow_vllm_cutlass: AVAILABLE_BACKENDS.remove(Fp8MoeBackend.VLLM_CUTLASS) @@ -261,7 +283,8 @@ def _return_or_raise( k_cls = None # type: ignore[assignment] supported, reason = is_supported_config_trtllm( config, - quant_scheme, + weight_key, + activation_key, activation_format, ) else: @@ -269,7 +292,8 @@ def _return_or_raise( supported, reason = k_cls.is_supported_config( k_cls, config, - quant_scheme, + weight_key, + activation_key, activation_format, ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index a58f2487011b..d96334de895f 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, nvfp4_moe_quant_config, nvfp4_w4a16_moe_quant_config, ) @@ -29,6 +28,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( prepare_nvfp4_moe_layer_for_marlin, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) if TYPE_CHECKING: from vllm.model_executor.layers.fused_moe.layer import FusedMoE @@ -107,14 +109,27 @@ def backend_2_kernel_cls( def select_nvfp4_moe_backend( config: FusedMoEConfig, - quant_scheme: FusedMoEQuantScheme, - activation_format: mk.FusedMoEActivationFormat, + weight_key: QuantKey | None, + activation_key: QuantKey | None, ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: """ Select the primary NvFP4 MoE backend Note: Shape-specific fallbacks may still occur at runtime. """ + # NOTE(rob): this is kind of a hack. We need to peak into + # the prepare-finalize selection to determine if we are using + # the batched or standard expert format. + use_batched = ( + config.moe_parallel_config.use_deepep_ll_kernels + or config.moe_parallel_config.use_pplx_kernels + ) + activation_format = ( + mk.FusedMoEActivationFormat.BatchedExperts + if use_batched + else mk.FusedMoEActivationFormat.Standard + ) + def _make_log_backend(backend: NvFp4MoeBackend): return f"Using '{backend.value}' backend for NvFp4 MoE" @@ -133,12 +148,13 @@ def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str: def _return_or_raise( backend: NvFp4MoeBackend, config: FusedMoEConfig, - scheme: FusedMoEQuantScheme, + weight_key: QuantKey | None, + activation_key: QuantKey | None, activation_format: mk.FusedMoEActivationFormat, ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: k_cls = backend_2_kernel_cls(backend) supported, reason = k_cls.is_supported_config( - k_cls, config, scheme, activation_format + k_cls, config, weight_key, activation_key, activation_format ) if supported: logger.info_once(_make_log_backend(backend)) @@ -167,7 +183,7 @@ def _return_or_raise( if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: backend = NvFp4MoeBackend.FLASHINFER_TRTLLM supported, reason = is_supported_config_trtllm( - config, quant_scheme, activation_format + config, weight_key, activation_key, activation_format ) if supported: logger.info_once(_make_log_backend(backend)) @@ -177,14 +193,14 @@ def _return_or_raise( else: backend = fi_2_vllm_backend_map[fi_backend] return _return_or_raise( - backend, config, quant_scheme, activation_format + backend, config, weight_key, activation_key, activation_format ) else: # If the user is not explicit about the backend, try each. for backend in FLASHINFER_NVFP4_MOE_BACKENDS: k_cls = backend_2_kernel_cls(backend) if k_cls.is_supported_config( - k_cls, config, quant_scheme, activation_format + k_cls, config, weight_key, activation_key, activation_format ): logger.info_once(_make_log_backend(backend)) return backend, k_cls @@ -196,7 +212,9 @@ def _return_or_raise( if envs.VLLM_TEST_FORCE_FP8_MARLIN: backend = NvFp4MoeBackend.MARLIN - return _return_or_raise(backend, config, quant_scheme, activation_format) + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) # Select kernels in order of backend. for backend in AVAILABLE_BACKENDS: @@ -204,7 +222,8 @@ def _return_or_raise( k_cls = None # type: ignore[assignment] supported, reason = is_supported_config_trtllm( config, - quant_scheme, + weight_key, + activation_key, activation_format, ) else: @@ -212,7 +231,8 @@ def _return_or_raise( supported, reason = k_cls.is_supported_config( k_cls, config, - quant_scheme, + weight_key, + activation_key, activation_format, ) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 4e86a21bce91..cb1a211e9925 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -11,11 +11,13 @@ FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEParallelConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) class QuantMethod(IntEnum): @@ -293,12 +295,16 @@ def _supports_no_act_and_mul() -> bool: return False @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return ( - quant_scheme.is_unquantized - or quant_scheme.is_fp8_w8a8 - or quant_scheme.is_mxfp4_w4a4 - ) + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # return ( + # quant_scheme.is_unquantized + # or quant_scheme.is_fp8_w8a8 + # or quant_scheme.is_mxfp4_w4a4 + # ) + return False @staticmethod def _supports_activation(activation: str) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py index 3865babf20a1..5e4366556fc1 100644 --- a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py @@ -9,11 +9,13 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) from vllm.platforms import current_platform @@ -44,8 +46,11 @@ def _supports_no_act_and_mul() -> bool: return CutlassExpertsFp8._supports_no_act_and_mul() @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return CutlassExpertsFp8._supports_quant_scheme(quant_scheme) + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return CutlassExpertsFp8._supports_quant_scheme(weight_key, activation_key) @staticmethod def _supports_activation(activation: str) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index a8beef14c3a1..3f3a76f80409 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -8,7 +8,6 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, @@ -17,6 +16,9 @@ ) from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) from vllm.utils.deep_gemm import ( is_deep_gemm_e8m0_used, ) @@ -45,8 +47,11 @@ def _supports_no_act_and_mul() -> bool: return DeepGemmExperts._supports_no_act_and_mul() @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return DeepGemmExperts._supports_quant_scheme(quant_scheme) + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return DeepGemmExperts._supports_quant_scheme(weight_key, activation_key) @staticmethod def _supports_activation(activation: str) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index 97cddd5897a7..e7cec11c237d 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -8,26 +8,27 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) from vllm.platforms import current_platform class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - moe: FusedMoEConfig, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, max_capture_size, ): - super().__init__(quant_config) - self.moe = moe + super().__init__(moe_config, quant_config) self.gemm1_alpha = gemm1_alpha self.gemm1_beta = gemm1_beta self.gemm1_clamp_limit = gemm1_clamp_limit @@ -46,8 +47,12 @@ def _supports_no_act_and_mul() -> bool: return False @staticmethod - def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - return quant_scheme.is_mxfp4_w4a16 or quant_scheme.is_mxfp4_w4a8 + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # return quant_scheme.is_mxfp4_w4a16 or quant_scheme.is_mxfp4_w4a8 + return False @staticmethod def _supports_activation(activation: str) -> bool: @@ -104,7 +109,7 @@ def apply( topk = topk_ids.size(-1) local_num_experts = w1.size(0) intermediate_size = w2.size(1) - local_expert_offset = self.moe.ep_rank * local_num_experts + local_expert_offset = self.moe_config.ep_rank * local_num_experts x_quant = hidden_states x_scale = a1q_scale diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 72bb77ccc419..3aa8c8e01a75 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, RoutingMethodType, fp8_w8a8_moe_quant_config, fp8_w8a16_moe_quant_config, @@ -82,6 +81,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( convert_bf16_scales_to_fp8, convert_packed_uint4b8_to_signed_int4_inplace, + kFp8Dynamic128Sym, + kFp8DynamicTokenSym, + kFp8Static128BlockSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, + kNvfp4Quant, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, @@ -239,34 +244,11 @@ def __init__( super().__init__(moe) self.group_size = 16 - # Create quant scheme, will be used later to select the quant scales. - # NOTE(rob): we should update QuantConfig to just be the think ts - # holds the scales. Should change the name. - quant_scheme = FusedMoEQuantScheme( - weight_dtype="nvfp4", - act_dtype="nvfp4" if not use_a16 else None, - per_tensor_quant=False, - per_token_quant=False, - block_size=None, - static_input_quant=False, - ) - - # Select NvFp4 MoE backend - # NOTE(rob): this is kind of a hack. We need to peak into - # the prepare-finalize selection to determine if we are using - # the batched or standard expert format. - use_batched = ( - self.moe.moe_parallel_config.use_deepep_ll_kernels - or self.moe.moe_parallel_config.use_pplx_kernels - ) + # Select experts implementation. self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( config=self.moe, - quant_scheme=quant_scheme, - activation_format=( - mk.FusedMoEActivationFormat.BatchedExperts - if use_batched - else mk.FusedMoEActivationFormat.Standard - ), + weight_key=kNvfp4Quant, + activation_key=kNvfp4Quant, ) # Delay creation of the kernel until after process-weights. @@ -597,38 +579,28 @@ def __init__( "channelwise, dynamic per token quantization." ) - # Create quant scheme, will be used later to select the quant scales. - # NOTE(rob): we should update QuantConfig to just be the think ts - # holds the scales. Should change the name. - quant_scheme = FusedMoEQuantScheme( - weight_dtype=current_platform.fp8_dtype(), - act_dtype=current_platform.fp8_dtype(), - per_tensor_quant=per_tensor, - per_token_quant=per_channel, - block_size=( - (self.weight_block_size[0], self.weight_block_size[1]) - if self.weight_block_size is not None - else None + ct2vllm_weight = { + QuantizationStrategy.CHANNEL: kFp8StaticChannelSym, + QuantizationStrategy.TENSOR: kFp8StaticTensorSym, + QuantizationStrategy.BLOCK: kFp8Static128BlockSym, + } + ct2vllm_act = { + QuantizationStrategy.TOKEN: kFp8DynamicTokenSym, + QuantizationStrategy.TENSOR: ( + kFp8StaticTensorSym if self.static_input_scales else kFp8Dynamic128Sym ), - static_input_quant=self.static_input_scales, - ) + } + weight_key = ct2vllm_weight[self.weight_quant.strategy] + if weight_key == kFp8Static128BlockSym: + activation_key = kFp8Dynamic128Sym + else: + activation_key = ct2vllm_act[self.input_quant.strategy] # Select Fp8 MoE backend - # NOTE(rob): this is kind of a hack. We need to peak into - # the prepare-finalize selection to determine if we are using - # the batched or standard expert format. - use_batched = ( - self.moe.moe_parallel_config.use_deepep_ll_kernels - or self.moe.moe_parallel_config.use_pplx_kernels - ) self.fp8_backend, self.experts_cls = select_fp8_moe_backend( config=self.moe, - quant_scheme=quant_scheme, - activation_format=( - mk.FusedMoEActivationFormat.BatchedExperts - if use_batched - else mk.FusedMoEActivationFormat.Standard - ), + weight_key=weight_key, + activation_key=activation_key, allow_vllm_cutlass=True, ) @@ -1441,8 +1413,7 @@ def select_gemm_impl( max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() assert max_num_tokens_per_rank is not None return BatchedMarlinExperts( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), + moe_config=self.moe, quant_config=self.moe_quant_config, w13_g_idx=layer.w13_weight_g_idx, w2_g_idx=layer.w2_weight_g_idx, @@ -1451,8 +1422,12 @@ def select_gemm_impl( is_k_full=self.is_k_full, ) else: + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None return MarlinExperts( quant_config=self.moe_quant_config, + max_num_tokens=max_num_tokens_per_rank, # type: ignore[call-arg] + num_dispatchers=prepare_finalize.num_dispatchers(), # type: ignore[call-arg] w13_g_idx=layer.w13_weight_g_idx, w2_g_idx=layer.w2_weight_g_idx, w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e72010ea5dc7..7fd0fa4f81cb 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -26,7 +26,6 @@ ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, - FusedMoEQuantScheme, RoutingMethodType, ) from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter @@ -74,6 +73,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, + kFp8StaticTensorSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, @@ -632,38 +634,24 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): "weight_scale_inv" if self.block_quant else "weight_scale" ) - # Create quant scheme, will be used later to select the quant scales. - # NOTE(rob): we should update QuantConfig to just be the think ts - # holds the scales. Should change the name. - quant_scheme = FusedMoEQuantScheme( - weight_dtype=current_platform.fp8_dtype(), - act_dtype=current_platform.fp8_dtype(), - per_tensor_quant=not self.block_quant, - per_token_quant=False, - block_size=( - (self.weight_block_size[0], self.weight_block_size[1]) - if self.weight_block_size is not None - else None - ), - static_input_quant=(self.quant_config.activation_scheme == "static"), - ) + # Set weight key and activation key for kernel compatibility + if self.block_quant: + weight_key = kFp8Static128BlockSym + activation_key = kFp8Dynamic128Sym + else: + weight_key = kFp8StaticTensorSym + activation_key = ( + kFp8StaticTensorSym + if self.quant_config.activation_scheme == "static" + else kFp8Dynamic128Sym + ) # Select Fp8 MoE backend - # NOTE(rob): this is kind of a hack. We need to peak into - # the prepare-finalize selection to determine if we are using - # the batched or standard expert format. - use_batched = ( - self.moe.moe_parallel_config.use_deepep_ll_kernels - or self.moe.moe_parallel_config.use_pplx_kernels - ) self.fp8_backend, self.experts_cls = select_fp8_moe_backend( config=self.moe, - quant_scheme=quant_scheme, - activation_format=( - mk.FusedMoEActivationFormat.BatchedExperts - if use_batched - else mk.FusedMoEActivationFormat.Standard - ), + weight_key=weight_key, + activation_key=activation_key, + allow_vllm_cutlass=False, ) # Delay creation of the kernel until after process-weights. diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index da7844c3fcd9..d7227e6a5284 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -16,7 +16,6 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, - FusedMoEQuantScheme, ) from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( @@ -74,6 +73,8 @@ GroupShape, cutlass_fp4_supported, is_layer_skipped, + kFp8StaticTensorSym, + kNvfp4Quant, swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -88,7 +89,6 @@ PerTensorScaleParameter, ) from vllm.model_executor.utils import replace_parameter -from vllm.platforms import current_platform from vllm.utils.flashinfer import ( flashinfer_scaled_fp4_mm, has_flashinfer, @@ -727,34 +727,11 @@ def __init__( self.quant_config = quant_config assert self.quant_config.is_checkpoint_fp8_serialized - # Create quant scheme, will be used later to select the quant scales. - # NOTE(rob): we should update QuantConfig to just be the think ts - # holds the scales. Should change the name. - quant_scheme = FusedMoEQuantScheme( - weight_dtype=current_platform.fp8_dtype(), - act_dtype=current_platform.fp8_dtype(), - per_tensor_quant=True, - per_token_quant=False, - block_size=None, - static_input_quant=True, - ) - # Select Fp8 MoE backend - # NOTE(rob): this is kind of a hack. We need to peak into - # the prepare-finalize selection to determine if we are using - # the batched or standard expert format. - use_batched = ( - self.moe.moe_parallel_config.use_deepep_ll_kernels - or self.moe.moe_parallel_config.use_pplx_kernels - ) self.fp8_backend, self.experts_cls = select_fp8_moe_backend( config=self.moe, - quant_scheme=quant_scheme, - activation_format=( - mk.FusedMoEActivationFormat.BatchedExperts - if use_batched - else mk.FusedMoEActivationFormat.Standard - ), + weight_key=kFp8StaticTensorSym, + activation_key=kFp8StaticTensorSym, ) # Delay creation of the kernel until after process-weights. @@ -1357,34 +1334,11 @@ def __init__( ) -> None: super().__init__(moe_config) self.quant_config = quant_config - # Create quant scheme, will be used later to select the quant scales. - # NOTE(rob): we should update QuantConfig to just be the think ts - # holds the scales. Should change the name. - quant_scheme = FusedMoEQuantScheme( - weight_dtype="nvfp4", - act_dtype="nvfp4", - per_tensor_quant=False, - per_token_quant=False, - block_size=None, - static_input_quant=False, - ) - - # Select NvFp4 MoE backend - # NOTE(rob): this is kind of a hack. We need to peak into - # the prepare-finalize selection to determine if we are using - # the batched or standard expert format. - use_batched = ( - self.moe.moe_parallel_config.use_deepep_ll_kernels - or self.moe.moe_parallel_config.use_pplx_kernels - ) + # Select experts implementation. self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( config=self.moe, - quant_scheme=quant_scheme, - activation_format=( - mk.FusedMoEActivationFormat.BatchedExperts - if use_batched - else mk.FusedMoEActivationFormat.Standard - ), + weight_key=kNvfp4Quant, + activation_key=kNvfp4Quant, ) # Delay creation of the kernel until after process-weights. diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 53591945e47d..edcda4bf23cd 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -12,20 +12,11 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, - FusedMoEQuantConfig, - FusedMoEQuantScheme, RoutingMethodType, ) -from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( - FlashInferCuteDSLExperts, -) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, -) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - create_flashinfer_prepare_finalize, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kNvfp4Quant, swizzle_blockscale, ) from vllm.platforms import current_platform @@ -46,7 +37,6 @@ "is_flashinfer_fp4_cutlass_moe_available", "is_flashinfer_fp4_cutedsl_moe_available", "reorder_w1w3_to_w3w1", - "build_flashinfer_fp4_cutlass_moe_prepare_finalize", ] # @@ -66,9 +56,15 @@ def _supports_no_act_and_mul() -> bool: return False -def _supports_quant_scheme(quant_scheme: FusedMoEQuantScheme) -> bool: - """Supports Fp8 per-tensor, Fp8 block, and Nvfp4 quantization.""" - return quant_scheme.is_nvfp4_w4a4 +def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, +) -> bool: + """Supports Nvfp4 quantization.""" + SUPPORTED_W_A = [ + (kNvfp4Quant, kNvfp4Quant), + ] + return (weight_key, activation_key) in SUPPORTED_W_A def _supports_activation(activation: str) -> bool: @@ -83,7 +79,8 @@ def _supports_moe_parallel_config(moe_parallel_config: FusedMoEParallelConfig) - def is_supported_config_trtllm( moe_config: FusedMoEConfig, - moe_quant_scheme: FusedMoEQuantScheme, + weight_key: QuantKey | None, + activation_key: QuantKey | None, activation_format: mk.FusedMoEActivationFormat, ) -> tuple[bool, str | None]: """ @@ -99,7 +96,7 @@ def _make_reason(reason: str) -> str: return False, _make_reason("no act_and_mul MLP layer") elif not _supports_activation(moe_config.activation): return False, _make_reason(f"{moe_config.activation} activation") - elif not _supports_quant_scheme(moe_quant_scheme): + elif not _supports_quant_scheme(weight_key, activation_key): return False, _make_reason("quantization scheme") elif not _supports_moe_parallel_config(moe_config.moe_parallel_config): return False, _make_reason("parallel config") @@ -146,48 +143,6 @@ def reorder_w1w3_to_w3w1( ) -def build_flashinfer_fp4_cutlass_moe_prepare_finalize( - moe: FusedMoEConfig, -) -> mk.FusedMoEPrepareAndFinalize: - """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" - use_dp = moe.moe_parallel_config.dp_size > 1 - enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv" - return create_flashinfer_prepare_finalize( - use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv - ) - - -def select_nvfp4_gemm_impl( - moe: FusedMoEConfig, - moe_quant_config: FusedMoEQuantConfig, - allow_flashinfer: bool, -) -> mk.FusedMoEPermuteExpertsUnpermute: - """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" - - if allow_flashinfer: - if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm": - return FlashInferCuteDSLExperts( - out_dtype=moe.in_dtype, - quant_config=moe_quant_config, - ) - elif envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput": - return FlashInferExperts( - out_dtype=moe.in_dtype, - quant_config=moe_quant_config, - ep_rank=moe.moe_parallel_config.ep_rank, - ep_size=moe.moe_parallel_config.ep_size, - tp_rank=moe.moe_parallel_config.tp_rank, - tp_size=moe.moe_parallel_config.tp_size, - use_dp=moe.moe_parallel_config.dp_size > 1, - ) - - # native cutlass experts currently don't support DP; TP case won't call this - raise ValueError( - "CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS " - "Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)" - ) - - def prepare_static_weights_for_trtllm_fp4_moe( # args_dequant, # args, diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 08262ed1a314..2a5f76121e44 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -20,6 +20,7 @@ FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 +INT4_DTYPE = torch.int32 def get_fp8_min_max() -> tuple[float, float]: @@ -48,6 +49,7 @@ class GroupShape(_GroupShape): # Aliases for common quantization group shapes PER_TENSOR: ClassVar["GroupShape"] PER_TOKEN: ClassVar["GroupShape"] + PER_CHANNEL: ClassVar["GroupShape"] def is_per_tensor(self) -> bool: return self.row == -1 and self.col == -1 @@ -55,12 +57,16 @@ def is_per_tensor(self) -> bool: def is_per_token(self) -> bool: return self.row == 1 and self.col == -1 + def is_per_channel(self) -> bool: + return self.row == -1 and self.col == 1 + def is_per_group(self) -> bool: return self.row == 1 and self.col >= 1 GroupShape.PER_TENSOR = GroupShape(-1, -1) GroupShape.PER_TOKEN = GroupShape(1, -1) +GroupShape.PER_CHANNEL = GroupShape(-1, 1) @dataclass(frozen=True) @@ -77,16 +83,12 @@ class ScaleDesc: group_shape: GroupShape def __str__(self): - group_shape = ( - "per_tensor" - if self.group_shape == GroupShape.PER_TENSOR - else ( - "per_token" - if self.group_shape == GroupShape.PER_TOKEN - else str(self.group_shape) - ) - ) - + d = { + GroupShape.PER_TENSOR: "per_tensor", + GroupShape.PER_TOKEN: "per_token", + GroupShape.PER_CHANNEL: "per_channel", + } + group_shape = d.get(self.group_shape, str(self.group_shape)) return ( f"{fx.graph.dtype_abbrs[self.dtype]}," f"{'static' if self.static else 'dynamic'},{group_shape}" @@ -123,6 +125,9 @@ def __str__(self): kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) +kStaticChannelScale = ScaleDesc(torch.float32, True, GroupShape.PER_CHANNEL) +kFp8StaticChannelSym = QuantKey(FP8_DTYPE, kStaticChannelScale, symmetric=True) + kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) @@ -132,6 +137,9 @@ def __str__(self): kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128)) kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True) +kStatic128BlockScale = ScaleDesc(torch.float32, True, GroupShape(128, 128)) +kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True) + kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64)) kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True) From e17d456a02c994978f29808549cf4ee44e1aded1 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 17:12:23 -0500 Subject: [PATCH 093/113] convert to using quant key Signed-off-by: Robert Shaw --- .../compressed_tensors/compressed_tensors_moe.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 3aa8c8e01a75..5ec5f9556c9d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1422,12 +1422,9 @@ def select_gemm_impl( is_k_full=self.is_k_full, ) else: - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - assert max_num_tokens_per_rank is not None return MarlinExperts( + moe_config=self.moe, quant_config=self.moe_quant_config, - max_num_tokens=max_num_tokens_per_rank, # type: ignore[call-arg] - num_dispatchers=prepare_finalize.num_dispatchers(), # type: ignore[call-arg] w13_g_idx=layer.w13_weight_g_idx, w2_g_idx=layer.w2_weight_g_idx, w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices, From 8cacf5a882d459c0154ee001820b6887826ab63b Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 17:27:22 -0500 Subject: [PATCH 094/113] reject usage for things that have not migrated over yet Signed-off-by: Robert Shaw --- .../layers/fused_moe/batched_deep_gemm_moe.py | 8 +-- .../fused_moe/flashinfer_cutlass_moe.py | 10 ++-- .../layers/fused_moe/fused_marlin_moe.py | 14 ++---- .../fused_moe/gpt_oss_triton_kernels_moe.py | 27 +++++++--- .../layers/fused_moe/triton_cutlass_moe.py | 25 ++++++++-- .../quantization/utils/flashinfer_utils.py | 50 ------------------- .../layers/quantization/utils/quant_utils.py | 1 - 7 files changed, 51 insertions(+), 84 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 49ff4239c14e..0d82eec3569c 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -292,12 +292,8 @@ def _supports_quant_scheme( weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: - if weight_key is None or activation_key is None: - return False - - return ( - weight_key == kFp8Dynamic128Sym and activation_key == kFp8Static128BlockSym - ) + SUPPORTED_W_A = [(kFp8Static128BlockSym, kFp8Dynamic128Sym)] + return (weight_key, activation_key) in SUPPORTED_W_A @staticmethod def _supports_activation(activation: str) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 1e3cd64a3f00..1cad2f482fef 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -106,11 +106,11 @@ def _supports_quant_scheme( weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: - # Supports: - # * unquantized - # * fp8 static per-tensor on 9.0+ - # * fp8 block on 9.0 - # * nvfp4 on 10.0+ + # The following are supported by FlashInferExperts: + # * unquantized + # * fp8 static per-tensor on 9.0+ + # * fp8 block on 9.0 + # * nvfp4 on 10.0+ p = current_platform scheme = (weight_key, activation_key) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 114e04abe916..6381087660ba 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -29,6 +29,7 @@ ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, + kFp8Static128BlockSym, kFp8StaticChannelSym, kNvfp4Quant, ) @@ -575,17 +576,10 @@ def _supports_quant_scheme( weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: - # return quant_scheme.weight_dtype in [ - # current_platform.fp8_dtype(), - # torch.int8, - # "int4", - # "nvfp4", - # "mxfp4", - # ] - - # NOTE: Marlin runs activations unquantized, so - # it can run all activation formats. + # TODO(rob): add int4, mxfp4, int8 as integrations + # are migrated to use the oracle one-by-one. SUPPORTED_W = [ + kFp8Static128BlockSym, kFp8StaticChannelSym, kNvfp4Quant, ] diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index eb9d0e669245..b209820cdfa9 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, ) -from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.import_utils import has_triton_kernels @@ -248,27 +247,41 @@ def make_routing_data( class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): @staticmethod def _supports_current_device() -> bool: - p = current_platform - return (p.is_cuda() and p.has_device_capability((9, 0))) or p.is_rocm() + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) @staticmethod def _supports_no_act_and_mul() -> bool: - return False + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) @staticmethod def _supports_quant_scheme( weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: - return False + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["swigluoai"] + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return True + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) def supports_expert_map(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py index 5e4366556fc1..39163bb24f58 100644 --- a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py @@ -35,30 +35,45 @@ def __init__( @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: + assert ( + CutlassExpertsFp8.activation_format() == TritonExperts.activation_format() + ) return CutlassExpertsFp8.activation_format() @staticmethod def _supports_current_device() -> bool: - return CutlassExpertsFp8._supports_current_device() + return ( + CutlassExpertsFp8._supports_current_device() + and TritonExperts._supports_current_device() + ) @staticmethod def _supports_no_act_and_mul() -> bool: - return CutlassExpertsFp8._supports_no_act_and_mul() + return ( + CutlassExpertsFp8._supports_no_act_and_mul() + and TritonExperts._supports_no_act_and_mul() + ) @staticmethod def _supports_quant_scheme( weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: - return CutlassExpertsFp8._supports_quant_scheme(weight_key, activation_key) + return CutlassExpertsFp8._supports_quant_scheme( + weight_key, activation_key + ) and TritonExperts._supports_quant_scheme(weight_key, activation_key) @staticmethod def _supports_activation(activation: str) -> bool: - return CutlassExpertsFp8._supports_activation(activation) + return CutlassExpertsFp8._supports_activation( + activation + ) and TritonExperts._supports_activation(activation) @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return CutlassExpertsFp8._supports_parallel_config(moe_parallel_config) + return CutlassExpertsFp8._supports_parallel_config( + moe_parallel_config + ) and TritonExperts._supports_parallel_config(moe_parallel_config) def workspace_shapes( self, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 799854479823..19e7124684ad 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -4,19 +4,8 @@ import torch -import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, - FusedMoEQuantConfig, -) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, -) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - create_flashinfer_prepare_finalize, -) from vllm.platforms import current_platform from vllm.utils.math_utils import round_up @@ -191,45 +180,6 @@ def make_fp8_moe_alpha_scales_for_fi( return g1_alphas, g2_alphas -def build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False -) -> mk.FusedMoEPrepareAndFinalize: - """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" - use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False - # Propagate block-scale flag so prepare/finalize can skip act quantization - # and inform the kernel to consume per-block weight scales. - return create_flashinfer_prepare_finalize( - use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale - ) - - -def select_cutlass_fp8_gemm_impl( - moe: FusedMoEConfig | None, - quant_config: FusedMoEQuantConfig, - out_dtype: torch.dtype | None = None, - use_deepseek_fp8_block_scale: bool = False, -) -> mk.FusedMoEPermuteExpertsUnpermute: - """Return a GEMM *experts* implementation for fused-MoE layers""" - - if moe is not None: - return FlashInferExperts( - out_dtype=moe.in_dtype, - quant_config=quant_config, - ep_rank=moe.moe_parallel_config.ep_rank, - ep_size=moe.moe_parallel_config.ep_size, - tp_rank=moe.moe_parallel_config.tp_rank, - tp_size=moe.moe_parallel_config.tp_size, - use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - ) - - assert out_dtype is not None, "If moe config is None, out_dtype must be passed" - return FlashInferExperts( - out_dtype=out_dtype, - quant_config=quant_config, - use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - ) - - def get_flashinfer_moe_backend() -> FlashinferMoeBackend: backend_map = { "throughput": FlashinferMoeBackend.CUTLASS, diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 2a5f76121e44..fe0cea539713 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -20,7 +20,6 @@ FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 -INT4_DTYPE = torch.int32 def get_fp8_min_max() -> tuple[float, float]: From 9518c9751659f8ea75622be28a873cffef923919 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 17:31:27 -0500 Subject: [PATCH 095/113] reject usage for things that have not migrated over yet Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/oracle/nvfp4.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index d96334de895f..875995f1b7e6 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -64,10 +64,7 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool: # of all experts in Expert Parallel Mode when all experts are not # on the same rank. - return backend in [ - NvFp4MoeBackend.FLASHINFER_CUTLASS, - NvFp4MoeBackend.FLASHINFER_TRTLLM, - ] + return backend in FLASHINFER_NVFP4_MOE_BACKENDS def backend_2_kernel_cls( From ab090c10e9dbee789a8dafaf6d7fda69046a6c4b Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 17:34:19 -0500 Subject: [PATCH 096/113] reject usage for things that have not migrated over yet Signed-off-by: Robert Shaw --- .../layers/fused_moe/cutlass_moe.py | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index f48f7c4cb0cc..e917fe203abf 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -33,7 +33,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_group_gemm_supported, ) -from vllm.platforms import current_platform from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -868,7 +867,7 @@ def __init__( group_size: int, ): super().__init__(moe_config=moe_config, quant_config=quant_config) - self.out_dtype = moe_config.in_dtype + self.out_dtype = out_dtype self.a_strides1 = a_strides1 self.a_strides2 = a_strides2 self.b_strides1 = b_strides1 @@ -879,34 +878,43 @@ def __init__( self.s_strides2 = s_strides2 self.group_size = group_size - @staticmethod - def activation_format() -> mk.FusedMoEActivationFormat: - return mk.FusedMoEActivationFormat.Standard - @staticmethod def _supports_current_device() -> bool: - p = current_platform - return p.is_cuda() and p.has_device_capability((9, 0)) + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." + ) @staticmethod def _supports_no_act_and_mul() -> bool: - return False + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." + ) @staticmethod def _supports_quant_scheme( weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: - # return quant_scheme.is_int4_w4a8 and quant_scheme.per_token_quant - return False + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." + ) @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["silu", "gelu"] + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." + ) @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return True + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." + ) def supports_chunking(self) -> bool: return True From 1db50dcd697cf5d15d90be1c7a939edc3ff5edb7 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 17:35:25 -0500 Subject: [PATCH 097/113] reject usage for things that have not migrated over yet Signed-off-by: Robert Shaw --- .../layers/fused_moe/cutlass_moe.py | 70 ++++++++++--------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index e917fe203abf..db80c974cc29 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -18,6 +18,9 @@ moe_permute, moe_unpermute, ) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP, @@ -1027,6 +1030,7 @@ def cutlass_moe_w4a8_fp8( s_strides1: torch.Tensor, s_strides2: torch.Tensor, quant_config: FusedMoEQuantConfig, + moe_config: FusedMoEConfig, activation: str = "silu", expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, @@ -1084,36 +1088,36 @@ def cutlass_moe_w4a8_fp8( Returns: - torch.Tensor: The bf16 output tensor after applying the MoE layer. """ - # assert quant_config is not None - - # num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) - - # fn = mk.FusedMoEModularKernel( - # MoEPrepareAndFinalizeNoEP(), - # CutlassExpertsW4A8Fp8( - # out_dtype=a.dtype, - # a_strides1=a_strides1, - # a_strides2=a_strides2, - # b_strides1=b_strides1, - # b_strides2=b_strides2, - # c_strides1=c_strides1, - # c_strides2=c_strides2, - # s_strides1=s_strides1, - # s_strides2=s_strides2, - # # TODO: - # quant_config=quant_config, - # group_size=group_size, - # ), - # ) - - # return fn( - # a, - # w1_q, - # w2_q, - # topk_weights, - # topk_ids, - # activation=activation, - # global_num_experts=num_experts, - # expert_map=expert_map, - # apply_router_weight_on_input=apply_router_weight_on_input, - # ) + assert quant_config is not None + + num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) + + fn = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + CutlassExpertsW4A8Fp8( + out_dtype=a.dtype, + a_strides1=a_strides1, + a_strides2=a_strides2, + b_strides1=b_strides1, + b_strides2=b_strides2, + c_strides1=c_strides1, + c_strides2=c_strides2, + s_strides1=s_strides1, + s_strides2=s_strides2, + moe_config=moe_config, + quant_config=quant_config, + group_size=group_size, + ), + ) + + return fn( + a, + w1_q, + w2_q, + topk_weights, + topk_ids, + activation=activation, + global_num_experts=num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) From 6090a060aed4568ad7622247989279df991aa95a Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 17:47:01 -0500 Subject: [PATCH 098/113] differentiate static vs dynamic quantization Signed-off-by: Robert Shaw --- tests/compile/test_fusion_attn.py | 6 +++--- tests/compile/test_silu_mul_quant_fusion.py | 6 +++--- vllm/compilation/activation_quant_fusion.py | 6 +++--- vllm/compilation/fusion.py | 4 ++-- vllm/compilation/fusion_attn.py | 4 ++-- vllm/compilation/matcher_utils.py | 4 ++-- .../layers/fused_moe/flashinfer_cutedsl_moe.py | 5 +++-- .../layers/fused_moe/flashinfer_cutlass_moe.py | 5 +++-- .../layers/fused_moe/fused_marlin_moe.py | 4 ++-- .../compressed_tensors/compressed_tensors_moe.py | 6 +++--- vllm/model_executor/layers/quantization/modelopt.py | 6 +++--- .../layers/quantization/utils/flashinfer_fp4_moe.py | 5 +++-- .../layers/quantization/utils/quant_utils.py | 11 +++++++++-- vllm/v1/attention/backends/flashinfer.py | 4 ++-- 14 files changed, 43 insertions(+), 33 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index a1fd098aee5f..66d4f8d2caf3 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform @@ -202,7 +202,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): """Test model for AttentionNvfp4QuantPattern fusion.""" - quant_key = kNvfp4Quant + quant_key = kNvfp4Dynamic def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -449,7 +449,7 @@ def test_attention_quant_pattern( # Note: for fp8, fully_replaced=False because query quant ops remain in graph. # Only output quant ops are fused into attention. - test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant) + test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic) # access the underlying `AttnFusionPass` on the `LazyInitPass` assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index eb0dee8d4e39..446f8c7626b8 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -29,7 +29,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, @@ -121,11 +121,11 @@ def forward(self, x): def ops_in_model_before(self): return [ SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul, - QUANT_OPS[kNvfp4Quant], + QUANT_OPS[kNvfp4Dynamic], ] def ops_in_model_after(self): - return [FUSED_OPS[kNvfp4Quant]] + return [FUSED_OPS[kNvfp4Dynamic]] class TestSiluMulGroupFp8QuantModel(torch.nn.Module): diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index f0ce5b3db7ec..8530c0dadb23 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.platforms import current_platform @@ -41,7 +41,7 @@ torch.ops._C, "silu_and_mul_nvfp4_quant" ) if silu_and_mul_nvfp4_quant_supported: - FUSED_OPS[kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 + FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 class ActivationQuantPattern(ABC): @@ -129,7 +129,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): """ def __init__(self) -> None: - super().__init__(kNvfp4Quant) + super().__init__(kNvfp4Dynamic) def get_inputs(self) -> list[torch.Tensor]: result = self.empty_quant(5, 32) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index e3c6c2f20141..667828cc6a5a 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -20,7 +20,7 @@ kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, kStaticTensorScale, ) from vllm.platforms import current_platform @@ -63,7 +63,7 @@ def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor: kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): - QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default + QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default if current_platform.is_cuda(): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 57448aa0b096..69dc2e3a6868 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -16,7 +16,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, - kNvfp4Quant, + kNvfp4Dynamic, kStaticTensorScale, ) from vllm.platforms import current_platform @@ -217,7 +217,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): """ def __init__(self, layer: Attention, dtype: torch.dtype) -> None: - super().__init__(layer, kNvfp4Quant, dtype) + super().__init__(layer, kNvfp4Dynamic, dtype) def _register(self, pm_pass: PatternMatcherPass) -> None: def pattern( diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index eda12180d493..7bb98db5e9f6 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -21,7 +21,7 @@ kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform @@ -38,7 +38,7 @@ } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): - QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 + QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 if current_platform.is_cuda(): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index a2917a5736ef..82ba560f63d3 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -16,7 +16,8 @@ ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, - kNvfp4Quant, + kNvfp4Dynamic, + kNvfp4Static, ) from vllm.platforms import current_platform from vllm.utils.flashinfer import ( @@ -59,7 +60,7 @@ def _supports_quant_scheme( activation_key: QuantKey | None, ) -> bool: SUPPORTED_W_A = [ - (kNvfp4Quant, kNvfp4Quant), + (kNvfp4Static, kNvfp4Dynamic), ] return (weight_key, activation_key) in SUPPORTED_W_A diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 1cad2f482fef..7a77c302954a 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -17,7 +17,8 @@ kFp8Dynamic128Sym, kFp8Static128BlockSym, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, + kNvfp4Static, ) from vllm.platforms import current_platform from vllm.utils.flashinfer import ( @@ -127,7 +128,7 @@ def _supports_quant_scheme( and (p.is_device_capability((9, 0))) ) or ( - (scheme == (kNvfp4Quant, kNvfp4Quant)) + (scheme == (kNvfp4Static, kNvfp4Dynamic)) and (p.is_device_capability((10, 0))) # GB? ) ) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 6381087660ba..e0c31a1da725 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -31,7 +31,7 @@ QuantKey, kFp8Static128BlockSym, kFp8StaticChannelSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -581,7 +581,7 @@ def _supports_quant_scheme( SUPPORTED_W = [ kFp8Static128BlockSym, kFp8StaticChannelSym, - kNvfp4Quant, + kNvfp4Dynamic, ] return weight_key in SUPPORTED_W diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 5ec5f9556c9d..01b7b83b2375 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -86,7 +86,7 @@ kFp8Static128BlockSym, kFp8StaticChannelSym, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, @@ -247,8 +247,8 @@ def __init__( # Select experts implementation. self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( config=self.moe, - weight_key=kNvfp4Quant, - activation_key=kNvfp4Quant, + weight_key=kNvfp4Dynamic, + activation_key=kNvfp4Dynamic, ) # Delay creation of the kernel until after process-weights. diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index d7227e6a5284..8c68d5742aa0 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -74,7 +74,7 @@ cutlass_fp4_supported, is_layer_skipped, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -1337,8 +1337,8 @@ def __init__( # Select experts implementation. self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( config=self.moe, - weight_key=kNvfp4Quant, - activation_key=kNvfp4Quant, + weight_key=kNvfp4Dynamic, + activation_key=kNvfp4Dynamic, ) # Delay creation of the kernel until after process-weights. diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index edcda4bf23cd..8517162a9512 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -16,7 +16,8 @@ ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, - kNvfp4Quant, + kNvfp4Dynamic, + kNvfp4Static, swizzle_blockscale, ) from vllm.platforms import current_platform @@ -62,7 +63,7 @@ def _supports_quant_scheme( ) -> bool: """Supports Nvfp4 quantization.""" SUPPORTED_W_A = [ - (kNvfp4Quant, kNvfp4Quant), + (kNvfp4Static, kNvfp4Dynamic), ] return (weight_key, activation_key) in SUPPORTED_W_A diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index fe0cea539713..15b7f3444cde 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -130,8 +130,15 @@ def __str__(self): kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) -kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) -kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale) +kNvfp4DynamicGroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) +kNvfp4Dynamic = QuantKey( + FP4_DTYPE, scale=kNvfp4DynamicGroupScale, scale2=kStaticTensorScale +) + +kNvfp4StaticGroupScale = ScaleDesc(FP8_DTYPE, True, GroupShape(1, 16)) +kNvfp4Static = QuantKey( + FP4_DTYPE, scale=kNvfp4StaticGroupScale, scale2=kStaticTensorScale +) kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128)) kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 9892c360d3d6..5ed860cc56a9 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -29,7 +29,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability @@ -1184,7 +1184,7 @@ def fused_output_quant_supported(self, quant_key: QuantKey): return ( self.support_trtllm_attn and self.kv_cache_dtype.startswith("fp8") - and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) + and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic) ) # FlashInfer requires attention sinks to be float32 From 6a3a75b3b7e2c12c56d1b9a59c3129790f4e5ed4 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 17:47:41 -0500 Subject: [PATCH 099/113] remove newline Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 7a77c302954a..5d68e8f0cce9 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -241,7 +241,6 @@ def apply( assert self.w1_scale is not None and self.w2_scale is not None, ( "w1_scale and w2_scale must not be None for FlashInferExperts" ) - # Flashinfer CUTLASS kernel takes scalar global scales, # min because inv_scale. quant_scales = [ From 89911ea880e44032b8cadd822c46eac27e253790 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 17:56:44 -0500 Subject: [PATCH 100/113] fix static vs dynamic Signed-off-by: Robert Shaw --- .../quantization/compressed_tensors/compressed_tensors_moe.py | 3 ++- vllm/model_executor/layers/quantization/modelopt.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 01b7b83b2375..a8180f1d6e4f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -87,6 +87,7 @@ kFp8StaticChannelSym, kFp8StaticTensorSym, kNvfp4Dynamic, + kNvfp4Static, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, @@ -247,7 +248,7 @@ def __init__( # Select experts implementation. self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( config=self.moe, - weight_key=kNvfp4Dynamic, + weight_key=kNvfp4Static, activation_key=kNvfp4Dynamic, ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 8c68d5742aa0..d999351d8264 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -75,6 +75,7 @@ is_layer_skipped, kFp8StaticTensorSym, kNvfp4Dynamic, + kNvfp4Static, swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -1337,7 +1338,7 @@ def __init__( # Select experts implementation. self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( config=self.moe, - weight_key=kNvfp4Dynamic, + weight_key=kNvfp4Static, activation_key=kNvfp4Dynamic, ) From ad8fe2e9809e5de95a542966e78583ba456001d6 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 18:32:29 -0500 Subject: [PATCH 101/113] get things working again Signed-off-by: Robert Shaw --- tests/kernels/moe/test_flashinfer_moe.py | 43 ++++++++++++++++--- .../layers/fused_moe/fused_marlin_moe.py | 4 +- .../layers/fused_moe/oracle/nvfp4.py | 10 ++--- .../compressed_tensors_moe.py | 2 +- 4 files changed, 46 insertions(+), 13 deletions(-) diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index 1262eea70bab..a81a1294990e 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -12,15 +12,19 @@ from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe, ) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( - create_flashinfer_prepare_finalize, -) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.torch_utils import set_random_seed @@ -86,9 +90,38 @@ def test_flashinfer_fp4_moe_no_graph( assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) + moe_config = FusedMoEConfig( + num_experts=e, + experts_per_token=topk, + hidden_dim=k, + intermediate_size_per_partition=n, + num_local_experts=e, + activation=activation, + device="cuda", + moe_parallel_config=FusedMoEParallelConfig( + tp_size=1, + pcp_size=1, + dp_size=1, + ep_size=1, + tp_rank=0, + pcp_rank=0, + dp_rank=0, + ep_rank=0, + use_ep=False, + all2all_backend="allgather_reducescatter", + ), + in_dtype=dtype, + is_act_and_mul=is_gated_act, + ) + flashinfer_experts = FusedMoEModularKernel( - create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True), - FlashInferExperts(out_dtype=dtype, quant_config=quant_config), + MoEPrepareAndFinalizeNoEP( + defer_input_quant=FlashInferExperts.should_pf_defer_input_quant( + moe_config=moe_config, + quant_config=quant_config, + ) + ), + FlashInferExperts(moe_config=moe_config, quant_config=quant_config), ) fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation] diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index e0c31a1da725..5b552dc93287 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -31,7 +31,7 @@ QuantKey, kFp8Static128BlockSym, kFp8StaticChannelSym, - kNvfp4Dynamic, + kNvfp4Static, ) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -581,7 +581,7 @@ def _supports_quant_scheme( SUPPORTED_W = [ kFp8Static128BlockSym, kFp8StaticChannelSym, - kNvfp4Dynamic, + kNvfp4Static, ] return weight_key in SUPPORTED_W diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 875995f1b7e6..e3d8bb4ab062 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -39,11 +39,11 @@ class NvFp4MoeBackend(Enum): - FLASHINFER_TRTLLM = "FlashInfer TRTLLM" - FLASHINFER_CUTLASS = "FlashInfer CUTLASS" - FLASHINFER_CUTEDSL = "FlashInfer CUTEDSL" - VLLM_CUTLASS = "vLLM CUTLASS" - MARLIN = "vLLM MARLIN" + FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" + FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS" + FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL" + VLLM_CUTLASS = "VLLM_CUTLASS" + MARLIN = "MARLIN" FLASHINFER_NVFP4_MOE_BACKENDS = [ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index a8180f1d6e4f..ac73970b91f3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -249,7 +249,7 @@ def __init__( self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( config=self.moe, weight_key=kNvfp4Static, - activation_key=kNvfp4Dynamic, + activation_key=None if use_a16 else kNvfp4Dynamic, ) # Delay creation of the kernel until after process-weights. From 7bc367477a94044bcf084e0d6d616d540adaac26 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 19:52:40 -0500 Subject: [PATCH 102/113] updated Signed-off-by: Robert Shaw --- .../fused_moe/unquantized_fused_moe_method.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index c31c0223eb06..fdbf1d159ab3 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -257,8 +257,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.use_inplace = True self.kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - AiterExperts(self.moe_quant_config), - shared_experts=None, + AiterExperts( + moe_config=self.moe_quant_config, + quant_config=self.moe_quant_config, + ), ) elif self.flashinfer_cutlass_moe_enabled: @@ -270,19 +272,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), FlashInferExperts( - out_dtype=layer.params_dtype, + moe_config=self.moe, quant_config=self.moe_quant_config, - tp_rank=self.moe.moe_parallel_config.tp_rank, - tp_size=self.moe.moe_parallel_config.tp_size, - ep_rank=self.moe.moe_parallel_config.ep_rank, - ep_size=self.moe.moe_parallel_config.ep_size, ), ) else: self.use_inplace = True self.kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - TritonExperts(self.moe_quant_config), + TritonExperts( + moe_config=self.moe_quant_config, + quant_config=self.moe_quant_config, + ), shared_experts=None, ) From f8d3af761b69b00f3deece8ae426d40c5883771f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 19:53:26 -0500 Subject: [PATCH 103/113] updatred Signed-off-by: Robert Shaw --- .../layers/fused_moe/unquantized_fused_moe_method.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index fdbf1d159ab3..c4405b170ec7 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -121,14 +121,18 @@ def select_gemm_impl( == FusedMoEActivationFormat.BatchedExperts ): logger.debug("BatchedTritonExperts %s", self.moe) - return BatchedTritonExperts( + return BatchedTritonExperts.make_batched_experts( + moe_config=self.moe, + quant_config=self.moe_quant_config, max_num_tokens=self.moe.max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, ) else: logger.debug("TritonExperts %s", self.moe) - return TritonExperts(self.moe_quant_config) + return TritonExperts( + moe_config=self.moe, + quant_config=self.moe_quant_config, + ) def create_weights( self, From 10b957c05290e70b747ac80041c2bd3540486668 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 20:40:45 -0500 Subject: [PATCH 104/113] attempt to get sp working Signed-off-by: Robert Shaw --- tests/kernels/moe/test_flashinfer.py | 1 + tests/kernels/moe/test_flashinfer_moe.py | 1 + vllm/config/parallel.py | 1 + .../layers/fused_moe/all2all_utils.py | 5 ++++- vllm/model_executor/layers/fused_moe/config.py | 3 +++ .../flashinfer_cutlass_prepare_finalize.py | 8 ++++++-- vllm/model_executor/layers/fused_moe/layer.py | 2 +- .../layers/fused_moe/prepare_finalize.py | 13 ++++++++++--- 8 files changed, 27 insertions(+), 7 deletions(-) diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index bb2f6b873941..2fa00a8effb7 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -127,6 +127,7 @@ def make_moe_tensors_8bit( ep_rank=0, use_ep=False, all2all_backend="naive", + isequence_parallel=False, ) # flashinfer expects swapped rows for w13 diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index a81a1294990e..d9b7f4678534 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -109,6 +109,7 @@ def test_flashinfer_fp4_moe_no_graph( ep_rank=0, use_ep=False, all2all_backend="allgather_reducescatter", + isequence_parallel=False, ), in_dtype=dtype, is_act_and_mul=is_gated_act, diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 363078aefc51..ae90d9bd5201 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -425,6 +425,7 @@ def stateless_init_dp_group(self) -> ProcessGroup: # to avoid the excess work. # # Not needed for pplx-kernels as it can handle duplicate input tokens. + # TODO(rob): investigate 'flashinfer_all2allv'? @property def use_sequence_parallel_moe(self) -> bool: return ( diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index b2efb5dcda1c..df9e81ab7b5e 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -188,6 +188,9 @@ def maybe_make_prepare_finalize( ) elif moe.use_naive_kernels and allow_new_interface: - prepare_finalize = MoEPrepareAndFinalizeNaiveEP(defer_input_quant) + prepare_finalize = MoEPrepareAndFinalizeNaiveEP( + defer_input_quant, + moe.moe_parallel_config.is_sequence_parallel, + ) return prepare_finalize diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 55b30c7a43c6..a391668e9db5 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -864,6 +864,7 @@ class FusedMoEParallelConfig: use_ep: bool # whether to use EP or not all2all_backend: str # all2all backend for MoE communication + is_sequence_parallel: bool # whether sequence parallelism is used @property def use_all2all_kernels(self): @@ -1013,6 +1014,7 @@ def make( ep_rank=0, use_ep=False, all2all_backend=vllm_parallel_config.all2all_backend, + is_sequence_parallel=vllm_parallel_config.is_sequence_parallel, ) # DP + EP / TP + EP / DP + TP + EP assert use_ep @@ -1031,6 +1033,7 @@ def make( ep_rank=ep_rank, use_ep=True, all2all_backend=vllm_parallel_config.all2all_backend, + is_sequence_parallel=vllm_parallel_config.is_sequence_parallel, ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 2c1bcb2db818..2d7a468be2a9 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -364,9 +364,13 @@ def create_flashinfer_prepare_finalize( assert use_nvfp4 return FlashInferAllToAllMoEPrepareAndFinalize(use_dp) defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4 - return MoEPrepareAndFinalizeNaiveEP(defer_input_quant=defer_input_quant) + return MoEPrepareAndFinalizeNaiveEP( + defer_input_quant=defer_input_quant, is_sequence_parallel=False + ) else: # CUTLASS FP8 BLOCK and CUTLASS NVFP4 apply input quantization # in a single call with the MoE experts kernel. defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4 - return MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant) + return MoEPrepareAndFinalizeNoEP( + defer_input_quant=defer_input_quant, + ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a1c4955ec6ba..2be7303e3dc1 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -360,7 +360,7 @@ def __init__( enable_eplb: bool = False, num_redundant_experts: int = 0, has_bias: bool = False, - is_sequence_parallel=False, + is_sequence_parallel: bool = False, expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, routing_method_type: RoutingMethodType | None = None, diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index e853c4b2147f..5fe6b7f54451 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -15,9 +15,14 @@ class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): - def __init__(self, defer_input_quant: bool = False) -> None: + def __init__( + self, + defer_input_quant: bool = False, + is_sequence_parallel: bool = False, + ) -> None: super().__init__() self.defer_input_quant = defer_input_quant + self.is_sequence_parallel = is_sequence_parallel @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -89,7 +94,7 @@ def prepare( a1q, topk_weights, topk_ids, - is_sequence_parallel=False, # TODO: support SP + is_sequence_parallel=self.is_sequence_parallel, extra_tensors=scales, ) if skip_gather_scales: @@ -131,7 +136,9 @@ def finalize( apply_router_weight_on_input=apply_router_weight_on_input, ) - output.copy_(get_ep_group().combine(out, is_sequence_parallel=False)) + output.copy_( + get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel) + ) class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): From e8ab5455edf3b2b203ed65868c2b89b0aced5c98 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 20:41:54 -0500 Subject: [PATCH 105/113] nit Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index a391668e9db5..4f0a3a3fe276 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1014,7 +1014,7 @@ def make( ep_rank=0, use_ep=False, all2all_backend=vllm_parallel_config.all2all_backend, - is_sequence_parallel=vllm_parallel_config.is_sequence_parallel, + is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe, ) # DP + EP / TP + EP / DP + TP + EP assert use_ep @@ -1033,7 +1033,7 @@ def make( ep_rank=ep_rank, use_ep=True, all2all_backend=vllm_parallel_config.all2all_backend, - is_sequence_parallel=vllm_parallel_config.is_sequence_parallel, + is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe, ) From 50a8c977c83466a70a8e56ddc6dfd0e620201c05 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 20:56:44 -0500 Subject: [PATCH 106/113] appears to be working properly Signed-off-by: Robert Shaw --- .../layers/fused_moe/oracle/fp8.py | 4 ++ .../compressed_tensors_moe.py | 37 +++++++++---------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index c7b90f1285ba..dee7f1a4de39 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -363,6 +363,8 @@ def make_fp8_moe_quant_config( a1_scale: torch.Tensor | None, a2_scale: torch.Tensor | None, block_shape: list[int] | None = None, + per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, ) -> FusedMoEQuantConfig | None: """ Create FusedMoEQuantConfig for the specifed FP8 Backend. @@ -415,6 +417,8 @@ def make_fp8_moe_quant_config( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ac73970b91f3..a6060ce01c4c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -28,8 +28,6 @@ FusedMoEConfig, FusedMoEQuantConfig, RoutingMethodType, - fp8_w8a8_moe_quant_config, - fp8_w8a16_moe_quant_config, int4_w4a16_moe_quant_config, int4_w4afp8_moe_quant_config, int8_w8a8_moe_quant_config, @@ -46,6 +44,7 @@ Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, + make_fp8_moe_quant_config, select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( @@ -862,24 +861,24 @@ def select_gemm_impl( def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - if self.fp8_backend == Fp8MoeBackend.MARLIN: - return fp8_w8a16_moe_quant_config( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - block_shape=self.weight_block_size, - ) - - per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN - per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL + w1_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + a1_scale = layer.w13_input_scale + a2_scale = layer.w2_input_scale - return fp8_w8a8_moe_quant_config( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - per_act_token_quant=per_act_token, - per_out_ch_quant=per_channel_quant, - block_shape=layer.weight_block_size, + return make_fp8_moe_quant_config( + fp8_backend=self.fp8_backend, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=( + self.input_quant.strategy == QuantizationStrategy.TOKEN + ), + per_out_ch_quant=( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + ), + block_shape=self.weight_block_size, ) def apply( From ddc2eb14e557c3ab13bd1a3c63aec28d157f3f9a Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 21:06:14 -0500 Subject: [PATCH 107/113] fix pre commit Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/all2all_utils.py | 3 ++- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 4 ++++ vllm/model_executor/layers/fused_moe/layer.py | 1 + vllm/model_executor/layers/fused_moe/prepare_finalize.py | 4 +++- 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index df9e81ab7b5e..43f75cc31428 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -190,7 +190,8 @@ def maybe_make_prepare_finalize( elif moe.use_naive_kernels and allow_new_interface: prepare_finalize = MoEPrepareAndFinalizeNaiveEP( defer_input_quant, - moe.moe_parallel_config.is_sequence_parallel, + is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel), + num_dispatchers=all2all_manager.world_size, ) return prepare_finalize diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index db80c974cc29..e41fa3f7d50c 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -881,6 +881,10 @@ def __init__( self.s_strides2 = s_strides2 self.group_size = group_size + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + @staticmethod def _supports_current_device() -> bool: raise NotImplementedError( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2be7303e3dc1..01bb2e2423e0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1684,6 +1684,7 @@ def must_reduce_shared_expert_outputs(self) -> bool: early. """ assert self.quant_method is not None + # TODO(rob): investigate this. return ( isinstance(self.quant_method, FusedMoEModularMethod) and self.quant_method.fused_experts.output_is_reduced() diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 5fe6b7f54451..a4d38d184e9b 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -19,10 +19,12 @@ def __init__( self, defer_input_quant: bool = False, is_sequence_parallel: bool = False, + num_dispatchers: int = 1, ) -> None: super().__init__() self.defer_input_quant = defer_input_quant self.is_sequence_parallel = is_sequence_parallel + self._num_dispatchers = num_dispatchers @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -35,7 +37,7 @@ def topk_indices_dtype(self) -> torch.dtype | None: return None def num_dispatchers(self) -> int: - return 1 + return self._num_dispatchers def output_is_reduced(self) -> bool: return False From 5f913ea5acda545d3cadcf1e747931a8cf196c69 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 21:14:05 -0500 Subject: [PATCH 108/113] updated Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/config.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 4f0a3a3fe276..b5cd5167e7e1 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -120,9 +120,6 @@ class RoutingMethodType(IntEnum): Unspecified = 6.0 -UNQUANTIZED_DTYPES = [torch.bfloat16, torch.float16, torch.float32] - - @dataclass class FusedMoEQuantDesc: """ @@ -360,10 +357,6 @@ def ocp_mx_scheme(self) -> str | None: def use_mxfp4_w4a16(self) -> bool: return self._a1.dtype is None and self._w1.dtype == "mxfp4" - @property - def use_mxfp4_w4a8(self) -> bool: - return self._a1.dtype == "mxfp8" and self._w1.dtype == "mxfp4" - @property def use_mxfp4_w4a4(self) -> bool: return self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4" From e58e783946ab750d2cc529785692c3dd3f679caf Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 21:36:07 -0500 Subject: [PATCH 109/113] remove flashinfer constructors Signed-off-by: Robert Shaw --- .../moe/modular_kernel_tools/mk_objects.py | 10 +- .../flashinfer_cutlass_prepare_finalize.py | 120 +----------------- 2 files changed, 5 insertions(+), 125 deletions(-) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 99b168dc7554..f085706bb22c 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -237,7 +237,6 @@ def expert_info(kind) -> ExpertInfo: ) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 FlashInferCutlassMoEPrepareAndFinalize, - create_flashinfer_prepare_finalize, ) register_prepare_and_finalize( @@ -389,13 +388,12 @@ def make_prepare_finalize( quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEPrepareAndFinalize: if backend != "naive" and backend is not None: - prepare_finalize = maybe_make_prepare_finalize(moe, quant_config) + # TODO(rob): add defer input quant. + prepare_finalize = maybe_make_prepare_finalize( + moe, quant_config, allow_new_interface=True + ) assert prepare_finalize is not None return prepare_finalize - elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize: - return create_flashinfer_prepare_finalize( - use_dp=moe.moe_parallel_config.dp_size > 1 - ) else: return MoEPrepareAndFinalizeNoEP() diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 2d7a468be2a9..44f7b3b9bc0a 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -4,19 +4,12 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.distributed import get_dp_group, get_ep_group +from vllm.distributed import get_ep_group from vllm.distributed.device_communicators.base_device_communicator import ( All2AllManagerBase, ) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNaiveEP, - MoEPrepareAndFinalizeNoEP, -) -from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP, -) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils.flashinfer import nvfp4_block_scale_interleave @@ -163,90 +156,6 @@ def finalize( output.copy_(fused_expert_output) -class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize): - def __init__( - self, - use_dp: bool, - num_dispatchers: int = 1, - use_deepseek_fp8_block_scale: bool = False, - ): - super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale) - - def prepare( - self, - a1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: torch.Tensor | None, - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - self._apply_router_weight_on_input( - a1, topk_weights, topk_ids, apply_router_weight_on_input - ) - is_nvfp4 = quant_config.quant_dtype == "nvfp4" - if not self.use_dp and is_nvfp4: - return a1, None, None, topk_ids, topk_weights - - if not self.use_deepseek_fp8_block_scale: - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_gscale if is_nvfp4 else quant_config.a1_scale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=not self.use_dp, - ) - else: - # Block-scale path: pass activations through, omit per-token scales - a1q = a1 - a1q_scale = None - - if self.use_dp: - # Build gather list conditionally - omit a1q_scale if None - # (block-scale path) - gather_list = [topk_weights, topk_ids, a1q] - if a1q_scale is not None: - gather_list.append(a1q_scale) - gathered = get_dp_group().all_gatherv( - gather_list, - dim=0, - sizes=get_local_sizes(), - ) - topk_weights, topk_ids, a1q, a1q_scale = gathered - else: - gathered = get_dp_group().all_gatherv( - gather_list, - dim=0, - sizes=get_local_sizes(), - ) - topk_weights, topk_ids, a1q = gathered - a1q_scale = None - - if is_nvfp4 and a1q_scale is not None: - a1q_scale = nvfp4_block_scale_interleave(a1q_scale) - - return a1q, a1q_scale, None, topk_ids, topk_weights - - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: - assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceNoOP) - - if self.use_dp: - fused_expert_output = get_dp_group().reduce_scatterv( - fused_expert_output, dim=0, sizes=get_local_sizes() - ) - output.copy_(fused_expert_output) - - def flashinfer_alltoall_dispatch( all2all_manager: All2AllManagerBase, global_num_tokens_cpu: list[int], @@ -347,30 +256,3 @@ def flashinfer_alltoall_combine( top_k=top_k, token_count=token_count, ) - - -def create_flashinfer_prepare_finalize( - use_dp: bool, - use_nvfp4: bool = False, - enable_alltoallv: bool = False, - use_deepseek_fp8_block_scale: bool = False, -) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP: - """Factory function to create the appropriate FlashInfer implementation.""" - - # NOTE(rob): CUTLASS FP8 block quant executes the input - # quantzation and grouped gemm in a single kernel. - if use_dp: - if enable_alltoallv: - assert use_nvfp4 - return FlashInferAllToAllMoEPrepareAndFinalize(use_dp) - defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4 - return MoEPrepareAndFinalizeNaiveEP( - defer_input_quant=defer_input_quant, is_sequence_parallel=False - ) - else: - # CUTLASS FP8 BLOCK and CUTLASS NVFP4 apply input quantization - # in a single call with the MoE experts kernel. - defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4 - return MoEPrepareAndFinalizeNoEP( - defer_input_quant=defer_input_quant, - ) From a8de4dacbbafe6f97b149c1762744edf4154bf63 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 21:41:46 -0500 Subject: [PATCH 110/113] nits Signed-off-by: Robert Shaw --- .../layers/fused_moe/unquantized_fused_moe_method.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index c4405b170ec7..ef0f79a18846 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -262,7 +262,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), AiterExperts( - moe_config=self.moe_quant_config, + moe_config=self.moe, quant_config=self.moe_quant_config, ), ) @@ -285,7 +285,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), TritonExperts( - moe_config=self.moe_quant_config, + moe_config=self.moe, quant_config=self.moe_quant_config, ), shared_experts=None, From 2c9e9e6d87b6f676272444c140c4453f9edf9311 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 21:43:23 -0500 Subject: [PATCH 111/113] remove do naive dispach combine comment Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/layer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 01bb2e2423e0..77898da75b1d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1923,7 +1923,6 @@ def forward_impl( hidden_states, router_logits, has_separate_shared_experts ) - # TODO(rob): remove this once we migrate to internal use of MK. do_naive_dispatch_combine: bool = self.dp_size > 1 and not ( isinstance(self.quant_method, FusedMoEModularMethod) or self.quant_method.supports_mk_interally From 9a907e0a62723c9f2a9d053e36a3eec12b501332 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 21:45:23 -0500 Subject: [PATCH 112/113] update backend names Signed-off-by: Robert Shaw --- .../layers/fused_moe/oracle/fp8.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index dee7f1a4de39..fe6d5cc68bc9 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -45,15 +45,15 @@ class Fp8MoeBackend(Enum): NONE = 0 - FLASHINFER_TRTLLM = "FlashInfer TRTLLM" - FLASHINFER_CUTLASS = "FlashInfer CUTLASS" - DEEPGEMM = "DeepGEMM" - BATCHED_DEEPGEMM = "Batched DeepGEMM" - MARLIN = "Marlin" - TRITON = "Triton" - BATCHED_TRITON = "Batched Triton" + FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" + FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS" + DEEPGEMM = "DEEPGEMM" + BATCHED_DEEPGEMM = "BATCHED_DEEPGEMM" + MARLIN = "MARLIN" + TRITON = "TRITON" + BATCHED_TRITON = "BATCHED_TRITON" AITER = "AITER" - VLLM_CUTLASS = "vLLM CUTLASS" + VLLM_CUTLASS = "VLLM_CUTLASS" def backend_2_kernel_cls( From 0670758aa37233688cfc668c7c40b4a1b5a836ad Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 14 Jan 2026 21:50:30 -0500 Subject: [PATCH 113/113] update fallback experts Signed-off-by: Robert Shaw --- .../layers/fused_moe/fallback.py | 52 +++++++++++++++++++ .../layers/fused_moe/triton_cutlass_moe.py | 49 ++--------------- .../layers/fused_moe/triton_deep_gemm_moe.py | 35 ++----------- 3 files changed, 58 insertions(+), 78 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py index ec59c04f8b89..9c9e416f366a 100644 --- a/vllm/model_executor/layers/fused_moe/fallback.py +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -6,11 +6,16 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): """Base class for runtime dispatching of expert implementations.""" + _experts_cls = mk.FusedMoEPermuteExpertsUnpermute + _fallback_cls = mk.FusedMoEPermuteExpertsUnpermute + def __init__( self, experts: mk.FusedMoEPermuteExpertsUnpermute, @@ -22,6 +27,53 @@ def __init__( self.fallback_experts = fallback_experts self.experts = experts + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + assert ( + FallbackExperts._experts_cls.activation_format() + == FallbackExperts._fallback_cls.activation_format() + ) + return FallbackExperts._experts_cls.activation_format() + + @staticmethod + def _supports_current_device() -> bool: + return ( + FallbackExperts._experts_cls._supports_current_device() + and FallbackExperts._fallback_cls._supports_current_device() + ) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return ( + FallbackExperts._experts_cls._supports_no_act_and_mul() + and FallbackExperts._fallback_cls._supports_no_act_and_mul() + ) + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return FallbackExperts._experts_cls._supports_quant_scheme( + weight_key, activation_key + ) and FallbackExperts._fallback_cls._supports_quant_scheme( + weight_key, activation_key + ) + + @staticmethod + def _supports_activation(activation: str) -> bool: + return FallbackExperts._experts_cls._supports_activation( + activation + ) and FallbackExperts._fallback_cls._supports_activation(activation) + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return FallbackExperts._experts_cls._supports_parallel_config( + moe_parallel_config + ) and FallbackExperts._fallback_cls._supports_parallel_config( + moe_parallel_config + ) + def supports_chunking(self) -> bool: assert ( self.experts.supports_chunking() diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py index 39163bb24f58..e21d17226671 100644 --- a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py @@ -7,21 +7,20 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, - FusedMoEParallelConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, -) from vllm.platforms import current_platform class TritonOrCutlassExperts(FallbackExperts): """Cutlass with fallback to Triton for low latency shapes on SM100.""" + _experts_cls = CutlassExpertsFp8 + _fallback_cls = TritonExperts + def __init__( self, moe_config: FusedMoEConfig, @@ -33,48 +32,6 @@ def __init__( fallback_experts=TritonExperts(moe_config, quant_config), ) - @staticmethod - def activation_format() -> mk.FusedMoEActivationFormat: - assert ( - CutlassExpertsFp8.activation_format() == TritonExperts.activation_format() - ) - return CutlassExpertsFp8.activation_format() - - @staticmethod - def _supports_current_device() -> bool: - return ( - CutlassExpertsFp8._supports_current_device() - and TritonExperts._supports_current_device() - ) - - @staticmethod - def _supports_no_act_and_mul() -> bool: - return ( - CutlassExpertsFp8._supports_no_act_and_mul() - and TritonExperts._supports_no_act_and_mul() - ) - - @staticmethod - def _supports_quant_scheme( - weight_key: QuantKey | None, - activation_key: QuantKey | None, - ) -> bool: - return CutlassExpertsFp8._supports_quant_scheme( - weight_key, activation_key - ) and TritonExperts._supports_quant_scheme(weight_key, activation_key) - - @staticmethod - def _supports_activation(activation: str) -> bool: - return CutlassExpertsFp8._supports_activation( - activation - ) and TritonExperts._supports_activation(activation) - - @staticmethod - def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return CutlassExpertsFp8._supports_parallel_config( - moe_parallel_config - ) and TritonExperts._supports_parallel_config(moe_parallel_config) - def workspace_shapes( self, M: int, diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 3f3a76f80409..6c13903fccd3 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -6,7 +6,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, - FusedMoEParallelConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( @@ -16,9 +15,6 @@ ) from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, -) from vllm.utils.deep_gemm import ( is_deep_gemm_e8m0_used, ) @@ -27,40 +23,15 @@ class TritonOrDeepGemmExperts(FallbackExperts): """DeepGemm with fallback to Triton for low latency shapes.""" + _experts_cls = DeepGemmExperts + _fallback_cls = TritonExperts + def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig): super().__init__( experts=DeepGemmExperts(moe_config, quant_config), fallback_experts=TritonExperts(moe_config, quant_config), ) - @staticmethod - def activation_format() -> mk.FusedMoEActivationFormat: - assert DeepGemmExperts.activation_format() == TritonExperts.activation_format() - return DeepGemmExperts.activation_format() - - @staticmethod - def _supports_current_device() -> bool: - return DeepGemmExperts._supports_current_device() - - @staticmethod - def _supports_no_act_and_mul() -> bool: - return DeepGemmExperts._supports_no_act_and_mul() - - @staticmethod - def _supports_quant_scheme( - weight_key: QuantKey | None, - activation_key: QuantKey | None, - ) -> bool: - return DeepGemmExperts._supports_quant_scheme(weight_key, activation_key) - - @staticmethod - def _supports_activation(activation: str) -> bool: - return DeepGemmExperts._supports_activation(activation) - - @staticmethod - def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return DeepGemmExperts._supports_parallel_config(moe_parallel_config) - def workspace_shapes( self, M: int,