diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 34bd58832f28..5b58353927a4 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -229,9 +229,6 @@ class FusedMoEQuantConfig: _w1: FusedMoEQuantDesc _w2: FusedMoEQuantDesc is_nvfp4_scale_swizzled: bool = True - # CK MXFP4 (gfx950) padding info for rocm_aiter_ops.fused_moe() - hidden_pad: int = 0 - intermediate_pad: int = 0 def __post_init__(self): assert not self.per_act_token_quant or self.block_shape is None, ( @@ -1172,6 +1169,11 @@ class FusedMoEConfig: # Defaults to in_dtype if not specified. router_logits_dtype: torch.dtype | None = None + # Defaults to hidden_dim if not specified. + hidden_dim_unpadded: int | None = None + # Defaults to intermediate_size_per_partition if not specified. + intermediate_size_per_partition_unpadded: int | None = None + moe_backend: str = "auto" max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE has_bias: bool = False @@ -1195,6 +1197,13 @@ def __post_init__(self): if self.router_logits_dtype is None: self.router_logits_dtype = self.in_dtype + if self.hidden_dim_unpadded is None: + self.hidden_dim_unpadded = self.hidden_dim + if self.intermediate_size_per_partition_unpadded is None: + self.intermediate_size_per_partition_unpadded = ( + self.intermediate_size_per_partition + ) + @property def tp_size(self): return self.moe_parallel_config.tp_size diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index f6a303e7988e..d951439d34a0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -9,6 +9,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, + FusedMoEParallelConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( @@ -65,6 +66,38 @@ def uses_weight_scale_2_pattern(self) -> bool: """ return False + def maybe_roundup_sizes( + self, + hidden_size: int, + intermediate_size_per_partition: int, + act_dtype: torch.dtype, + moe_parallel_config: FusedMoEParallelConfig, + ) -> tuple[int, int]: + """ + Given layer hidden size and intermediate size per partition and MoE + configurations, round up hidden_size and intermediate_size_per_partition + if necessary. + + Args: + hidden_size: Layer hidden-size + intermediate_size_per_partition: Intermediate size per partition for + the layer. + act_dtype: Data type of the layer activations. + moe_parallel_config: Fused MoE parallelization strategy configuration. + + Return: + A tuple of (rounded_hidden_size, rounded_intermediate_size_per_partition), + where: + - rounded_hidden_size is the possibly rounded up hidden size. + - rounded_intermediate_size_per_partition is the possibly rounded + up intermediate size per partition. + """ + from .all2all_utils import maybe_roundup_layer_hidden_size + + return maybe_roundup_layer_hidden_size( + hidden_size, act_dtype, moe_parallel_config + ), intermediate_size_per_partition + def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 7caa66a5bf57..e03ecd01ae79 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -428,13 +428,9 @@ def triton_kernel_fused_mxfp4_w4a8_experts( assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32 assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32 - # Shape check: when weights are padded (e.g. hidden_size padded for - # GFX950 swizzle), unpadded_K_w1 carries the original dimension. - expected_K_w1 = unpadded_K_w1 if unpadded_K_w1 is not None else w1.shape[-2] - assert hidden_states.shape[-1] == expected_K_w1, ( - f"hidden_states K={hidden_states.shape[-1]} != " - f"expected K={expected_K_w1} (w1 K={w1.shape[-2]})" - ) + # Shape check: weights are padded (e.g. hidden_size padded for + # GFX950 swizzle). + assert hidden_states.shape[-1] == w1.shape[-2] assert w2.shape[-1] == w1.shape[1] E, _, N = w1.shape @@ -494,12 +490,6 @@ def triton_kernel_fused_mxfp4_w4a8_experts( unpadded_K=unpadded_K_w2, ) - # When hidden_size was padded for alignment (e.g. GFX950 swizzle), - # the kernel output has the padded dimension. Slice back to the - # original hidden_size so downstream layers see the expected shape. - if unpadded_N_w2 is not None and intermediate_cache3.shape[-1] != unpadded_N_w2: - intermediate_cache3 = intermediate_cache3[..., :unpadded_N_w2].contiguous() - return intermediate_cache3 diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f1905fd28e3d..a95481a7e6a0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -210,42 +210,6 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str: ) -# TODO(rob): move this down to the kernel. -def maybe_roundup_hidden_size( - hidden_size: int, - act_dtype: torch.dtype, - moe_parallel_config: FusedMoEParallelConfig, - is_lora_enabled: bool, - model_type: str | None, -) -> int: - """ - Given layer hidden size and MoE configurations, round up hidden_size - if necessary. - - Args: - hidden_size: Layer hidden-size - act_dtype: Data type of the layer activations. - moe_parallel_config: Fused MoE parallelization strategy configuration. - is_lora_enabled: True if the engine is enabled with LoRA. This - is used in the case of mxfp4 quantization in selecting the - MxFP4Backend. - model_type: for checking if gpt-oss - - Return: - Rounded up hidden_size if rounding up is required based on the configs. - Original hidden size otherwise. - """ - from vllm.model_executor.layers.fused_moe.all2all_utils import ( - maybe_roundup_layer_hidden_size, - ) - - hidden_size = maybe_roundup_layer_hidden_size( - hidden_size, act_dtype, moe_parallel_config - ) - - return hidden_size - - # --8<-- [start:fused_moe] @CustomOp.register("fused_moe") class FusedMoE(CustomOp): @@ -459,7 +423,7 @@ def __init__( ), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s." assert intermediate_size % self.tp_size == 0 - self.intermediate_size_per_partition = intermediate_size // self.tp_size + intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize @@ -501,28 +465,13 @@ def __init__( ) self.routing_method_type: RoutingMethodType = self.router.routing_method_type - # Round up hidden size before creating moe_config. - # This way moe_config is created with the correct hidden_size from the start. - unpadded_hidden_size = hidden_size - self.model_type = ( - self.vllm_config.model_config.hf_config.model_type - if self.vllm_config.model_config is not None - else None - ) - hidden_size = maybe_roundup_hidden_size( - hidden_size=hidden_size, - act_dtype=moe_in_dtype, - moe_parallel_config=self.moe_parallel_config, - is_lora_enabled=vllm_config.lora_config is not None, - model_type=self.model_type, - ) - self.hidden_size = hidden_size - self.moe_config: FusedMoEConfig = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, hidden_dim=hidden_size, - intermediate_size_per_partition=self.intermediate_size_per_partition, + hidden_dim_unpadded=hidden_size, + intermediate_size_per_partition=intermediate_size_per_partition, + intermediate_size_per_partition_unpadded=intermediate_size_per_partition, num_local_experts=self.local_num_experts, num_logical_experts=self.logical_num_experts, moe_parallel_config=self.moe_parallel_config, @@ -567,13 +516,6 @@ def _get_quant_method() -> FusedMoEMethodBase: # for heuristic purposes, so it must be initialized first. self.quant_method: FusedMoEMethodBase = _get_quant_method() - # Quant methods (e.g. Mxfp4MoEMethod) may round up hidden_dim - # and intermediate_size in moe_config during __init__. Sync - # self.hidden_size so downstream consumers (e.g. LoRA) see the - # padded value. - if self.moe_config.hidden_dim != self.hidden_size: - self.hidden_size = self.moe_config.hidden_dim - if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike(): raise NotImplementedError( "is_act_and_mul=False is supported only for CUDA and ROCm for now" @@ -591,11 +533,24 @@ def _get_quant_method() -> FusedMoEMethodBase: f"EPLB is not supported {self.quant_method.__class__.__name__}." ) + # Round up hidden size and update moe_config. + hidden_size, intermediate_size_per_partition = ( + self.quant_method.maybe_roundup_sizes( + hidden_size, + intermediate_size_per_partition, + moe_in_dtype, + self.moe_parallel_config, + ) + ) + self.moe_config.hidden_dim = hidden_size + self.moe_config.intermediate_size_per_partition = ( + intermediate_size_per_partition + ) + moe_quant_params = { "num_experts": self.local_num_experts, - "hidden_size": self.hidden_size, - "unpadded_hidden_size": unpadded_hidden_size, - "intermediate_size_per_partition": self.intermediate_size_per_partition, + "hidden_size": hidden_size, + "intermediate_size_per_partition": intermediate_size_per_partition, "params_dtype": params_dtype, "weight_loader": self.weight_loader, "global_num_experts": self.global_num_experts, @@ -933,9 +888,17 @@ def _load_w13( # Only narrow if the loaded_weight is not a scalar (0-dim tensor) # and we're not loading the full weight if not load_full and loaded_weight.ndim > 0: - loaded_weight = loaded_weight.narrow( - shard_dim, shard_size * tp_rank, shard_size - ) + # Handle padding: loaded_weight might be smaller than shard_size on last + # TP rank + start_offset = shard_size * tp_rank + available = loaded_weight.shape[shard_dim] - start_offset + if available <= 0: + # If there is no available weight to load for this TP rank + # (can happen on last TP rank with padding), we can skip + # loading and return early + return + narrow_size = min(shard_size, available) + loaded_weight = loaded_weight.narrow(shard_dim, start_offset, narrow_size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": @@ -944,6 +907,13 @@ def _load_w13( else: assert shard_id == "w3" expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + + # Handle padding: if loaded_weight is smaller than expert_data (can happen + # on last TP shard with padding), copy to top-left corner + if expert_data.shape != loaded_weight.shape: + expert_data = expert_data[ + : loaded_weight.shape[0], : loaded_weight.shape[1] + ] expert_data.copy_(loaded_weight) def _load_w2( @@ -961,10 +931,24 @@ def _load_w2( # Only narrow if the loaded_weight is not a scalar (0-dim tensor) # and we're not loading the full weight if not load_full and loaded_weight.ndim > 0: - loaded_weight = loaded_weight.narrow( - shard_dim, shard_size * tp_rank, shard_size - ) + # Handle padding: loaded_weight might be smaller than shard_size on last + # TP rank + start_offset = shard_size * tp_rank + available = loaded_weight.shape[shard_dim] - start_offset + if available <= 0: + # If there is no available weight to load for this TP rank + # (can happen on last TP rank with padding), we can skip + # loading and return early + return + narrow_size = min(shard_size, available) + loaded_weight = loaded_weight.narrow(shard_dim, start_offset, narrow_size) # w2, down_proj: Load into only logical weight of w2. + # Handle padding: if loaded_weight is smaller than expert_data (can happen + # on last TP shard with padding), copy to top-left corner + if expert_data.shape != loaded_weight.shape: + expert_data = expert_data[ + : loaded_weight.shape[0], : loaded_weight.shape[1] + ] expert_data.copy_(loaded_weight) def _load_single_value( @@ -1549,6 +1533,14 @@ def make_expert_params_mapping( ] ] + @property + def hidden_size(self) -> int: + return self.moe_config.hidden_dim + + @property + def intermediate_size_per_partition(self) -> int: + return self.moe_config.intermediate_size_per_partition + def extra_repr(self) -> str: s = ( f"global_num_experts={self.global_num_experts}, " diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index 09725631d67b..77df6edf9e94 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -779,8 +779,6 @@ def make_mxfp4_moe_quant_config( w2_scale: Union[torch.Tensor, "PrecisionConfig"], w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, - hidden_pad: int = 0, - intermediate_pad: int = 0, ) -> FusedMoEQuantConfig | None: """Create a FusedMoEQuantConfig for the given MXFP4 backend.""" if mxfp4_backend in ( @@ -802,16 +800,12 @@ def make_mxfp4_moe_quant_config( Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, Mxfp4MoeBackend.CK, ): - config = mxfp4_w4a16_moe_quant_config( + return mxfp4_w4a16_moe_quant_config( w1_bias=w1_bias, w2_bias=w2_bias, w1_scale=w1_scale, w2_scale=w2_scale, ) - if mxfp4_backend == Mxfp4MoeBackend.CK: - config.hidden_pad = hidden_pad - config.intermediate_pad = intermediate_pad - return config else: return ocp_mx_moe_quant_config( quant_dtype="mxfp4", diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 98c98b7c5412..d24bda101ffa 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, ) @@ -186,6 +187,7 @@ def rocm_aiter_fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + moe_config: FusedMoEConfig, activation: MoEActivation = MoEActivation.SILU, apply_router_weight_on_input: bool = False, expert_map: torch.Tensor | None = None, @@ -276,6 +278,17 @@ def rocm_aiter_fused_experts( "Only support topk=1 when `apply_router_weight_on_input` is True" ) + # Compute padding on-the-fly for CK MXFP4 kernels + hidden_pad = 0 + intermediate_pad = 0 + assert moe_config.hidden_dim_unpadded is not None + assert moe_config.intermediate_size_per_partition_unpadded is not None + hidden_pad = hidden_states.shape[1] - moe_config.hidden_dim_unpadded + intermediate_pad = ( + moe_config.intermediate_size_per_partition + - moe_config.intermediate_size_per_partition_unpadded + ) + return rocm_aiter_ops.fused_moe( hidden_states, w1, @@ -292,8 +305,8 @@ def rocm_aiter_fused_experts( doweight_stage1=apply_router_weight_on_input, num_local_tokens=num_local_tokens, output_dtype=output_dtype, - hidden_pad=quant_config.hidden_pad, - intermediate_pad=quant_config.intermediate_pad, + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None, bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None, ) @@ -419,6 +432,7 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, quant_config=self.quant_config, + moe_config=self.moe_config, a1q_scale=a1q_scale, num_local_tokens=num_local_tokens, output_dtype=output.dtype, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 5e14d1712aec..1b8b726d9714 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -715,8 +715,6 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - layer.intermediate_size_per_partition = intermediate_size_per_partition - layer.hidden_size = hidden_size layer.num_experts = num_experts layer.orig_dtype = params_dtype layer.weight_block_size = None @@ -2274,8 +2272,6 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - layer.intermediate_size_per_partition = intermediate_size_per_partition - layer.hidden_size = hidden_size layer.num_experts = num_experts layer.orig_dtype = params_dtype layer.weight_block_size = None diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d758edd9ca50..f055cf7bb81f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -667,8 +667,6 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - layer.intermediate_size_per_partition = intermediate_size_per_partition - layer.hidden_size = hidden_size layer.num_experts = num_experts layer.orig_dtype = params_dtype layer.weight_block_size = None @@ -1006,8 +1004,6 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - layer.intermediate_size_per_partition = intermediate_size_per_partition - layer.hidden_size = hidden_size layer.num_experts = num_experts layer.orig_dtype = params_dtype layer.weight_block_size = None diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 8ce5432fed83..c69e99a68126 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -13,6 +13,7 @@ ) from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( @@ -107,18 +108,6 @@ def __init__(self, moe: FusedMoEConfig): self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} self.moe_kernel: mk.FusedMoEKernel | None = None - # Round up dims once based on backend. This mutates the shared - # FusedMoEConfig in-place so that create_weights() and all - # downstream code see the padded dimensions. This must happen - # before create_weights() is called. - self.moe.hidden_dim, self.moe.intermediate_size_per_partition = ( - mxfp4_round_up_hidden_size_and_intermediate_size( - self.mxfp4_backend, - self.moe.hidden_dim, - self.moe.intermediate_size_per_partition, - ) - ) - # Used for triton kernel precision configs self.w13_precision_config = None self.w2_precision_config = None @@ -129,6 +118,23 @@ def skip_forward_padding(self) -> bool: # so can skip the padding in the forward before applying the moe method return self.mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8 + def maybe_roundup_sizes( + self, + hidden_size: int, + intermediate_size_per_partition: int, + act_dtype: torch.dtype, + moe_parallel_config: FusedMoEParallelConfig, + ) -> tuple[int, int]: + hidden_size, intermediate_size_per_partition = super().maybe_roundup_sizes( + hidden_size=hidden_size, + intermediate_size_per_partition=intermediate_size_per_partition, + act_dtype=act_dtype, + moe_parallel_config=moe_parallel_config, + ) + return mxfp4_round_up_hidden_size_and_intermediate_size( + self.mxfp4_backend, hidden_size, intermediate_size_per_partition + ) + def create_weights( self, layer: torch.nn.Module, @@ -143,32 +149,16 @@ def create_weights( scale_dtype = torch.uint8 mxfp4_block = 32 - # Use pre-rounded sizes from config - self.intermediate_size = intermediate_size_per_partition_after_pad = ( - self.moe.intermediate_size_per_partition - ) - self.hidden_size = hidden_size = self.moe.hidden_dim - - # Expose padded dimensions on the layer for LoRA and Marlin code - # that reads layer.hidden_size / layer.intermediate_size_per_partition. layer.params_dtype = params_dtype layer.num_experts = num_experts - layer.hidden_size = hidden_size - layer.intermediate_size_per_partition = ( - intermediate_size_per_partition_after_pad - ) - - # CK (gfx950) padding info for rocm_aiter_ops.fused_moe() - self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0) - self.intermediate_pad = ( - intermediate_size_per_partition_after_pad - intermediate_size_per_partition - ) + self.intermediate_size = intermediate_size_per_partition + self.hidden_size = hidden_size # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.zeros( num_experts, - 2 * intermediate_size_per_partition_after_pad, + 2 * intermediate_size_per_partition, hidden_size // 2, dtype=weight_dtype, ), @@ -180,7 +170,7 @@ def create_weights( w13_weight_scale = torch.nn.Parameter( torch.zeros( num_experts, - 2 * intermediate_size_per_partition_after_pad, + 2 * intermediate_size_per_partition, hidden_size // mxfp4_block, dtype=scale_dtype, ), @@ -194,7 +184,7 @@ def create_weights( torch.zeros( num_experts, hidden_size, - intermediate_size_per_partition_after_pad // 2, + intermediate_size_per_partition // 2, dtype=weight_dtype, ), requires_grad=False, @@ -206,7 +196,7 @@ def create_weights( torch.zeros( num_experts, hidden_size, - intermediate_size_per_partition_after_pad // mxfp4_block, + intermediate_size_per_partition // mxfp4_block, dtype=scale_dtype, ), requires_grad=False, @@ -218,7 +208,7 @@ def create_weights( w13_bias = torch.nn.Parameter( torch.zeros( num_experts, - 2 * intermediate_size_per_partition_after_pad, + 2 * intermediate_size_per_partition, dtype=torch.bfloat16, ), requires_grad=False, @@ -368,8 +358,6 @@ def get_fused_moe_quant_config( w2_scale=w2_scale, w1_bias=w1_bias, w2_bias=w2_bias, - hidden_pad=self.hidden_pad, - intermediate_pad=self.intermediate_pad, ) def select_gemm_impl( diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 68eb655664b0..a58ee5c44e00 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -18,6 +18,7 @@ MoEActivation, ) from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, mxfp4_w4a8_moe_quant_config, @@ -27,13 +28,13 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( Mxfp4MoeBackend, + mxfp4_round_up_hidden_size_and_intermediate_size, select_mxfp4_moe_backend, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - CK_MXFP4_MOE_DIM_ALIGNMENT, _swizzle_mxfp4, ) from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( @@ -49,7 +50,6 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -from vllm.utils.math_utils import round_up logger = init_logger(__name__) @@ -173,8 +173,6 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - layer.intermediate_size_per_partition = intermediate_size_per_partition - layer.hidden_size = hidden_size layer.num_experts = num_experts layer.orig_dtype = params_dtype layer.weight_block_size = None @@ -182,7 +180,7 @@ def create_weights( # WEIGHTS w13_weight = torch.nn.Parameter( - torch.empty( + torch.zeros( num_experts, 2 * intermediate_size_per_partition, hidden_size, @@ -194,7 +192,7 @@ def create_weights( set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter( - torch.empty( + torch.zeros( num_experts, hidden_size, intermediate_size_per_partition, @@ -461,6 +459,7 @@ def apply( activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input, quant_config=self.moe_quant_config, + moe_config=layer.moe_config, expert_map=layer.expert_map, ) elif self.use_marlin: @@ -527,7 +526,7 @@ def create_weights( ): params_dtype = torch.uint32 w13_weight = torch.nn.Parameter( - torch.empty( + torch.zeros( num_experts, 2 * intermediate_size_per_partition, hidden_size // 8, # INT32 packing for W4 @@ -536,7 +535,7 @@ def create_weights( requires_grad=False, ) w2_weight = torch.nn.Parameter( - torch.empty( + torch.zeros( num_experts, hidden_size, intermediate_size_per_partition // 8, # INT32 packing for W4 @@ -649,6 +648,7 @@ def apply( activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input, quant_config=self.moe_quant_config, + moe_config=layer.moe_config, expert_map=layer.expert_map, ) @@ -702,6 +702,9 @@ def __init__( self.mxfp4_backend: Mxfp4MoeBackend | None = None if self.ocp_mx_scheme == "w_mxfp4": self.mxfp4_backend, _ = select_mxfp4_moe_backend(moe) + elif self.ocp_mx_scheme.startswith("w_mxfp4"): + # TODO(bowenbao): refactor and introduce backends for other OCP MX schemes. + self.mxfp4_backend = Mxfp4MoeBackend.NONE if self.input_quant is not None: self.static_input_scales = not self.input_quant.get("is_dynamic") @@ -734,36 +737,11 @@ def __init__( self.emulate = ( not current_platform.supports_mx() or not self.ocp_mx_scheme.startswith("w_mxfp4") - ) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe) - - # CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension - # alignment requirements. When violated (e.g. MiniMax-M2.1 with - # TP=4 yields intermediate_size_per_partition=384), AITER raises: - # "device_gemm ... does not support this GEMM problem". - # Fall back to emulation in that case. - # For gpt_oss models, create_weights rounds up the dimensions - # internally, so the alignment check is skipped. - if ( - not self.emulate - and self.use_rocm_aiter_moe - and self.ocp_mx_scheme is not None - and self.ocp_mx_scheme.startswith("w_mxfp4") - and self.model_type != "gpt_oss" - and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0 - ): - logger.warning_once( - "AITER CK MXFP4 MoE GEMM does not support " - "intermediate_size_per_partition=%d (not a multiple of %d). " - "This typically happens when intermediate_size / " - "tensor_parallel_size produces an incompatible dimension. " - "Falling back to emulation mode. To avoid this overhead, " - "use a compatible tensor_parallel_size or set " - "VLLM_ROCM_USE_AITER_MOE=0.", - moe.intermediate_size_per_partition, - CK_MXFP4_MOE_DIM_ALIGNMENT, - ) - self.use_rocm_aiter_moe = False - self.emulate = True + ) and ( + self.mxfp4_backend is None + or self.mxfp4_backend is Mxfp4MoeBackend.NONE + or not self.use_rocm_aiter_moe + ) if self.emulate: logger.warning_once( @@ -780,6 +758,27 @@ def __init__( "The current mode supports native MoE MXFP4 computation" ) + def maybe_roundup_sizes( + self, + hidden_size: int, + intermediate_size_per_partition: int, + act_dtype: torch.dtype, + moe_parallel_config: FusedMoEParallelConfig, + ) -> tuple[int, int]: + hidden_size, intermediate_size_per_partition = super().maybe_roundup_sizes( + hidden_size=hidden_size, + intermediate_size_per_partition=intermediate_size_per_partition, + act_dtype=act_dtype, + moe_parallel_config=moe_parallel_config, + ) + if self.mxfp4_backend is not None: + hidden_size, intermediate_size_per_partition = ( + mxfp4_round_up_hidden_size_and_intermediate_size( + self.mxfp4_backend, hidden_size, intermediate_size_per_partition + ) + ) + return hidden_size, intermediate_size_per_partition + def get_packed_dim(self, dim: int, quant_dtype: str): if quant_dtype == "mxfp4": assert dim % 2 == 0 @@ -805,40 +804,12 @@ def create_weights( ) params_dtype = torch.uint8 - self.intermediate_size_per_partition = intermediate_size_per_partition - if self.model_type == "gpt_oss": - if current_platform.is_rocm(): - intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 256 - ) - else: - intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 64 - ) - else: - intermediate_size_per_partition_after_pad = intermediate_size_per_partition - - self.unpadded_hidden_size = extra_weight_attrs.get( - "unpadded_hidden_size", hidden_size - ) - - # On GFX950, the GFX950MXScaleLayout swizzle requires - # hidden_size to be a multiple of 256 (SCALE_K = hidden_size / 32 - # must be divisible by 8). Pad hidden_size for weight/scale - # allocation; the original value is preserved in unpadded_hidden_size. - # Only applies to the native (non-emulated) CK path on GFX950. - if ( - self.model_type == "gpt_oss" - and current_platform.is_rocm() - and not self.emulate - ): - hidden_size = round_up(hidden_size, 256) # WEIGHTS w13_weight = torch.nn.Parameter( - torch.empty( + torch.zeros( num_experts, - 2 * intermediate_size_per_partition_after_pad, + 2 * intermediate_size_per_partition, self.get_packed_dim(hidden_size, self.weight_dtype), dtype=params_dtype, ), @@ -849,12 +820,10 @@ def create_weights( set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter( - torch.empty( + torch.zeros( num_experts, hidden_size, - self.get_packed_dim( - intermediate_size_per_partition_after_pad, self.weight_dtype - ), + self.get_packed_dim(intermediate_size_per_partition, self.weight_dtype), dtype=params_dtype, ), requires_grad=False, @@ -867,7 +836,7 @@ def create_weights( w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, - 2 * intermediate_size_per_partition_after_pad, + 2 * intermediate_size_per_partition, hidden_size // OCP_MX_BLOCK_SIZE, dtype=params_dtype, ), @@ -877,7 +846,7 @@ def create_weights( torch.ones( num_experts, hidden_size, - intermediate_size_per_partition_after_pad // OCP_MX_BLOCK_SIZE, + intermediate_size_per_partition // OCP_MX_BLOCK_SIZE, dtype=params_dtype, ), requires_grad=False, @@ -892,7 +861,7 @@ def create_weights( w13_bias = torch.nn.Parameter( torch.zeros( num_experts, - 2 * intermediate_size_per_partition_after_pad, + 2 * intermediate_size_per_partition, dtype=torch.float32, ), requires_grad=False, @@ -1072,6 +1041,7 @@ def apply( topk_ids=topk_ids, activation=layer.activation, quant_config=self.moe_quant_config, + moe_config=layer.moe_config, expert_map=layer.expert_map, ) else: @@ -1204,6 +1174,8 @@ def apply_monolithic( triton_kernel_moe_forward, ) + assert self.moe.hidden_dim_unpadded is not None + assert self.moe.intermediate_size_per_partition_unpadded is not None return triton_kernel_moe_forward( hidden_states=x, w1=self.w13_weight_triton_tensor, @@ -1215,8 +1187,8 @@ def apply_monolithic( expert_map=expert_map, quant_config=self.moe_quant_config, apply_router_weight_on_input=layer.apply_router_weight_on_input, - unpadded_N_w1=self.intermediate_size_per_partition * 2, - unpadded_K_w1=self.unpadded_hidden_size, - unpadded_N_w2=self.unpadded_hidden_size, - unpadded_K_w2=self.intermediate_size_per_partition, + unpadded_N_w1=self.moe.intermediate_size_per_partition_unpadded * 2, + unpadded_K_w1=self.moe.hidden_dim_unpadded, + unpadded_N_w2=self.moe.hidden_dim_unpadded, + unpadded_K_w2=self.moe.intermediate_size_per_partition_unpadded, ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 5b4f7caa37d4..397442aecedf 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -348,7 +348,6 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( w13, w13_scale, w2, w2_scale, is_act_and_mul, min_alignment ) ) - layer.intermediate_size_per_partition = padded_intermediate layer.moe_config.intermediate_size_per_partition = padded_intermediate w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 271bcf168386..66827488ffed 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -439,7 +439,6 @@ def prepare_fp8_moe_layer_for_fi( layer.moe_config.is_act_and_mul, min_alignment, ) - layer.intermediate_size_per_partition = new_intermediate layer.moe_config.intermediate_size_per_partition = new_intermediate # FI kernels require W31 layout rather than W13.