diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index afd0d1dd501a..5b07250c6e0c 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -4,6 +4,7 @@ import os import socket from collections.abc import Callable +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, overload import regex as re @@ -95,12 +96,29 @@ class EPLBConfig: - None: Auto-select backend ("torch_gloo" for async, "torch_nccl" for sync) """ + save_path: Path | None = None + """If set, save the cumulative per-logical-expert load tensor to this file + at every rearrange step. The file is overwritten in place. The resulting + file is suitable for loading via `load_path` in a subsequent run with a + different EP topology.""" + + load_path: Path | None = None + """If set, load a per-logical-expert load tensor from this file at startup, + run the EPLB policy once against the live deploy topology, and apply the + resulting physical-to-logical mapping before warmup. Online rearrangement + is disabled for the rest of the run (mapping stays static).""" + @model_validator(mode="after") def _validate_eplb_config(self) -> Self: if self.use_async and self.policy != "default": raise ValueError("Async EPLB is only supported with the default policy.") if self.log_balancedness and self.log_balancedness_interval <= 0: raise ValueError("log_balancedness_interval must be greater than 0.") + if self.save_path is not None and self.load_path is not None: + raise ValueError( + "save_path and load_path cannot both be set: a run is either " + "recording stats or replaying them." + ) return self diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 1da39caccd80..21c7d1a2c2fb 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -29,8 +29,10 @@ import threading from collections.abc import Sequence from dataclasses import dataclass +from pathlib import Path import torch +from safetensors.torch import load_file, save_file from torch.distributed import ProcessGroup, all_reduce from vllm.config import ModelConfig, ParallelConfig @@ -278,6 +280,12 @@ def __init__(self, parallel_config: ParallelConfig, device: torch.device): newly started EP ranks may not have physical experts mapped yet. """ + self.cumulative_logical_load: dict[str, torch.Tensor] = {} + """ + Per-logical-expert load accumulated across rearrange windows since + startup, indexed by model name (matches `self.model_states` keys). + Held on CPU. Populated only when `eplb_config.save_path` is set. + """ if self.device.type == "cuda": self.cuda_device_index = self.device.index if self.cuda_device_index is None and torch.cuda.is_available(): @@ -657,6 +665,7 @@ def _init_should_record_tensor(self, model: "MixtureOfExperts") -> None: # type def rearrange( self, is_profile: bool = False, + load_initial: bool = False, rank_mapping: dict[int, int] | None = None, ) -> torch.Tensor | None: """ @@ -687,35 +696,71 @@ def rearrange( "(profile)" if is_profile else "", ) - # Map the physical expert load to global logical experts - global_expert_load_windows = [] - for eplb_model_state in self.model_states.values(): - expert_load_window = eplb_model_state.expert_load_window[ - :, :, : self.num_valid_physical_experts - ] - logical_expert_load_window = torch.zeros( - self.expert_load_window_size, - eplb_model_state.model.num_moe_layers, - eplb_model_state.model.num_logical_experts, - dtype=eplb_model_state.expert_load_window.dtype, - device=eplb_model_state.expert_load_window.device, - ) - logical_expert_load_window.scatter_add_( - dim=-1, - index=eplb_model_state.physical_to_logical_map[ - :, : self.num_valid_physical_experts + if load_initial: + load_path = self.parallel_config.eplb_config.load_path + assert load_path is not None + tensors = load_file(str(load_path)) + global_expert_load_windows = [] + for name, eplb_model_state in self.model_states.items(): + global_expert_load_windows.append( + tensors[name].to( + device=eplb_model_state.expert_load_window.device, + dtype=eplb_model_state.expert_load_window.dtype, + ) + ) + else: + # Map the physical expert load to global logical experts + global_expert_load_windows = [] + for eplb_model_state in self.model_states.values(): + expert_load_window = eplb_model_state.expert_load_window[ + :, :, : self.num_valid_physical_experts ] - .unsqueeze(0) - .expand_as(expert_load_window) - .long(), - src=expert_load_window, - ) + logical_expert_load_window = torch.zeros( + self.expert_load_window_size, + eplb_model_state.model.num_moe_layers, + eplb_model_state.model.num_logical_experts, + dtype=eplb_model_state.expert_load_window.dtype, + device=eplb_model_state.expert_load_window.device, + ) + logical_expert_load_window.scatter_add_( + dim=-1, + index=eplb_model_state.physical_to_logical_map[ + :, : self.num_valid_physical_experts + ] + .unsqueeze(0) + .expand_as(expert_load_window) + .long(), + src=expert_load_window, + ) - global_expert_load_window = logical_expert_load_window.sum(dim=0) - global_expert_load_windows.append(global_expert_load_window) + global_expert_load_window = logical_expert_load_window.sum(dim=0) + global_expert_load_windows.append(global_expert_load_window) # Perform all-reduce to get the expert load across all ranks for each model global_expert_load_windows = self._allreduce_list(global_expert_load_windows) + save_path = self.parallel_config.eplb_config.save_path + if save_path is not None and not is_profile: + for (name, _), global_load in zip( + self.model_states.items(), global_expert_load_windows + ): + load_cpu = global_load.detach().to(dtype=torch.float32, device="cpu") + if name in self.cumulative_logical_load: + self.cumulative_logical_load[name].add_(load_cpu) + else: + self.cumulative_logical_load[name] = load_cpu.clone() + if get_ep_group().device_group.rank() == 0: + self._save_logical_load( + self.cumulative_logical_load, Path(save_path) + ) + logger.info( + "Saved EPLB cumulative logical load to %s.", save_path + ) + # Recording-only mode: skip physical rearrangement so we capture + # the baseline on the unchanged topology and avoid NCCL p2p OOM + # from move_to_buffer (the apply path allocates large transfer + # buffers that don't fit alongside model weights + KV cache). + return None + # TODO(bowen): Treat differently for prefill and decode nodes eplb_model_state = next(iter(self.model_states.values())) model = eplb_model_state.model @@ -877,6 +922,15 @@ def _sync_load_pass(self) -> list[torch.Tensor]: load_pass_list.append(eplb_model_state.expert_load_pass.clone()) return self._allreduce_list(load_pass_list) + @staticmethod + def _save_logical_load( + tensors: dict[str, torch.Tensor], + path: Path, + ) -> None: + """Write per-logical-expert load tensors to `path` as safetensors.""" + path.parent.mkdir(parents=True, exist_ok=True) + save_file(tensors, str(path), metadata={"version": "1"}) + @classmethod def from_mapping( cls, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0362011a6e6d..c13c2d28da2a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3096,6 +3096,8 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: """ if not self.parallel_config.enable_eplb or self.eep_eplb_suppressed: return + if self.eplb_state is None: + return assert self.eplb_state is not None model = self.get_model() @@ -4889,6 +4891,16 @@ def load_model(self, load_dummy_weights: bool = False) -> None: self.get_model(), "requires_sequential_video_encoding" ) # Temporary hack for dynamic res video w/o support for bs>1 yet + if ( + is_mixture_of_experts(self.model) + and self.parallel_config.enable_eplb + and not load_dummy_weights + and self.eplb_state is not None + and self.parallel_config.eplb_config.load_path is not None + ): + self.eplb_state.rearrange(load_initial=True) + self.eplb_state = None + if ( is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb