Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 175 additions & 62 deletions vllm/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torch import nn
from transformers.models.glm4_moe import Glm4MoeConfig

from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
Expand Down Expand Up @@ -168,7 +169,28 @@ def __init__(
self.physical_expert_start + self.n_local_physical_experts
)

if config.n_shared_experts is not None:
# AITER fused MoE / fused shared-expert (FSE) gates. When AITER
# MoE is active, the `routed_scaling_factor` must be applied
# *inside* the kernel (per routed slot) rather than post-hoc to
# the whole MoE output - otherwise the FSE shared-expert slot
# (which the kernel inserts with unit weight) would also be
# scaled by `routed_scaling_factor`, producing a structural
# magnitude error in every MoE layer.
# See vllm/_aiter_ops.py::is_fusion_moe_shared_experts_enabled,
# vllm/model_executor/models/deepseek_v2.py L341 for the same
# pattern, and the equivalent ROCm/ATOM gate at
# atom/model_ops/topK.py.
self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
self.is_fusion_moe_shared_experts_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)

if (
config.n_shared_experts is None
or self.is_fusion_moe_shared_experts_enabled
):
self.shared_experts = None
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = Glm4MoeMLP(
hidden_size=config.hidden_size,
Expand All @@ -178,8 +200,6 @@ def __init__(
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
else:
self.shared_experts = None

self.experts = FusedMoE(
shared_experts=self.shared_experts,
Expand All @@ -194,12 +214,21 @@ def __init__(
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func="sigmoid",
# AITER applies `routed_scaling_factor` internally per routed
# slot, so it must NOT also be applied to the whole output
# (which would incorrectly scale the FSE shared-expert slot).
# Mirrors deepseek_v2.py.
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scale_to_output=True,
apply_routed_scale_to_output=not self.is_rocm_aiter_moe_enabled,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
router_logits_dtype=torch.float32,
n_shared_experts=(
config.n_shared_experts
if self.is_fusion_moe_shared_experts_enabled
else None
),
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -469,15 +498,33 @@ def forward(
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
# When AITER fused shared experts is enabled, the FusedMoE layer is
# widened by n_shared_experts slots that hold the (split) shared
# expert weights; the mapping must enumerate those slots too so the
# weight loader can route mlp.shared_experts.* tensors there.
num_experts = self.config.n_routed_experts
if (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
and self.config.n_shared_experts
):
num_experts += self.config.n_shared_experts
return fused_moe_make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_experts=num_experts,
)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# AITER fused shared-expert (FSE) weight-loader branch. When FSE is
# on, the FusedMoE layer was widened by n_shared_experts slots, and
# checkpoint tensors named `...mlp.shared_experts.{gate,up,down}_proj.*`
# must be split into n_shared_experts chunks and routed to the
# appended expert slots `...mlp.experts.{n_routed_experts + j}.*`.
rocm_aiter_moe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
Expand All @@ -494,6 +541,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue

is_fusion_moe_shared_experts_layer = (
rocm_aiter_moe_shared_expert_enabled
and ("mlp.shared_experts" in name)
)

for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
Expand All @@ -506,6 +559,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
# Under FSE we treat mlp.shared_experts.* as expert-style
# tensors (handled below) rather than stacked gate/up linears.
if is_fusion_moe_shared_experts_layer:
continue

name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
Expand All @@ -527,65 +584,121 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue

# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True

# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)

if is_pp_missing_parameter(name_mapped, self):
continue

param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader

# FSE split: if this is a widened mlp.shared_experts tensor,
# slice it into n_shared_experts chunks along the intermediate-
# size axis and synthesize per-slot expert names. For
# ColumnParallel (gate_proj / up_proj) the intermediate dim
# is dim 0; for RowParallel (down_proj) it's dim 1.
num_chunks = 1
split_dim = 0
chunk_size = 0
if is_fusion_moe_shared_experts_layer:
num_chunks = (
getattr(self.config, "n_shared_experts", 1) or 1
)
success = weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
split_dim = (
1
if (
"down_proj.weight" in name
and loaded_weight.ndim > 1
)
else 0
)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue

# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
total = loaded_weight.shape[split_dim]
if total % num_chunks != 0:
raise ValueError(
f"FSE shared-expert weight {name} has dim "
f"{total} along axis {split_dim} which is not "
f"divisible by n_shared_experts={num_chunks}."
)
chunk_size = total // num_chunks

for j in range(num_chunks):
chunk_name = name
weight_to_load = loaded_weight

if is_fusion_moe_shared_experts_layer:
chunk_slice = slice(
j * chunk_size, (j + 1) * chunk_size
)
if loaded_weight.ndim == 1:
weight_to_load = loaded_weight[chunk_slice]
elif split_dim == 0:
weight_to_load = loaded_weight[chunk_slice, :]
else:
weight_to_load = loaded_weight[:, chunk_slice]
# Synthesize an expert-style name so the expert
# params mapping above can route it via the
# FusedMoE expert-aware weight loader.
chunk_name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts + j}",
)

for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in chunk_name:
continue

# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True

# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)

if is_pp_missing_parameter(name_mapped, self):
continue

param = params_dict[name_mapped]
# We should ask the weight loader to return success
# or not here since otherwise we may skip experts
# with other available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
weight_to_load,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
if not is_fusion_moe_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue

# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if name is not None and not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)

return loaded_params

Expand Down
Loading
Loading