Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/test_areas/basic_correctness.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ steps:
- vllm/
- tests/basic_correctness/test_basic_correctness
- tests/basic_correctness/test_cpu_offload
- tests/basic_correctness/test_moe_expert_cache
- tests/basic_correctness/test_cumem.py
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s basic_correctness/test_cumem.py
- pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_cpu_offload.py
- pytest -v -s basic_correctness/test_moe_expert_cache.py
28 changes: 28 additions & 0 deletions benchmarks/qwen_122b_test_20260331.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
=== Qwen 3.5-122B GGUF Q4_K — Test Results ===

Status: OOM KILLED during expert store build (~45 GB RAM, 62 GB total)

Architecture: 48 layers × 256 experts × top_k=8
Expert FFN: 1024→3072→1024 (small per expert)
Fused format: [out, in, 256] in Q4_K
Total expert data: ~64 GB (BF16 after dequant)

What worked:
- GGUF parsing: 0.2s (3 shards, 879 tensors)
- Q4_K type ID: correctly mapped (was wrong, fixed)
- Q6_K dequant: added for non-expert tensors
- Multimodal wrapper: text_config extraction works
- Non-expert weight loading: succeeded
- Expert store build: started, processed ~30-35 of 48 layers before OOM

Why it failed:
_build_expert_store_from_fused_reader accumulates all 12,288 experts
in a Python dict (BF16 tensors) before calling GenericExpertStore.from_dict.
At ~5.3 MB per expert × 12,288 experts = ~63 GB. Exceeds 62 GB RAM.

What's needed:
Stream-to-mmap: write each expert directly to an mmap'd file as it's
dequanted, instead of accumulating in a dict. This would cap peak RAM
at ~2.4 GB (one fused layer at a time) regardless of total model size.

No performance numbers available.
102 changes: 102 additions & 0 deletions docs/features/moe_cache_policies.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# MoE Expert Weight Caching

vLLM can run MoE models that exceed available GPU memory by keeping all expert
weights in CPU pinned memory and caching only the most-recently-used
experts in a fixed-size GPU scratch buffer.

This feature is controlled by the `--moe-expert-cache-size` option.

| Option | Default | Description |
| --- | --- | --- |
| `--moe-expert-cache-size N` | `0` (disabled) | Number of expert slots to allocate in the GPU buffer per layer |

!!! note
Expert caching requires `--enforce-eager`. CUDA graph capture is
incompatible with the dynamic Python bookkeeping in `prepare()`.

!!! note
Expert caching is not compatible with expert parallelism (EP > 1),
data parallelism, or sequence parallelism.

## Quick start

```bash
# OLMoE-1B-7B: 64 experts, fits on 8 GB GPU with 16 cached per layer
vllm serve allenai/OLMoE-1B-7B-0924 \
--moe-expert-cache-size 16 \
--enforce-eager
```

### Python API

`moe_expert_cache_size` is exposed as a direct `LLM` constructor parameter:

```python
from vllm import LLM, SamplingParams

llm = LLM(
model="allenai/OLMoE-1B-7B-0924",
moe_expert_cache_size=16,
enforce_eager=True,
)
```

## Architecture (RFC #38256)

The cache is implemented as a `CachedWeightProvider` — the kernel does not
know or care where weights came from.

### How it works

```text
Decode (unique experts <= capacity) — GPU fast path:
topk_ids -> provider.prepare():
hit -> move_to_end in OrderedDict (O(1))
miss -> evict LRU, H2D copy, update mapping[expert] = slot
-> kernel.apply(result.w1, result.w2, result.topk_ids)
```

A persistent `_mapping` tensor (`int32`, GPU) holds the `expert_id -> slot`
mapping. It is updated in-place for misses and used for a vectorized remap —
no CPU tensor build or H2D transfer on the hot path.

The `CachedWeightProvider` uses `collections.OrderedDict` for LRU eviction
(no external dependencies). When unique experts exceed capacity, a
`RuntimeError` is raised — increase `--moe-expert-cache-size` to avoid this.

## Observability

### DEBUG-level hit/miss log

Set `VLLM_LOGGING_LEVEL=DEBUG` to get a per-layer hit/miss report every
60 seconds:

```text
DEBUG vllm...expert_weight_provider: Expert cache: 1234 hits, 56 misses (95.7% hit rate)
```

## Sizing guidance

Set `--moe-expert-cache-size` to the number of experts that must fit on
GPU simultaneously per layer. For a model with `E` experts and `top_k`
routing:

- **Minimum useful**: `top_k` (one slot per active expert per token, no
eviction during decode)
- **Typical decode**: `2 * top_k` – `4 * top_k` gives headroom for
locality without wasting VRAM
- **Maximum** (no-op): `E` (all experts on GPU, equivalent to normal mode)

## GPU memory note

Expert weights in CPU pinned memory are invisible to the `--gpu-memory-utilization`
profiler. The profiler will underestimate available KV cache headroom by the
expert weight footprint (a safe margin, not a hazard), but exact
`gpu-memory-utilization`-based sizing will be off.

## Tests

```bash
# Unit tests: CachedWeightProvider
pytest tests/kernels/moe/test_expert_lru_cache.py -v
```
39 changes: 39 additions & 0 deletions tests/basic_correctness/test_moe_expert_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Correctness tests for the MoE expert LRU cache (--moe-expert-cache-size).

Runs two vllm serve instances side-by-side via compare_two_settings:
- baseline: standard MoE (all experts on GPU)
- cache: expert LRU cache enabled with a small GPU buffer

Token outputs must match exactly.
"""

import pytest

from ..utils import compare_two_settings

_MOE_MODEL = "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"


@pytest.mark.parametrize("cache_size", [4, 16])
def test_moe_expert_cache_correctness(cache_size: int) -> None:
"""Output tokens from the cache path must match the no-cache baseline."""
compare_two_settings(
model=_MOE_MODEL,
arg1=["--enforce-eager"],
arg2=[
"--enforce-eager",
"--moe-expert-cache-size",
str(cache_size),
],
)


def test_moe_expert_cache_disabled_by_default() -> None:
"""Verify that the default (cache_size=0) leaves the existing path intact."""
compare_two_settings(
model=_MOE_MODEL,
arg1=["--enforce-eager"],
arg2=["--enforce-eager", "--moe-expert-cache-size", "0"],
)
Loading
Loading