diff --git a/.buildkite/test_areas/basic_correctness.yaml b/.buildkite/test_areas/basic_correctness.yaml index 042734e8433b..9f2fe068fce3 100644 --- a/.buildkite/test_areas/basic_correctness.yaml +++ b/.buildkite/test_areas/basic_correctness.yaml @@ -9,9 +9,11 @@ steps: - vllm/ - tests/basic_correctness/test_basic_correctness - tests/basic_correctness/test_cpu_offload + - tests/basic_correctness/test_moe_expert_cache - tests/basic_correctness/test_cumem.py commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s basic_correctness/test_cumem.py - pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_cpu_offload.py + - pytest -v -s basic_correctness/test_moe_expert_cache.py diff --git a/benchmarks/qwen_122b_test_20260331.txt b/benchmarks/qwen_122b_test_20260331.txt new file mode 100644 index 000000000000..cf4c5f7429b1 --- /dev/null +++ b/benchmarks/qwen_122b_test_20260331.txt @@ -0,0 +1,28 @@ +=== Qwen 3.5-122B GGUF Q4_K — Test Results === + +Status: OOM KILLED during expert store build (~45 GB RAM, 62 GB total) + +Architecture: 48 layers × 256 experts × top_k=8 + Expert FFN: 1024→3072→1024 (small per expert) + Fused format: [out, in, 256] in Q4_K + Total expert data: ~64 GB (BF16 after dequant) + +What worked: + - GGUF parsing: 0.2s (3 shards, 879 tensors) + - Q4_K type ID: correctly mapped (was wrong, fixed) + - Q6_K dequant: added for non-expert tensors + - Multimodal wrapper: text_config extraction works + - Non-expert weight loading: succeeded + - Expert store build: started, processed ~30-35 of 48 layers before OOM + +Why it failed: + _build_expert_store_from_fused_reader accumulates all 12,288 experts + in a Python dict (BF16 tensors) before calling GenericExpertStore.from_dict. + At ~5.3 MB per expert × 12,288 experts = ~63 GB. Exceeds 62 GB RAM. + +What's needed: + Stream-to-mmap: write each expert directly to an mmap'd file as it's + dequanted, instead of accumulating in a dict. This would cap peak RAM + at ~2.4 GB (one fused layer at a time) regardless of total model size. + +No performance numbers available. diff --git a/docs/features/moe_cache_policies.md b/docs/features/moe_cache_policies.md new file mode 100644 index 000000000000..b538ea3a303b --- /dev/null +++ b/docs/features/moe_cache_policies.md @@ -0,0 +1,102 @@ +# MoE Expert Weight Caching + +vLLM can run MoE models that exceed available GPU memory by keeping all expert +weights in CPU pinned memory and caching only the most-recently-used +experts in a fixed-size GPU scratch buffer. + +This feature is controlled by the `--moe-expert-cache-size` option. + +| Option | Default | Description | +| --- | --- | --- | +| `--moe-expert-cache-size N` | `0` (disabled) | Number of expert slots to allocate in the GPU buffer per layer | + +!!! note + Expert caching requires `--enforce-eager`. CUDA graph capture is + incompatible with the dynamic Python bookkeeping in `prepare()`. + +!!! note + Expert caching is not compatible with expert parallelism (EP > 1), + data parallelism, or sequence parallelism. + +## Quick start + +```bash +# OLMoE-1B-7B: 64 experts, fits on 8 GB GPU with 16 cached per layer +vllm serve allenai/OLMoE-1B-7B-0924 \ + --moe-expert-cache-size 16 \ + --enforce-eager +``` + +### Python API + +`moe_expert_cache_size` is exposed as a direct `LLM` constructor parameter: + +```python +from vllm import LLM, SamplingParams + +llm = LLM( + model="allenai/OLMoE-1B-7B-0924", + moe_expert_cache_size=16, + enforce_eager=True, +) +``` + +## Architecture (RFC #38256) + +The cache is implemented as a `CachedWeightProvider` — the kernel does not +know or care where weights came from. + +### How it works + +```text +Decode (unique experts <= capacity) — GPU fast path: + topk_ids -> provider.prepare(): + hit -> move_to_end in OrderedDict (O(1)) + miss -> evict LRU, H2D copy, update mapping[expert] = slot + -> kernel.apply(result.w1, result.w2, result.topk_ids) +``` + +A persistent `_mapping` tensor (`int32`, GPU) holds the `expert_id -> slot` +mapping. It is updated in-place for misses and used for a vectorized remap — +no CPU tensor build or H2D transfer on the hot path. + +The `CachedWeightProvider` uses `collections.OrderedDict` for LRU eviction +(no external dependencies). When unique experts exceed capacity, a +`RuntimeError` is raised — increase `--moe-expert-cache-size` to avoid this. + +## Observability + +### DEBUG-level hit/miss log + +Set `VLLM_LOGGING_LEVEL=DEBUG` to get a per-layer hit/miss report every +60 seconds: + +```text +DEBUG vllm...expert_weight_provider: Expert cache: 1234 hits, 56 misses (95.7% hit rate) +``` + +## Sizing guidance + +Set `--moe-expert-cache-size` to the number of experts that must fit on +GPU simultaneously per layer. For a model with `E` experts and `top_k` +routing: + +- **Minimum useful**: `top_k` (one slot per active expert per token, no + eviction during decode) +- **Typical decode**: `2 * top_k` – `4 * top_k` gives headroom for + locality without wasting VRAM +- **Maximum** (no-op): `E` (all experts on GPU, equivalent to normal mode) + +## GPU memory note + +Expert weights in CPU pinned memory are invisible to the `--gpu-memory-utilization` +profiler. The profiler will underestimate available KV cache headroom by the +expert weight footprint (a safe margin, not a hazard), but exact +`gpu-memory-utilization`-based sizing will be off. + +## Tests + +```bash +# Unit tests: CachedWeightProvider +pytest tests/kernels/moe/test_expert_lru_cache.py -v +``` diff --git a/tests/basic_correctness/test_moe_expert_cache.py b/tests/basic_correctness/test_moe_expert_cache.py new file mode 100644 index 000000000000..ca0bef81b1fb --- /dev/null +++ b/tests/basic_correctness/test_moe_expert_cache.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Correctness tests for the MoE expert LRU cache (--moe-expert-cache-size). + +Runs two vllm serve instances side-by-side via compare_two_settings: + - baseline: standard MoE (all experts on GPU) + - cache: expert LRU cache enabled with a small GPU buffer + +Token outputs must match exactly. +""" + +import pytest + +from ..utils import compare_two_settings + +_MOE_MODEL = "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8" + + +@pytest.mark.parametrize("cache_size", [4, 16]) +def test_moe_expert_cache_correctness(cache_size: int) -> None: + """Output tokens from the cache path must match the no-cache baseline.""" + compare_two_settings( + model=_MOE_MODEL, + arg1=["--enforce-eager"], + arg2=[ + "--enforce-eager", + "--moe-expert-cache-size", + str(cache_size), + ], + ) + + +def test_moe_expert_cache_disabled_by_default() -> None: + """Verify that the default (cache_size=0) leaves the existing path intact.""" + compare_two_settings( + model=_MOE_MODEL, + arg1=["--enforce-eager"], + arg2=["--enforce-eager", "--moe-expert-cache-size", "0"], + ) diff --git a/tests/kernels/moe/test_expert_lru_cache.py b/tests/kernels/moe/test_expert_lru_cache.py new file mode 100644 index 000000000000..27191245e3f4 --- /dev/null +++ b/tests/kernels/moe/test_expert_lru_cache.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for CachedWeightProvider (LFRU expert cache).""" + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.expert_weight_provider import ( + CachedWeightProvider, + ExpertWeightResult, +) +from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_random_seed + +pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA required") + +NUM_EXPERTS = [8, 64] +DTYPES = [torch.bfloat16, torch.float16] +CAPACITIES = [1, 4] +HIDDEN = 16 +INTERMEDIATE = 32 + + +def _make_weights(num_experts: int, dtype: torch.dtype): + w13 = torch.randn(num_experts, 2 * INTERMEDIATE, HIDDEN, dtype=dtype) + w2 = torch.randn(num_experts, HIDDEN, INTERMEDIATE, dtype=dtype) + return w13, w2 + + +def _make_scales(num_experts: int): + w13_s = torch.rand(num_experts, 1, dtype=torch.float32) + w2_s = torch.rand(num_experts, 1, dtype=torch.float32) + return w13_s, w2_s + + +def _make_provider( + num_experts: int = 8, + capacity: int = 4, + dtype: torch.dtype = torch.bfloat16, + with_scales: bool = False, +): + set_random_seed(42) + w13, w2 = _make_weights(num_experts, dtype) + kwargs: dict = dict(capacity=capacity, w13_weight=w13, w2_weight=w2) + scales = None + if with_scales: + w13_s, w2_s = _make_scales(num_experts) + kwargs.update(w13_scale=w13_s, w2_scale=w2_s) + scales = (w13_s, w2_s) + return CachedWeightProvider(**kwargs), w13, w2, scales + + +def _topk(ids: list[int]) -> torch.Tensor: + return torch.tensor(ids, dtype=torch.int32, device="cuda").unsqueeze(0) + + +# -- Core cache behavior -- + + +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("capacity", CAPACITIES) +@pytest.mark.parametrize("dtype", DTYPES) +def test_cold_miss_and_warm_hit(num_experts: int, capacity: int, dtype: torch.dtype): + """Cold access misses, repeat access hits. GPU buffer matches source.""" + provider, w13, w2, _ = _make_provider(num_experts, capacity, dtype) + expert_ids = list(range(min(capacity, num_experts))) + + # Cold miss + result = provider.prepare(_topk(expert_ids)) + assert provider.misses == len(expert_ids) + assert provider.hits == 0 + assert isinstance(result, ExpertWeightResult) + assert result.w1 is provider.buf_w13 + assert result.w2 is provider.buf_w2 + assert result.topk_ids.shape == (1, len(expert_ids)) + + # Verify GPU buffer contents match source weights + for eid in expert_ids: + slot = provider._lru[eid][0] + torch.testing.assert_close(result.w1[slot].cpu(), w13[eid]) + torch.testing.assert_close(result.w2[slot].cpu(), w2[eid]) + + # Warm hit + prev_misses = provider.misses + provider.prepare(_topk(expert_ids)) + assert provider.hits == len(expert_ids) + assert provider.misses == prev_misses + + +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_cache_full_equals_num_experts(num_experts: int, dtype: torch.dtype): + """When capacity == num_experts, all fit with zero evictions.""" + provider, _, _, _ = _make_provider(num_experts, capacity=num_experts, dtype=dtype) + all_ids = list(range(num_experts)) + provider.prepare(_topk(all_ids)) + assert provider.misses == num_experts + assert len(provider._free_slots) == 0 + + provider.prepare(_topk(all_ids)) + assert provider.hits == num_experts + + +@pytest.mark.parametrize("capacity", CAPACITIES) +def test_topk_ids_remapping(capacity: int): + """Remapped topk_ids point to the correct GPU buffer slots.""" + provider, _, _, _ = _make_provider(capacity=capacity) + ids = list(range(min(capacity, 8))) + result = provider.prepare(_topk(ids)) + + for eid, slot in zip( + _topk(ids).squeeze(0).tolist(), + result.topk_ids.squeeze(0).tolist(), + ): + assert provider._lru[eid][0] == slot + + +@pytest.mark.parametrize("dtype", [torch.int32, torch.int64]) +def test_output_dtype_matches_input(dtype: torch.dtype): + """Remapped topk_ids preserves input dtype.""" + provider, *_ = _make_provider() + ids = torch.tensor([[0, 1]], dtype=dtype, device="cuda") + result = provider.prepare(ids) + assert result.topk_ids.dtype == dtype + + +# -- LFRU eviction semantics -- + + +def test_lfru_prefers_evicting_low_frequency(): + """LFRU evicts the expert with lowest freq/age score, not pure LRU. + A accessed 5x, B accessed 1x. When C arrives, B is evicted, not A. + """ + provider, w13, _, _ = _make_provider(capacity=2) + provider.prepare(_topk([0, 1])) + for _ in range(4): + provider.prepare(_topk([0])) # A freq=5 + provider.prepare(_topk([1])) # touch B for recency parity + + provider.prepare(_topk([2])) # evicts B (lower freq/age score) + assert 0 in provider._lru, "High-frequency expert A should survive" + assert 2 in provider._lru, "New expert C should be cached" + assert 1 not in provider._lru, "Low-frequency expert B should be evicted" + slot_c = provider._lru[2][0] + torch.testing.assert_close(provider.buf_w13[slot_c].cpu(), w13[2]) + + +def test_lfru_evicts_stale_high_freq_expert(): + """High historical freq but old last-access loses to recent low-freq. + Distinguishes LFRU (score=freq/age) from pure frequency-based caching. + """ + provider, _, _, _ = _make_provider(capacity=2) + + # Expert 0: accessed 11x early, then becomes stale + provider.prepare(_topk([0])) + for _ in range(10): + provider.prepare(_topk([0])) + # Expert 1: loaded later, accessed 51x (0 becomes very stale) + provider.prepare(_topk([1])) + for _ in range(50): + provider.prepare(_topk([1])) + + # Expert 0: freq=11, age~62 -> score~0.18. Expert 1: freq=51, age=1 -> 51 + provider.prepare(_topk([2])) + assert 1 in provider._lru, "Recent high-freq expert should survive" + assert 0 not in provider._lru, "Stale expert should be evicted" + + +def test_capacity_one_always_evicts(): + """With capacity=1, every new expert evicts the previous.""" + provider, *_ = _make_provider(capacity=1) + for eid in range(5): + provider.prepare(_topk([eid])) + assert provider.misses == 5 + assert provider.hits == 0 + assert len(provider._lru) == 1 + assert 4 in provider._lru + + +# -- GPU buffer correctness under eviction -- + + +def test_gpu_buffer_correct_after_eviction(): + """After eviction, the reused slot contains the new expert's weights.""" + provider, w13, w2, _ = _make_provider(capacity=4) + provider.prepare(_topk([0, 1, 2, 3])) + + # Make 0 the eviction candidate (least recently used, lowest freq) + provider.prepare(_topk([1, 2, 3])) + slot_for_0 = provider._lru[0][0] + + provider.prepare(_topk([7])) + assert provider._lru[7][0] == slot_for_0 + torch.testing.assert_close(provider.buf_w13[slot_for_0].cpu(), w13[7]) + torch.testing.assert_close(provider.buf_w2[slot_for_0].cpu(), w2[7]) + + +# -- Scale buffer handling -- + + +def test_scale_lifecycle(): + """Scales are allocated, copied on load, and updated on eviction.""" + if not current_platform.has_device_capability(89): + pytest.skip("FP8 requires CUDA capability >= 89") + + provider, _, _, scales = _make_provider( + capacity=4, dtype=torch.float8_e4m3fn, with_scales=True + ) + w13_s, w2_s = scales + + # Buffers allocated on GPU + assert provider.buf_w13_scale is not None + assert provider.buf_w2_scale is not None + assert provider.buf_w13_scale.device.type == "cuda" + + # Scales copied correctly on load + result = provider.prepare(_topk([3, 6])) + for eid in [3, 6]: + slot = provider._lru[eid][0] + torch.testing.assert_close(result.w1_scale[slot].cpu(), w13_s[eid]) + torch.testing.assert_close(result.w2_scale[slot].cpu(), w2_s[eid]) + + # Fill cache and evict: scales must be updated in evicted slot + provider.prepare(_topk([0, 1])) # cache now full: [3, 6, 0, 1] + provider.prepare(_topk([3, 6, 0])) # boost freq on 3,6,0; expert 1 stale + + result = provider.prepare(_topk([7])) # evicts 1 + assert 1 not in provider._lru + slot_7 = provider._lru[7][0] + torch.testing.assert_close(provider.buf_w13_scale[slot_7].cpu(), w13_s[7]) + torch.testing.assert_close(provider.buf_w2_scale[slot_7].cpu(), w2_s[7]) + + +def test_no_scales_when_not_provided(): + """Without scale inputs, scale buffers remain None.""" + provider, *_ = _make_provider() + assert provider.buf_w13_scale is None + assert provider.buf_w2_scale is None + result = provider.prepare(_topk([0])) + assert result.w1_scale is None + assert result.w2_scale is None + + +# -- Invalidation -- + + +def test_invalidate_frees_slot(): + """invalidate() removes an expert and returns its slot to the free list.""" + provider, *_ = _make_provider() + provider.prepare(_topk([0, 1, 2, 3])) + old_slot = provider._lru[2][0] + provider.invalidate(2) + assert 2 not in provider._lru + assert old_slot in provider._free_slots + + +def test_invalidate_noop_when_absent(): + """invalidate() on an uncached expert is a no-op.""" + provider, *_ = _make_provider() + provider.invalidate(99) # must not raise + + +# -- Overflow (unique experts > capacity) -- + + +def test_overflow_raises(): + """When unique experts exceed capacity, raise RuntimeError immediately.""" + provider, *_ = _make_provider(capacity=2) + with pytest.raises(RuntimeError, match="unique experts"): + provider.prepare(_topk([0, 1, 2, 3])) + + +# -- CPU pinned memory -- + + +def test_cpu_backing_is_pinned(): + """CPU weight tensors must be pinned for async H2D copies.""" + provider, *_ = _make_provider() + assert provider._cpu_w13.is_pinned() + assert provider._cpu_w2.is_pinned() diff --git a/vllm/config/offload.py b/vllm/config/offload.py index ad65e8acf35a..c6d57202fed1 100644 --- a/vllm/config/offload.py +++ b/vllm/config/offload.py @@ -94,6 +94,9 @@ class OffloadConfig: prefetch: PrefetchOffloadConfig = Field(default_factory=PrefetchOffloadConfig) """Parameters for prefetch offloading backend.""" + moe_expert_cache_size: int = Field(default=0, ge=0) + """Number of MoE expert weight rows to keep in a GPU cache buffer.""" + @model_validator(mode="after") def validate_offload_config(self) -> "OffloadConfig": """Validate offload configuration constraints.""" diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 6229b44d52a8..c21a32188c57 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1836,6 +1836,21 @@ def validate_mamba_block_size(self) -> "VllmConfig": ) return self + @model_validator(mode="after") + def validate_moe_expert_cache(self) -> "VllmConfig": + if self.model_config is None: + return self + if ( + self.offload_config.moe_expert_cache_size > 0 + and not self.model_config.enforce_eager + ): + raise ValueError( + "--moe-expert-cache-size requires --enforce-eager. " + "The expert LRU cache uses dynamic Python state in prepare() " + "that is incompatible with CUDA graph capture." + ) + return self + _current_vllm_config: VllmConfig | None = None _current_prefix: str | None = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 55c87bf356c5..5bc9ea6f2da7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -470,6 +470,7 @@ class EngineArgs: offload_num_in_group: int = PrefetchOffloadConfig.offload_num_in_group offload_prefetch_step: int = PrefetchOffloadConfig.offload_prefetch_step offload_params: set[str] = get_field(PrefetchOffloadConfig, "offload_params") + moe_expert_cache_size: int = OffloadConfig.moe_expert_cache_size gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes max_num_batched_tokens: int | None = None @@ -1094,6 +1095,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: offload_group.add_argument( "--offload-params", **prefetch_kwargs["offload_params"] ) + offload_group.add_argument( + "--moe-expert-cache-size", **offload_kwargs["moe_expert_cache_size"] + ) # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) @@ -2015,6 +2019,7 @@ def create_engine_config( offload_prefetch_step=self.offload_prefetch_step, offload_params=self.offload_params, ), + moe_expert_cache_size=self.moe_expert_cache_size, ) if self.gdn_prefill_backend is not None: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d296e84d0411..d7487422fb67 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -176,6 +176,11 @@ class LLM: these segments will be offloaded (e.g., {"gate_up_proj", "down_proj"} for MLP weights, or {"w13_weight", "w2_weight"} for MoE expert weights). If None or empty, all parameters are offloaded. + moe_expert_cache_size: Number of MoE expert weight rows to keep in a + GPU LRU buffer. When greater than zero, expert weights are stored in + CPU pinned memory and only the N most-recently-used experts are + mirrored onto the GPU on each forward pass. Requires + ``enforce_eager=True``. Default is 0 (disabled). enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. @@ -234,6 +239,7 @@ def __init__( offload_num_in_group: int = 1, offload_prefetch_step: int = 1, offload_params: set[str] | None = None, + moe_expert_cache_size: int = 0, enforce_eager: bool = False, enable_return_routed_experts: bool = False, disable_custom_all_reduce: bool = False, @@ -360,6 +366,7 @@ def _make_config(value: Any, cls: type[_R]) -> _R: offload_num_in_group=offload_num_in_group, offload_prefetch_step=offload_prefetch_step, offload_params=offload_params or set(), + moe_expert_cache_size=moe_expert_cache_size, enforce_eager=enforce_eager, enable_return_routed_experts=enable_return_routed_experts, disable_custom_all_reduce=disable_custom_all_reduce, diff --git a/vllm/model_executor/layers/fused_moe/expert_weight_provider.py b/vllm/model_executor/layers/fused_moe/expert_weight_provider.py new file mode 100644 index 000000000000..0d9355712c6a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/expert_weight_provider.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass +class ExpertWeightResult: + """GPU-resident expert weights ready for kernel consumption.""" + + w1: torch.Tensor + w2: torch.Tensor + topk_ids: torch.Tensor + w1_scale: torch.Tensor | None = None + w2_scale: torch.Tensor | None = None + + +class CachedWeightProvider: + """GPU LRU cache backed by CPU pinned memory. + + Keeps capacity expert weight tensors in a fixed-size GPU scratch + buffer. All expert weights reside in CPU pinned memory; only the N + hottest experts are mirrored into the GPU buffer. + + Uses LFRU (frequency-weighted LRU) eviction: score = freq / age. + This prevents early layers from monopolizing the cache — a known + problem with pure LRU in sequential MoE execution where early + layers always appear "recently used." + + On each forward pass, prepare() identifies which experts are needed, + copies any misses from CPU to GPU (evicting the lowest-scored entry + when the buffer is full), and returns an ExpertWeightResult with + remapped topk_ids whose values are GPU-buffer slot indices. + """ + + def __init__( + self, + capacity: int, + w13_weight: torch.Tensor, + w2_weight: torch.Tensor, + w13_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + ) -> None: + num_experts = w13_weight.size(0) + + self.capacity = capacity + self._num_experts = num_experts + self.hits = 0 + self.misses = 0 + + if w13_weight.device.type == "cpu": + cuda_device = torch.accelerator.current_accelerator() + self._cpu_w13: torch.Tensor = ( + w13_weight if w13_weight.is_pinned() else w13_weight.pin_memory() + ) + self._cpu_w2: torch.Tensor = ( + w2_weight if w2_weight.is_pinned() else w2_weight.pin_memory() + ) + else: + cuda_device = w13_weight.device + self._cpu_w13 = w13_weight.cpu().pin_memory() + self._cpu_w2 = w2_weight.cpu().pin_memory() + + self._buf_w13: torch.Tensor = torch.empty( + capacity, + *w13_weight.shape[1:], + dtype=w13_weight.dtype, + device=cuda_device, + ) + self._buf_w2: torch.Tensor = torch.empty( + capacity, + *w2_weight.shape[1:], + dtype=w2_weight.dtype, + device=cuda_device, + ) + + if w13_scale is not None and w2_scale is not None: + self._cpu_w13_scale: torch.Tensor | None = w13_scale.cpu() + self._cpu_w2_scale: torch.Tensor | None = w2_scale.cpu() + self._buf_w13_scale: torch.Tensor | None = torch.empty( + capacity, + *w13_scale.shape[1:], + dtype=w13_scale.dtype, + device=cuda_device, + ) + self._buf_w2_scale: torch.Tensor | None = torch.empty( + capacity, + *w2_scale.shape[1:], + dtype=w2_scale.dtype, + device=cuda_device, + ) + else: + self._cpu_w13_scale = None + self._cpu_w2_scale = None + self._buf_w13_scale = None + self._buf_w2_scale = None + + # LFRU state: {expert_id: [slot, freq, last_access_clock]} + # Eviction score = freq / (clock - last_access + 1). Lower = evict first. + self._lru: dict[int, list] = {} + self._clock: int = 0 + self._free_slots: list[int] = list(range(capacity)) + + # Persistent GPU mapping tensor: _mapping[expert_id] = slot. + self._mapping: torch.Tensor = torch.zeros( + num_experts, dtype=torch.int32, device=cuda_device + ) + + @property + def buf_w13(self) -> torch.Tensor: + return self._buf_w13 + + @property + def buf_w2(self) -> torch.Tensor: + return self._buf_w2 + + @property + def buf_w13_scale(self) -> torch.Tensor | None: + return self._buf_w13_scale + + @property + def buf_w2_scale(self) -> torch.Tensor | None: + return self._buf_w2_scale + + def invalidate(self, expert_id: int) -> None: + """Remove *expert_id* from the cache, returning its slot to the free + list. No-op if the expert is not currently cached.""" + if expert_id in self._lru: + entry = self._lru.pop(expert_id) + self._free_slots.append(entry[0]) + + @torch.compiler.disable + def prepare(self, topk_ids: torch.Tensor) -> ExpertWeightResult: + """Populate the GPU buffer and return slot-remapped expert IDs. + + Args: + topk_ids: Shape ``[num_tokens, top_k]``, global expert IDs. + + Returns: + ExpertWeightResult with remapped topk_ids and GPU buffer refs. + + Raises: + RuntimeError: if unique experts in the batch exceed capacity. + """ + unique_ids = topk_ids.unique().tolist() + if len(unique_ids) > self.capacity: + raise RuntimeError( + f"CachedWeightProvider: {len(unique_ids)} unique experts " + f"requested but --moe-expert-cache-size={self.capacity}. " + f"Set --moe-expert-cache-size >= {len(unique_ids)}." + ) + + for expert_id in unique_ids: + if expert_id in self._lru: + # Cache hit: update frequency and recency + self._clock += 1 + entry = self._lru[expert_id] + entry[1] += 1 # freq + entry[2] = self._clock # last access + self.hits += 1 + else: + # Cache miss: need to load expert + if self._free_slots: + slot = self._free_slots.pop() + else: + # Evict entry with lowest freq/age score + best_key = None + best_score = float("inf") + for k, (s, freq, last) in self._lru.items(): + age = self._clock - last + 1 + score = freq / age + if score < best_score: + best_score = score + best_key = k + assert best_key is not None # _lru non-empty when capacity > 0 + slot = self._lru.pop(best_key)[0] + + # Copy expert weights from CPU to GPU slot + self._buf_w13[slot].copy_(self._cpu_w13[expert_id]) + self._buf_w2[slot].copy_(self._cpu_w2[expert_id]) + if self._buf_w13_scale is not None: + assert self._cpu_w13_scale is not None + assert self._cpu_w2_scale is not None + assert self._buf_w2_scale is not None + self._buf_w13_scale[slot].copy_(self._cpu_w13_scale[expert_id]) + self._buf_w2_scale[slot].copy_(self._cpu_w2_scale[expert_id]) + + self._clock += 1 + self._lru[expert_id] = [slot, 1, self._clock] + self._mapping[expert_id] = slot + self.misses += 1 + + total = self.hits + self.misses + if total > 0: + logger.debug( + "Expert cache: %d hits, %d misses (%.1f%% hit rate)", + self.hits, + self.misses, + 100.0 * self.hits / total, + ) + + remapped_ids = self._mapping[topk_ids.long()].to(dtype=topk_ids.dtype) + + return ExpertWeightResult( + w1=self._buf_w13, + w2=self._buf_w2, + topk_ids=remapped_ids, + w1_scale=self._buf_w13_scale, + w2_scale=self._buf_w2_scale, + ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index a239dfea92e4..b3f4f9cbfe4f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -30,6 +30,16 @@ def __init__(self, moe: FusedMoEConfig): self.moe_quant_config: FusedMoEQuantConfig | None = None self.moe_kernel: mk.FusedMoEKernel | None = None + @property + def supports_expert_lru_cache(self) -> bool: + """True if this quant method is compatible with expert LRU caching. + + Subclasses override to True when the method allocates w13_weight / + w2_weight in the standard per-expert layout and does not reorder or + repack weights in a way that is incompatible with slot-based remapping. + """ + return False + @property def supports_internal_mk(self) -> bool: # NOTE(rob): temporary attribute to indicate support for diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 142e180786c6..7ccaa9aec9cf 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -93,6 +93,23 @@ def apply( shared_experts_input: torch.Tensor | None, ) -> torch.Tensor: assert self.moe_kernel is not None + + provider = getattr(layer, "expert_weight_provider", None) + if provider is not None: + result = provider.prepare(topk_ids) + return self.moe_kernel.apply( + hidden_states=x, + w1=result.w1, + w2=result.w2, + topk_weights=topk_weights, + topk_ids=result.topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + apply_router_weight_on_input=(layer.apply_router_weight_on_input), + expert_map=(None if self.disable_expert_map else layer.expert_map), + shared_experts_input=shared_experts_input, + ) + return self.moe_kernel.apply( hidden_states=x, w1=layer.w13_weight, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c4fc1fd2557e..f8f0c43fd8af 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -3,7 +3,12 @@ from collections.abc import Callable, Iterable from enum import Enum -from typing import Literal, cast, get_args, overload +from typing import TYPE_CHECKING, Literal, cast, get_args, overload + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.expert_weight_provider import ( + CachedWeightProvider, + ) import torch from torch.nn.parameter import UninitializedParameter @@ -54,6 +59,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, ) +from vllm.model_executor.utils import replace_parameter from vllm.platforms import current_platform logger = init_logger(__name__) @@ -279,6 +285,8 @@ def __init__( super().__init__() self._routed_input_transform = routed_input_transform + self._gate = gate + self._shared_experts_init = shared_experts if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -564,11 +572,110 @@ def _get_quant_method() -> FusedMoEMethodBase: ): moe_quant_params["intermediate_size_full"] = intermediate_size + # Expert weight provider: populated after weight loading via + # _maybe_init_expert_lru_cache(). Initialized *before* + # create_weights so that create_weights can inspect + # _moe_expert_cache_size and allocate expert weights on CPU pinned + # memory when offloading is requested. + self.expert_weight_provider: CachedWeightProvider | None = None + self._moe_expert_cache_size = vllm_config.offload_config.moe_expert_cache_size + if self._moe_expert_cache_size > 0 and self.use_ep: + raise ValueError( + "moe_expert_cache_size is not compatible with expert " + f"parallelism (ep_size={self.ep_size})." + ) + if self._moe_expert_cache_size > 0 and ( + self.moe_parallel_config.dp_size > 1 or self.is_sequence_parallel + ): + raise ValueError( + "moe_expert_cache_size is not compatible with data parallelism " + "or sequence parallelism." + ) + if self._moe_expert_cache_size > 0 and ( + not vllm_config.model_config.enforce_eager + ): + raise ValueError( + "moe_expert_cache_size requires --enforce-eager; CUDA graph " + "capture with an active expert cache produces incorrect " + "results." + ) + + # Disable shared expert overlap if: + # - we are using eplb with non-default backend, because of correctness issues + # - we are using flashinfer with DP, since there nothing to gain + # - we are using marlin kernels + backend = self.moe_parallel_config.all2all_backend + self.use_overlapped = ( + not ( + (self.enable_eplb and backend != "allgather_reducescatter") + or self.moe_parallel_config.use_fi_nvl_two_sided_kernels + ) + and getattr(self, "shared_experts", None) is not None + ) + self.quant_method.create_weights(layer=self, **moe_quant_params) # TODO(bnell): this is un-needed and removed in a follow up PR. self.base_quant_method = self.quant_method + self.runner = self._init_runner() + + def _maybe_init_expert_lru_cache(self) -> None: + """Initialize the expert weight provider after weights have been loaded. + + Expert weights may reside on CPU (loaded directly into pinned memory + when GPU capacity is insufficient) or on GPU (standard load path). + Allocates GPU scratch buffers of size ``moe_expert_cache_size`` and + releases the full weight tensors from whichever device they were on. + + Must be called only once, after :meth:`process_weights_after_loading`. + """ + if self._moe_expert_cache_size == 0 or self.expert_weight_provider is not None: + return + if not hasattr(self, "w13_weight") or not hasattr(self, "w2_weight"): + raise ValueError( + "moe_expert_cache_size requires w13_weight and w2_weight " + f"parameters but they are missing on layer {self.layer_name}." + ) + if self.moe_config.has_bias: + raise ValueError( + "Expert LRU cache does not support MoE layers with bias " + "terms (fused_experts() receives w1/w2 only, not bias). " + f"Layer: {self.layer_name}." + ) + from vllm.model_executor.layers.fused_moe.expert_weight_provider import ( + CachedWeightProvider, + ) + + w13_scale = getattr(self, "w13_weight_scale", None) + w2_scale = getattr(self, "w2_weight_scale", None) + if w13_scale is not None and ( + w13_scale.dim() != 1 or w13_scale.size(0) != self.local_num_experts + ): + w13_scale = None + w2_scale = None + + capacity = min(self._moe_expert_cache_size, self.local_num_experts) + self.expert_weight_provider = CachedWeightProvider( + capacity=capacity, + w13_weight=self.w13_weight.data, + w2_weight=self.w2_weight.data, + w13_scale=w13_scale, + w2_scale=w2_scale, + ) + + # Release the full weight tensors (CachedWeightProvider holds its own + # reference to the CPU pinned backing store). + replace_parameter(self, "w13_weight", torch.empty(0)) + replace_parameter(self, "w2_weight", torch.empty(0)) + logger.info( + "Expert LRU cache enabled for %s: %d/%d experts cached on GPU.", + self.layer_name, + capacity, + self.local_num_experts, + ) + + def _init_runner(self): # Storing the runner in the FusedMoE is an intermediate state, eventually # the runner will own the FusedMoE layer and provide the execution interface # for MoE ops. @@ -577,12 +684,13 @@ def _get_quant_method() -> FusedMoEMethodBase: moe_config=self.moe_config, router=self.router, routed_input_transform=self._routed_input_transform, - gate=gate, - shared_experts=shared_experts, + gate=self._gate, + shared_experts=self._shared_experts_init, quant_method=self.quant_method, reduce_results=self.reduce_results, enable_dbo=self.vllm_config.parallel_config.enable_dbo, ) + return self.runner # TODO(bnell): This method is provided as a hook so vllm/lora/layers/fused_moe.py # can safely swap out the quant_method. We should figure out a less diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 190821562130..c7cf135101b0 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -43,6 +43,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): # --8<-- [end:unquantized_fused_moe] + @property + def supports_expert_lru_cache(self) -> bool: + # FLASHINFER_TRTLLM reorders weights into a tiled block layout that is + # incompatible with the generic per-expert slot-based remapping. + return self.unquantized_backend != UnquantizedMoeBackend.FLASHINFER_TRTLLM + def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.unquantized_backend, self.experts_cls = select_unquantized_moe_backend( @@ -93,16 +99,27 @@ def create_weights( w13_up_dim = 2 * intermediate_size_per_partition else: w13_up_dim = intermediate_size_per_partition + + # When the expert LRU cache is enabled, allocate expert weights in CPU + # pinned memory so that checkpoint loading never allocates GPU memory + # for them. This allows models whose expert weights exceed GPU capacity + # to load successfully; the cache init later populates a small GPU + # scratch buffer (size = moe_expert_cache_size) from these CPU tensors. + use_cpu_pinned = getattr(layer, "_moe_expert_cache_size", 0) > 0 + # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - w13_up_dim, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) + if use_cpu_pinned: + # Explicitly set device="cpu" to override vLLM's torch.device("cuda") + # context, then pin. pin_memory=True cannot be used alone here because + # the device context would silently move the allocation to CUDA first. + _w13_data = torch.empty( + num_experts, w13_up_dim, hidden_size, dtype=params_dtype, device="cpu" + ).pin_memory() + else: + _w13_data = torch.empty( + num_experts, w13_up_dim, hidden_size, dtype=params_dtype + ) + w13_weight = torch.nn.Parameter(_w13_data, requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) if self.moe.has_bias: @@ -113,15 +130,22 @@ def create_weights( layer.register_parameter("w13_bias", w13_bias) set_weight_attrs(w13_bias, extra_weight_attrs) # down_proj (row parallel) - w2_weight = torch.nn.Parameter( - torch.empty( + if use_cpu_pinned: + _w2_data = torch.empty( num_experts, hidden_size, intermediate_size_per_partition, dtype=params_dtype, - ), - requires_grad=False, - ) + device="cpu", + ).pin_memory() + else: + _w2_data = torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ) + w2_weight = torch.nn.Parameter(_w2_data, requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) if self.moe.has_bias: @@ -222,24 +246,46 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) else: self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) - elif self.unquantized_backend == UnquantizedMoeBackend.XPU: - w13 = layer.w13_weight - w2 = layer.w2_weight - - w13.data = w13.transpose(-1, -2).contiguous() - w2.data = w2.transpose(-1, -2).contiguous() - - self._setup_kernel( - layer=layer, - w13=w13, - w2=w2, - ) - else: - self._setup_kernel( - layer=layer, - w13=layer.w13_weight, - w2=layer.w2_weight, + elif current_platform.is_cuda_alike() or current_platform.is_xpu(): + # When the expert LRU cache is active, expert weights were loaded + # directly into CPU pinned memory (see create_weights). Skip the + # kernel setup (which requires CUDA weights) and go straight to + # cache initialization, which allocates the small GPU scratch buffer. + cache_active = ( + self.supports_expert_lru_cache and layer.w13_weight.device.type == "cpu" ) + if cache_active: + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + layer._maybe_init_expert_lru_cache() + # Still need to create the kernel (forward_native requires it) + if self.moe_kernel is None: + assert self.experts_cls is not None + self.moe_kernel = make_unquantized_moe_kernel( + quant_config=self.moe_quant_config, + moe_config=self.moe, + backend=self.unquantized_backend, + experts_cls=self.experts_cls, + routing_tables=layer._maybe_init_expert_routing_tables(), + shared_experts=getattr(layer, "shared_experts", None), + ) + else: + w13 = layer.w13_weight + w2 = layer.w2_weight + # XPU requires transposed weight layout + if current_platform.is_xpu(): + w13.data = w13.transpose(-1, -2).contiguous() + w2.data = w2.transpose(-1, -2).contiguous() + self._setup_kernel( + layer=layer, + w13=w13, + w2=w2, + ) + # Initialize expert LRU cache after kernel setup so the CPU + # backing store captures the final (possibly padded/shuffled) + # weights. Skipped for backends whose weight layout is + # incompatible with the generic fused_experts() path. + if self.supports_expert_lru_cache: + layer._maybe_init_expert_lru_cache() def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: if self.moe.has_bias: @@ -275,6 +321,23 @@ def forward_native( shared_experts_input: torch.Tensor | None, ) -> torch.Tensor: assert self.moe_kernel is not None + + provider = getattr(layer, "expert_weight_provider", None) + if provider is not None: + result = provider.prepare(topk_ids) + return self.moe_kernel.apply( + hidden_states=x, + w1=result.w1, + w2=result.w2, + topk_weights=topk_weights, + topk_ids=result.topk_ids, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + shared_experts_input=shared_experts_input, + ) + return self.moe_kernel.apply( hidden_states=x, w1=layer.w13_weight, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index dfb09d57361e..f2a9b3e9e33f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -574,6 +574,22 @@ class Fp8MoEMethod(FusedMoEMethodBase): quant_config: The quantization config. """ + @property + def supports_expert_lru_cache(self) -> bool: + from vllm.model_executor.layers.fused_moe.oracle.fp8 import Fp8MoeBackend + + _compatible_backends = { + Fp8MoeBackend.TRITON, + Fp8MoeBackend.BATCHED_TRITON, + Fp8MoeBackend.VLLM_CUTLASS, + Fp8MoeBackend.BATCHED_VLLM_CUTLASS, + Fp8MoeBackend.XPU, + } + # Block-quant scales are multi-dimensional and cannot be remapped + # per slot; backends that reorder weights into opaque layouts + # (DEEPGEMM, MARLIN, AITER, FLASHINFER) are also incompatible. + return not self.block_quant and self.fp8_backend in _compatible_backends + def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): super().__init__(layer.moe_config) self.quant_config = quant_config @@ -829,6 +845,13 @@ def process_weights_after_loading(self, layer: Module) -> None: layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale ) + # Initialize expert LRU cache when requested and compatible. + if self.supports_expert_lru_cache: + layer._maybe_init_expert_lru_cache() + + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, @@ -902,6 +925,23 @@ def apply( ) -> torch.Tensor: assert not self.is_monolithic assert self.moe_kernel is not None + + provider = getattr(layer, "expert_weight_provider", None) + if provider is not None: + result = provider.prepare(topk_ids) + return self.moe_kernel.apply( + x, + result.w1, + result.w2, + topk_weights, + result.topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=(layer.apply_router_weight_on_input), + shared_experts_input=shared_experts_input, + ) + return self.moe_kernel.apply( x, layer.w13_weight,