diff --git a/tests/model_executor/model_loader/test_ep_weight_filter.py b/tests/model_executor/model_loader/test_ep_weight_filter.py new file mode 100644 index 000000000000..2ac38192a4b0 --- /dev/null +++ b/tests/model_executor/model_loader/test_ep_weight_filter.py @@ -0,0 +1,361 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for EP weight filtering during model loading.""" + +import glob +import tempfile + +import huggingface_hub.constants +import pytest +import torch + +from vllm.model_executor.model_loader.ep_weight_filter import ( + compute_local_expert_ids, + parse_expert_id, + should_skip_weight, +) +from vllm.model_executor.model_loader.weight_utils import ( + safetensors_weights_iterator, +) + +# --------------------------------------------------------------------------- +# Unit tests for parse_expert_id +# --------------------------------------------------------------------------- + + +class TestParseExpertId: + def test_routed_expert(self): + name = "model.layers.0.mlp.experts.42.gate_proj.weight" + assert parse_expert_id(name) == 42 + + def test_large_expert_id(self): + name = "model.layers.60.mlp.experts.383.down_proj.weight" + assert parse_expert_id(name) == 383 + + def test_shared_expert(self): + # Shared experts use a different naming convention in most models + name = "model.layers.0.mlp.shared_experts.gate_proj.weight" + assert parse_expert_id(name) is None + + def test_attention_weight(self): + name = "model.layers.0.self_attn.q_proj.weight" + assert parse_expert_id(name) is None + + def test_embedding(self): + name = "model.embed_tokens.weight" + assert parse_expert_id(name) is None + + def test_layernorm(self): + name = "model.layers.0.input_layernorm.weight" + assert parse_expert_id(name) is None + + def test_fused_3d_expert(self): + # 3D fused-expert tensors (e.g. gpt-oss) have no numeric expert id. + # They must NOT be filtered — slicing happens later in weight_loader. + name = "model.layers.0.mlp.experts.gate_proj.weight" + assert parse_expert_id(name) is None + + def test_fused_3d_expert_down_proj(self): + name = "model.layers.10.mlp.experts.down_proj.weight" + assert parse_expert_id(name) is None + + def test_expert_scale(self): + # NVFP4 quantized models have scale tensors for experts + name = "model.layers.5.mlp.experts.100.gate_proj.weight_scale" + assert parse_expert_id(name) == 100 + + def test_expert_zero_id(self): + name = "model.layers.0.mlp.experts.0.up_proj.weight" + assert parse_expert_id(name) == 0 + + +# --------------------------------------------------------------------------- +# Unit tests for compute_local_expert_ids +# --------------------------------------------------------------------------- + + +class TestComputeLocalExpertIds: + def test_ep_disabled(self): + assert compute_local_expert_ids(64, ep_size=1, ep_rank=0) is None + + def test_even_split(self): + # 64 experts, EP=8 → 8 per rank + ids = compute_local_expert_ids(64, ep_size=8, ep_rank=0) + assert ids == set(range(0, 8)) + + ids = compute_local_expert_ids(64, ep_size=8, ep_rank=7) + assert ids == set(range(56, 64)) + + def test_uneven_split(self): + # 10 experts, EP=3 → ranks get 4, 3, 3 + ids_0 = compute_local_expert_ids(10, ep_size=3, ep_rank=0) + ids_1 = compute_local_expert_ids(10, ep_size=3, ep_rank=1) + ids_2 = compute_local_expert_ids(10, ep_size=3, ep_rank=2) + + assert len(ids_0) == 4 + assert len(ids_1) == 3 + assert len(ids_2) == 3 + # All experts covered, no overlap + assert ids_0 | ids_1 | ids_2 == set(range(10)) + assert ids_0.isdisjoint(ids_1) + assert ids_1.isdisjoint(ids_2) + + def test_384_experts_ep8(self): + # Kimi-K2.5 config: 384 experts, EP=8 + for rank in range(8): + ids = compute_local_expert_ids(384, ep_size=8, ep_rank=rank) + assert len(ids) == 48 + + # All experts covered + all_ids = set() + for rank in range(8): + ids = compute_local_expert_ids(384, ep_size=8, ep_rank=rank) + all_ids |= ids + assert all_ids == set(range(384)) + + def test_384_experts_ep16(self): + for rank in range(16): + ids = compute_local_expert_ids(384, ep_size=16, ep_rank=rank) + assert len(ids) == 24 + + def test_384_experts_ep24(self): + # 384 / 24 = 16 exactly + for rank in range(24): + ids = compute_local_expert_ids(384, ep_size=24, ep_rank=rank) + assert len(ids) == 16 + + # round_robin placement tests + + def test_round_robin_basic(self): + # 8 experts, EP=2: rank 0 → {0,2,4,6}, rank 1 → {1,3,5,7} + rr = "round_robin" + ids_0 = compute_local_expert_ids(8, 2, 0, placement=rr) + ids_1 = compute_local_expert_ids(8, 2, 1, placement=rr) + assert ids_0 == {0, 2, 4, 6} + assert ids_1 == {1, 3, 5, 7} + + def test_round_robin_full_coverage(self): + # 384 experts, EP=8: all experts covered, no overlap + rr = "round_robin" + all_ids: set[int] = set() + for rank in range(8): + ids = compute_local_expert_ids(384, 8, rank, placement=rr) + assert ids is not None and len(ids) == 48 + assert all_ids.isdisjoint(ids) + all_ids |= ids + assert all_ids == set(range(384)) + + def test_round_robin_uneven(self): + # 10 experts, EP=3: rank 0→{0,3,6,9}, rank 1→{1,4,7}, rank 2→{2,5,8} + rr = "round_robin" + ids_0 = compute_local_expert_ids(10, 3, 0, placement=rr) + ids_1 = compute_local_expert_ids(10, 3, 1, placement=rr) + ids_2 = compute_local_expert_ids(10, 3, 2, placement=rr) + assert ids_0 == {0, 3, 6, 9} + assert ids_1 == {1, 4, 7} + assert ids_2 == {2, 5, 8} + assert ids_0 | ids_1 | ids_2 == set(range(10)) + + +# --------------------------------------------------------------------------- +# Unit tests for should_skip_weight +# --------------------------------------------------------------------------- + + +class TestShouldSkipWeight: + def setup_method(self): + # Simulate EP=8, rank=0 → experts 0-47 + self.local_ids = compute_local_expert_ids(384, ep_size=8, ep_rank=0) + + def test_no_filter(self): + assert not should_skip_weight("anything", None) + + def test_dense_not_skipped(self): + assert not should_skip_weight( + "model.layers.0.self_attn.q_proj.weight", self.local_ids + ) + + def test_local_expert_not_skipped(self): + assert not should_skip_weight( + "model.layers.0.mlp.experts.10.gate_proj.weight", self.local_ids + ) + + def test_remote_expert_skipped(self): + assert should_skip_weight( + "model.layers.0.mlp.experts.200.gate_proj.weight", self.local_ids + ) + + def test_boundary_expert(self): + # Expert 47 is local (last one), 48 is not + assert not should_skip_weight( + "model.layers.0.mlp.experts.47.gate_proj.weight", self.local_ids + ) + assert should_skip_weight( + "model.layers.0.mlp.experts.48.gate_proj.weight", self.local_ids + ) + + def test_shared_expert_not_skipped(self): + assert not should_skip_weight( + "model.layers.0.mlp.shared_experts.gate_proj.weight", self.local_ids + ) + + def test_embedding_not_skipped(self): + assert not should_skip_weight("model.embed_tokens.weight", self.local_ids) + + def test_fused_3d_expert_not_skipped(self): + # 3D fused-expert tensors (gpt-oss style) have no numeric id. + # Must not be skipped — weight_loader handles slicing later. + assert not should_skip_weight( + "model.layers.0.mlp.experts.gate_proj.weight", self.local_ids + ) + + +# --------------------------------------------------------------------------- +# Integration test: safetensors_weights_iterator with EP filtering +# --------------------------------------------------------------------------- + + +class TestSafetensorsWeightsIteratorWithEpFilter: + """Verify that EP filtering produces a strict subset of unfiltered loading + and that all expected dense + local expert weights are present.""" + + @pytest.fixture(scope="class") + def gpt2_files(self): + """Download GPT-2 safetensors to a temp dir (shared across class).""" + with tempfile.TemporaryDirectory() as tmpdir: + huggingface_hub.constants.HF_HUB_OFFLINE = False + from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf, + ) + + download_weights_from_hf( + "openai-community/gpt2", + allow_patterns=["*.safetensors"], + cache_dir=tmpdir, + ) + files = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) + assert len(files) > 0 + yield files + + def test_no_filter_returns_all(self, gpt2_files): + """With local_expert_ids=None, all weights are returned (no MoE).""" + all_weights = dict(safetensors_weights_iterator(gpt2_files, False)) + filtered_weights = dict( + safetensors_weights_iterator(gpt2_files, False, local_expert_ids=None) + ) + assert set(all_weights.keys()) == set(filtered_weights.keys()) + + def test_empty_filter_skips_experts_only(self, gpt2_files): + """GPT-2 has no expert weights, so even an empty local_expert_ids + set should return all weights (all are dense).""" + all_weights = dict(safetensors_weights_iterator(gpt2_files, False)) + filtered_weights = dict( + safetensors_weights_iterator(gpt2_files, False, local_expert_ids=set()) + ) + # GPT-2 has no experts, so nothing should be filtered + assert set(all_weights.keys()) == set(filtered_weights.keys()) + + +class TestEpFilterOnSyntheticMoeWeights: + """Create synthetic safetensors files with expert-like naming and verify + that the filter correctly skips non-local experts.""" + + @pytest.fixture + def synthetic_moe_files(self, tmp_path): + """Create synthetic safetensors with expert-patterned tensor names.""" + from safetensors.torch import save_file + + tensors = {} + # Dense weights + tensors["model.embed_tokens.weight"] = torch.randn(100, 64) + tensors["model.layers.0.self_attn.q_proj.weight"] = torch.randn(64, 64) + tensors["model.layers.0.input_layernorm.weight"] = torch.randn(64) + # Expert weights: 8 experts + for expert_id in range(8): + tensors[f"model.layers.0.mlp.experts.{expert_id}.gate_proj.weight"] = ( + torch.randn(128, 64) + ) + tensors[f"model.layers.0.mlp.experts.{expert_id}.up_proj.weight"] = ( + torch.randn(128, 64) + ) + tensors[f"model.layers.0.mlp.experts.{expert_id}.down_proj.weight"] = ( + torch.randn(64, 128) + ) + # Shared expert (should never be filtered) + tensors["model.layers.0.mlp.shared_experts.gate_proj.weight"] = torch.randn( + 128, 64 + ) + + filepath = str(tmp_path / "model-00001-of-00001.safetensors") + save_file(tensors, filepath) + return [filepath], tensors + + def test_no_filter_returns_all(self, synthetic_moe_files): + files, expected = synthetic_moe_files + loaded = dict(safetensors_weights_iterator(files, False)) + assert set(loaded.keys()) == set(expected.keys()) + + def test_ep2_rank0_gets_half_experts(self, synthetic_moe_files): + files, expected = synthetic_moe_files + # EP=2, rank=0 → experts 0-3 + local_ids = compute_local_expert_ids(8, ep_size=2, ep_rank=0) + loaded = dict( + safetensors_weights_iterator(files, False, local_expert_ids=local_ids) + ) + + # Should have all dense + shared + experts 0-3 only + for name in loaded: + eid = parse_expert_id(name) + if eid is not None: + assert eid in local_ids, f"Non-local expert {eid} was loaded" + + # Check expert count: 4 experts × 3 weights = 12 + expert_names = [n for n in loaded if parse_expert_id(n) is not None] + assert len(expert_names) == 4 * 3 + + # Check all dense weights present + assert "model.embed_tokens.weight" in loaded + assert "model.layers.0.self_attn.q_proj.weight" in loaded + assert "model.layers.0.input_layernorm.weight" in loaded + assert "model.layers.0.mlp.shared_experts.gate_proj.weight" in loaded + + def test_ep2_rank1_gets_other_half(self, synthetic_moe_files): + files, expected = synthetic_moe_files + local_ids = compute_local_expert_ids(8, ep_size=2, ep_rank=1) + loaded = dict( + safetensors_weights_iterator(files, False, local_expert_ids=local_ids) + ) + + expert_names = [n for n in loaded if parse_expert_id(n) is not None] + assert len(expert_names) == 4 * 3 + for name in expert_names: + assert parse_expert_id(name) in local_ids + + def test_ep8_each_rank_gets_one_expert(self, synthetic_moe_files): + files, _ = synthetic_moe_files + all_expert_names = set() + for rank in range(8): + local_ids = compute_local_expert_ids(8, ep_size=8, ep_rank=rank) + loaded = dict( + safetensors_weights_iterator(files, False, local_expert_ids=local_ids) + ) + expert_names = {n for n in loaded if parse_expert_id(n) is not None} + # 1 expert × 3 weights + assert len(expert_names) == 3 + all_expert_names |= expert_names + + # All 8 experts × 3 weights covered across ranks + assert len(all_expert_names) == 24 + + def test_tensor_values_match(self, synthetic_moe_files): + """Filtered tensors have identical values to unfiltered ones.""" + files, _ = synthetic_moe_files + all_weights = dict(safetensors_weights_iterator(files, False)) + + local_ids = compute_local_expert_ids(8, ep_size=2, ep_rank=0) + filtered = dict( + safetensors_weights_iterator(files, False, local_expert_ids=local_ids) + ) + + for name, tensor in filtered.items(): + assert torch.equal(tensor, all_weights[name]), f"Tensor mismatch for {name}" diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 55c57adf9ec8..693bb2987d31 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -16,6 +16,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.ep_weight_filter import ( + compute_local_expert_ids, +) from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, @@ -70,6 +73,7 @@ class Source: def __init__(self, load_config: LoadConfig): super().__init__(load_config) + self.local_expert_ids: set[int] | None = None extra_config = load_config.model_loader_extra_config allowed_keys = {"enable_multithread_load", "num_threads"} @@ -243,6 +247,7 @@ def _get_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, self.load_config.safetensors_load_strategy, + local_expert_ids=self.local_expert_ids, ) else: if extra_config.get("enable_multithread_load"): @@ -296,6 +301,58 @@ def download_model(self, model_config: ModelConfig) -> None: allow_patterns_overrides=None, ) + def _init_ep_weight_filter(self, model_config: ModelConfig) -> None: + """Compute local expert ids for EP weight filtering. + + When expert parallelism is active, each rank only needs a subset of + expert weights. By computing the set upfront we can skip non-local + expert tensors *before* reading them from disk. + """ + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + + if not (model_config.is_moe and parallel_config.enable_expert_parallel): + return + + num_experts = model_config.get_num_experts() + if num_experts <= 0: + return + + # EP size/rank computation mirrors FusedMoEParallelConfig.make(): + # ep_size = dp_size * pcp_size * tp_size (flattened) + # ep_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank + from vllm.distributed import ( + get_dp_group, + get_pcp_group, + get_tensor_model_parallel_rank, + ) + + dp_size = parallel_config.data_parallel_size + tp_size = parallel_config.tensor_parallel_size + pcp_size = parallel_config.prefill_context_parallel_size + dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 + tp_rank = get_tensor_model_parallel_rank() if tp_size > 1 else 0 + pcp_rank = get_pcp_group().rank_in_group if pcp_size > 1 else 0 + ep_size = dp_size * pcp_size * tp_size + ep_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank + + self.local_expert_ids = compute_local_expert_ids( + num_experts, + ep_size, + ep_rank, + placement=parallel_config.expert_placement_strategy, + ) + if self.local_expert_ids is not None: + logger.info_once( + "EP weight filter: ep_size=%d, ep_rank=%d, loading %d/%d experts", + ep_size, + ep_rank, + len(self.local_expert_ids), + num_experts, + ) + @instrument(span_name="Load weights") def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: if model_config.quantization == "torchao": @@ -307,6 +364,8 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: ): self.load_config.safetensors_load_strategy = "torchao" + self._init_ep_weight_filter(model_config) + weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights(self.get_all_weights(model_config, model)) diff --git a/vllm/model_executor/model_loader/ep_weight_filter.py b/vllm/model_executor/model_loader/ep_weight_filter.py new file mode 100644 index 000000000000..1ef7f0174511 --- /dev/null +++ b/vllm/model_executor/model_loader/ep_weight_filter.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Filter out non-local expert weights during loading to avoid redundant I/O. + +In DP+EP deployments each rank only needs its own expert shard. Skipping +non-local expert tensors *before* they are read from disk eliminates the +majority of storage I/O for MoE models (experts typically account for +~85-90 % of total weight bytes). +""" + +import regex as re + +# Matches per-expert weight names like ".experts.42.gate_proj.weight". +# Does NOT match 3D fused-expert names like ".experts.gate_proj.weight" +# (no numeric id) — those are intentionally left unfiltered so the full +# tensor is loaded and sliced later by FusedMoE.weight_loader. +_EXPERT_ID_RE = re.compile(r"\.experts\.(\d+)\.") + + +def parse_expert_id(weight_name: str) -> int | None: + """Return the expert id embedded in *weight_name*, or ``None`` if it is + not an per-expert weight. + + Returns ``None`` for dense weights (attention, layernorm, embedding), + shared experts, and 3D fused-expert tensors where all experts are stored + in a single tensor without a numeric expert id in the name.""" + m = _EXPERT_ID_RE.search(weight_name) + return int(m.group(1)) if m else None + + +def compute_local_expert_ids( + num_experts: int, + ep_size: int, + ep_rank: int, + placement: str = "linear", +) -> set[int] | None: + """Compute the set of global expert ids owned by *ep_rank*. + + Returns ``None`` when EP is not active (``ep_size <= 1``), meaning all + experts are local and no filtering should be performed. + + The distribution logic mirrors + :func:`vllm.model_executor.layers.fused_moe.layer.determine_expert_map`. + + Args: + placement: ``"linear"`` for contiguous assignment, + ``"round_robin"`` for interleaved assignment. + """ + if ep_size <= 1: + return None + + if placement == "linear": + base = num_experts // ep_size + remainder = num_experts % ep_size + start = ep_rank * base + min(ep_rank, remainder) + local_count = base + (1 if ep_rank < remainder else 0) + return set(range(start, start + local_count)) + elif placement == "round_robin": + return set(range(ep_rank, num_experts, ep_size)) + else: + raise ValueError(f"Unknown expert placement strategy: {placement}") + + +def should_skip_weight( + weight_name: str, + local_expert_ids: set[int] | None, +) -> bool: + """Return ``True`` if *weight_name* is an expert weight that does not + belong to the local rank and should be skipped during loading.""" + if local_expert_ids is None: + return False + eid = parse_expert_id(weight_name) + if eid is None: + # Not an expert weight (dense / shared-expert / embedding) → keep. + return False + return eid not in local_expert_ids diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index ff0214ff55be..0a67a6a42aba 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -35,6 +35,9 @@ QuantizationConfig, get_quantization_config, ) +from vllm.model_executor.model_loader.ep_weight_filter import ( + should_skip_weight, +) from vllm.platforms import current_platform from vllm.tracing import instrument from vllm.utils.import_utils import PlaceholderModule @@ -721,8 +724,14 @@ def safetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, safetensors_load_strategy: str = "lazy", + local_expert_ids: set[int] | None = None, ) -> Generator[tuple[str, torch.Tensor], None, None]: - """Iterate over the weights in the model safetensor files.""" + """Iterate over the weights in the model safetensor files. + + When *local_expert_ids* is provided, expert weights not belonging to + this rank are skipped **before** reading from disk, which drastically + reduces storage I/O for MoE models under EP. + """ loading_desc = "Loading safetensors checkpoint shards" if safetensors_load_strategy == "eager": loading_desc += " (eager)" @@ -737,7 +746,9 @@ def safetensors_weights_iterator( if safetensors_load_strategy == "eager": with open(st_file, "rb") as f: state_dict = load(f.read()) - yield from state_dict.items() + for name, param in state_dict.items(): + if not should_skip_weight(name, local_expert_ids): + yield name, param elif safetensors_load_strategy == "torchao": # we can't load flattened torchao tensor subclasses directly into the model # instead we reconstruct the subclasses here before returning @@ -753,6 +764,8 @@ def safetensors_weights_iterator( with safe_open(st_file, framework="pt") as f: state_dict = {} for name in f.keys(): # noqa: SIM118 + if should_skip_weight(name, local_expert_ids): + continue state_dict[name] = f.get_tensor(name) # update with leftover tensor data from previous iteration, if any @@ -769,6 +782,8 @@ def safetensors_weights_iterator( else: with safe_open(st_file, framework="pt") as f: for name in f.keys(): # noqa: SIM118 + if should_skip_weight(name, local_expert_ids): + continue param = f.get_tensor(name) yield name, param