Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,6 @@ class FusedMoEQuantConfig:
_w1: FusedMoEQuantDesc
_w2: FusedMoEQuantDesc
is_nvfp4_scale_swizzled: bool = True
# CK MXFP4 (gfx950) padding info for rocm_aiter_ops.fused_moe()
hidden_pad: int = 0
intermediate_pad: int = 0

def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, (
Expand Down Expand Up @@ -1172,6 +1169,11 @@ class FusedMoEConfig:
# Defaults to in_dtype if not specified.
router_logits_dtype: torch.dtype | None = None

# Defaults to hidden_dim if not specified.
hidden_dim_unpadded: int | None = None
# Defaults to intermediate_size_per_partition if not specified.
intermediate_size_per_partition_unpadded: int | None = None

moe_backend: str = "auto"
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
has_bias: bool = False
Expand All @@ -1195,6 +1197,13 @@ def __post_init__(self):
if self.router_logits_dtype is None:
self.router_logits_dtype = self.in_dtype

if self.hidden_dim_unpadded is None:
self.hidden_dim_unpadded = self.hidden_dim
if self.intermediate_size_per_partition_unpadded is None:
self.intermediate_size_per_partition_unpadded = (
self.intermediate_size_per_partition
)

@property
def tp_size(self):
return self.moe_parallel_config.tp_size
Expand Down
33 changes: 33 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
Expand Down Expand Up @@ -65,6 +66,38 @@ def uses_weight_scale_2_pattern(self) -> bool:
"""
return False

def maybe_roundup_sizes(
self,
hidden_size: int,
intermediate_size_per_partition: int,
act_dtype: torch.dtype,
moe_parallel_config: FusedMoEParallelConfig,
) -> tuple[int, int]:
"""
Given layer hidden size and intermediate size per partition and MoE
configurations, round up hidden_size and intermediate_size_per_partition
if necessary.

Args:
hidden_size: Layer hidden-size
intermediate_size_per_partition: Intermediate size per partition for
the layer.
act_dtype: Data type of the layer activations.
moe_parallel_config: Fused MoE parallelization strategy configuration.

Return:
A tuple of (rounded_hidden_size, rounded_intermediate_size_per_partition),
where:
- rounded_hidden_size is the possibly rounded up hidden size.
- rounded_intermediate_size_per_partition is the possibly rounded
up intermediate size per partition.
"""
from .all2all_utils import maybe_roundup_layer_hidden_size

return maybe_roundup_layer_hidden_size(
hidden_size, act_dtype, moe_parallel_config
), intermediate_size_per_partition

def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,13 +428,9 @@ def triton_kernel_fused_mxfp4_w4a8_experts(
assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32

# Shape check: when weights are padded (e.g. hidden_size padded for
# GFX950 swizzle), unpadded_K_w1 carries the original dimension.
expected_K_w1 = unpadded_K_w1 if unpadded_K_w1 is not None else w1.shape[-2]
assert hidden_states.shape[-1] == expected_K_w1, (
f"hidden_states K={hidden_states.shape[-1]} != "
f"expected K={expected_K_w1} (w1 K={w1.shape[-2]})"
)
# Shape check: weights are padded (e.g. hidden_size padded for
# GFX950 swizzle).
assert hidden_states.shape[-1] == w1.shape[-2]
assert w2.shape[-1] == w1.shape[1]

E, _, N = w1.shape
Expand Down Expand Up @@ -494,12 +490,6 @@ def triton_kernel_fused_mxfp4_w4a8_experts(
unpadded_K=unpadded_K_w2,
)

# When hidden_size was padded for alignment (e.g. GFX950 swizzle),
# the kernel output has the padded dimension. Slice back to the
# original hidden_size so downstream layers see the expected shape.
if unpadded_N_w2 is not None and intermediate_cache3.shape[-1] != unpadded_N_w2:
intermediate_cache3 = intermediate_cache3[..., :unpadded_N_w2].contiguous()

return intermediate_cache3


Expand Down
134 changes: 63 additions & 71 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,42 +210,6 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
)


# TODO(rob): move this down to the kernel.
def maybe_roundup_hidden_size(
hidden_size: int,
act_dtype: torch.dtype,
moe_parallel_config: FusedMoEParallelConfig,
is_lora_enabled: bool,
model_type: str | None,
) -> int:
"""
Given layer hidden size and MoE configurations, round up hidden_size
if necessary.

Args:
hidden_size: Layer hidden-size
act_dtype: Data type of the layer activations.
moe_parallel_config: Fused MoE parallelization strategy configuration.
is_lora_enabled: True if the engine is enabled with LoRA. This
is used in the case of mxfp4 quantization in selecting the
MxFP4Backend.
model_type: for checking if gpt-oss

Return:
Rounded up hidden_size if rounding up is required based on the configs.
Original hidden size otherwise.
"""
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_roundup_layer_hidden_size,
)

hidden_size = maybe_roundup_layer_hidden_size(
hidden_size, act_dtype, moe_parallel_config
)

return hidden_size


# --8<-- [start:fused_moe]
@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
Expand Down Expand Up @@ -459,7 +423,7 @@ def __init__(
), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s."

assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize

Expand Down Expand Up @@ -501,28 +465,13 @@ def __init__(
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type

# Round up hidden size before creating moe_config.
# This way moe_config is created with the correct hidden_size from the start.
unpadded_hidden_size = hidden_size
self.model_type = (
self.vllm_config.model_config.hf_config.model_type
if self.vllm_config.model_config is not None
else None
)
hidden_size = maybe_roundup_hidden_size(
hidden_size=hidden_size,
act_dtype=moe_in_dtype,
moe_parallel_config=self.moe_parallel_config,
is_lora_enabled=vllm_config.lora_config is not None,
model_type=self.model_type,
)
self.hidden_size = hidden_size

self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
intermediate_size_per_partition=self.intermediate_size_per_partition,
hidden_dim_unpadded=hidden_size,
intermediate_size_per_partition=intermediate_size_per_partition,
intermediate_size_per_partition_unpadded=intermediate_size_per_partition,
num_local_experts=self.local_num_experts,
num_logical_experts=self.logical_num_experts,
moe_parallel_config=self.moe_parallel_config,
Expand Down Expand Up @@ -567,13 +516,6 @@ def _get_quant_method() -> FusedMoEMethodBase:
# for heuristic purposes, so it must be initialized first.
self.quant_method: FusedMoEMethodBase = _get_quant_method()

# Quant methods (e.g. Mxfp4MoEMethod) may round up hidden_dim
# and intermediate_size in moe_config during __init__. Sync
# self.hidden_size so downstream consumers (e.g. LoRA) see the
# padded value.
if self.moe_config.hidden_dim != self.hidden_size:
self.hidden_size = self.moe_config.hidden_dim

if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike():
raise NotImplementedError(
"is_act_and_mul=False is supported only for CUDA and ROCm for now"
Expand All @@ -591,11 +533,24 @@ def _get_quant_method() -> FusedMoEMethodBase:
f"EPLB is not supported {self.quant_method.__class__.__name__}."
)

# Round up hidden size and update moe_config.
hidden_size, intermediate_size_per_partition = (
self.quant_method.maybe_roundup_sizes(
hidden_size,
intermediate_size_per_partition,
moe_in_dtype,
self.moe_parallel_config,
)
)
self.moe_config.hidden_dim = hidden_size
self.moe_config.intermediate_size_per_partition = (
intermediate_size_per_partition
)

moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": self.hidden_size,
"unpadded_hidden_size": unpadded_hidden_size,
"intermediate_size_per_partition": self.intermediate_size_per_partition,
"hidden_size": hidden_size,
"intermediate_size_per_partition": intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
"global_num_experts": self.global_num_experts,
Expand Down Expand Up @@ -933,9 +888,17 @@ def _load_w13(
# Only narrow if the loaded_weight is not a scalar (0-dim tensor)
# and we're not loading the full weight
if not load_full and loaded_weight.ndim > 0:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
# Handle padding: loaded_weight might be smaller than shard_size on last
# TP rank
start_offset = shard_size * tp_rank
available = loaded_weight.shape[shard_dim] - start_offset
if available <= 0:
# If there is no available weight to load for this TP rank
# (can happen on last TP rank with padding), we can skip
# loading and return early
return
narrow_size = min(shard_size, available)
loaded_weight = loaded_weight.narrow(shard_dim, start_offset, narrow_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
Expand All @@ -944,6 +907,13 @@ def _load_w13(
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)

# Handle padding: if loaded_weight is smaller than expert_data (can happen
# on last TP shard with padding), copy to top-left corner
if expert_data.shape != loaded_weight.shape:
expert_data = expert_data[
: loaded_weight.shape[0], : loaded_weight.shape[1]
]
expert_data.copy_(loaded_weight)

def _load_w2(
Expand All @@ -961,10 +931,24 @@ def _load_w2(
# Only narrow if the loaded_weight is not a scalar (0-dim tensor)
# and we're not loading the full weight
if not load_full and loaded_weight.ndim > 0:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
# Handle padding: loaded_weight might be smaller than shard_size on last
# TP rank
start_offset = shard_size * tp_rank
available = loaded_weight.shape[shard_dim] - start_offset
if available <= 0:
# If there is no available weight to load for this TP rank
# (can happen on last TP rank with padding), we can skip
# loading and return early
return
narrow_size = min(shard_size, available)
loaded_weight = loaded_weight.narrow(shard_dim, start_offset, narrow_size)
# w2, down_proj: Load into only logical weight of w2.
# Handle padding: if loaded_weight is smaller than expert_data (can happen
# on last TP shard with padding), copy to top-left corner
if expert_data.shape != loaded_weight.shape:
expert_data = expert_data[
: loaded_weight.shape[0], : loaded_weight.shape[1]
]
expert_data.copy_(loaded_weight)

def _load_single_value(
Expand Down Expand Up @@ -1549,6 +1533,14 @@ def make_expert_params_mapping(
]
]

@property
def hidden_size(self) -> int:
return self.moe_config.hidden_dim

@property
def intermediate_size_per_partition(self) -> int:
return self.moe_config.intermediate_size_per_partition

def extra_repr(self) -> str:
s = (
f"global_num_experts={self.global_num_experts}, "
Expand Down
8 changes: 1 addition & 7 deletions vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,8 +779,6 @@ def make_mxfp4_moe_quant_config(
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
hidden_pad: int = 0,
intermediate_pad: int = 0,
) -> FusedMoEQuantConfig | None:
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
if mxfp4_backend in (
Expand All @@ -802,16 +800,12 @@ def make_mxfp4_moe_quant_config(
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.CK,
):
config = mxfp4_w4a16_moe_quant_config(
return mxfp4_w4a16_moe_quant_config(
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
if mxfp4_backend == Mxfp4MoeBackend.CK:
config.hidden_pad = hidden_pad
config.intermediate_pad = intermediate_pad
return config
else:
return ocp_mx_moe_quant_config(
quant_dtype="mxfp4",
Expand Down
18 changes: 16 additions & 2 deletions vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
Expand Down Expand Up @@ -186,6 +187,7 @@ def rocm_aiter_fused_experts(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
moe_config: FusedMoEConfig,
activation: MoEActivation = MoEActivation.SILU,
apply_router_weight_on_input: bool = False,
expert_map: torch.Tensor | None = None,
Expand Down Expand Up @@ -276,6 +278,17 @@ def rocm_aiter_fused_experts(
"Only support topk=1 when `apply_router_weight_on_input` is True"
)

# Compute padding on-the-fly for CK MXFP4 kernels
hidden_pad = 0
intermediate_pad = 0
assert moe_config.hidden_dim_unpadded is not None
assert moe_config.intermediate_size_per_partition_unpadded is not None
hidden_pad = hidden_states.shape[1] - moe_config.hidden_dim_unpadded
intermediate_pad = (
moe_config.intermediate_size_per_partition
- moe_config.intermediate_size_per_partition_unpadded
)

return rocm_aiter_ops.fused_moe(
hidden_states,
w1,
Expand All @@ -292,8 +305,8 @@ def rocm_aiter_fused_experts(
doweight_stage1=apply_router_weight_on_input,
num_local_tokens=num_local_tokens,
output_dtype=output_dtype,
hidden_pad=quant_config.hidden_pad,
intermediate_pad=quant_config.intermediate_pad,
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None,
bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None,
)
Expand Down Expand Up @@ -419,6 +432,7 @@ def apply(
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
quant_config=self.quant_config,
moe_config=self.moe_config,
a1q_scale=a1q_scale,
num_local_tokens=num_local_tokens,
output_dtype=output.dtype,
Expand Down
Loading
Loading