Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions vllm/model_executor/models/ernie45_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
# limitations under the License.
"""Inference-only ErineMoE model compatible with HuggingFace weights."""

from collections.abc import Iterable
import typing
from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any

Expand Down Expand Up @@ -139,10 +140,10 @@ def __init__(

# Load balancing settings.
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb

self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
Expand Down Expand Up @@ -426,8 +427,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.vocab_size = config.vocab_size
self.config = config
parallel_config = vllm_config.parallel_config
eplb_config = parallel_config.eplb_config
enable_eplb = parallel_config.enable_eplb
self.num_redundant_experts = parallel_config.num_redundant_experts

self.num_redundant_experts = eplb_config.num_redundant_experts

if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
Expand Down Expand Up @@ -570,20 +573,27 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
name_mapped.endswith(".bias") or name_mapped.endswith("_bias")
) and name_mapped not in params_dict:
continue
param = params_dict[name]

weight_loader = param.weight_loader
weight_loader(
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
loaded_weight,
name,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
break
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
Expand Down