Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ class ParallelConfig:
"""Whether the deployed model is MoE (if known)."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_ep_weight_filter: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To improve robustness and prevent user confusion from misconfiguration, it's a good practice to validate that enable_expert_parallel is enabled when enable_ep_weight_filter is used. Currently, if a user enables enable_ep_weight_filter without enable_expert_parallel, it will fail silently.

Consider adding a validation check in the _validate_parallel_config method of this class, similar to how enable_eplb is validated. This would raise an error for invalid combinations.

Example:

if self.enable_ep_weight_filter and not self.enable_expert_parallel:
    raise ValueError(
        "enable_expert_parallel must be True to use enable_ep_weight_filter."
    )

"""Skip non-local expert weights during model loading when expert
parallelism is active. Each rank only reads its own expert shard from
disk, which can drastically reduce storage I/O for MoE models with
per-expert weight tensors (e.g. DeepSeek, Mixtral, Kimi-K2.5). Has no
effect on 3D fused-expert checkpoints (e.g. GPT-OSS) or non-MoE
models."""
enable_eplb: bool = False
"""Enable expert parallelism load balancing for MoE layers."""
eplb_config: EPLBConfig = Field(default_factory=EPLBConfig)
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ class EngineArgs:
data_parallel_external_lb: bool = False
data_parallel_backend: DataParallelBackend = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
enable_ep_weight_filter: bool = ParallelConfig.enable_ep_weight_filter
moe_backend: MoEBackend = KernelConfig.moe_backend
all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep
Expand Down Expand Up @@ -902,6 +903,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"-ep",
**parallel_kwargs["enable_expert_parallel"],
)
parallel_group.add_argument(
"--enable-ep-weight-filter",
**parallel_kwargs["enable_ep_weight_filter"],
)
parallel_group.add_argument(
"--all2all-backend", **parallel_kwargs["all2all_backend"]
)
Expand Down Expand Up @@ -1731,6 +1736,7 @@ def create_engine_config(
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
is_moe_model=model_config.is_moe,
enable_expert_parallel=self.enable_expert_parallel,
enable_ep_weight_filter=self.enable_ep_weight_filter,
all2all_backend=self.all2all_backend,
enable_elastic_ep=self.enable_elastic_ep,
enable_dbo=self.enable_dbo,
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/model_loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,11 @@ def _init_ep_weight_filter(self, model_config: ModelConfig) -> None:
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config

if not (model_config.is_moe and parallel_config.enable_expert_parallel):
if not (
model_config.is_moe
and parallel_config.enable_expert_parallel
and parallel_config.enable_ep_weight_filter
):
return

num_experts = model_config.get_num_experts()
Expand Down
Loading