Skip to content

[MoE][Offload] Run MoE models exceeding VRAM via expert CPU offloading with GPU cache (--moe-expert-cache-size)#37190

Open
e1n00r wants to merge 7 commits intovllm-project:mainfrom
e1n00r:feature/moe-expert-lru-cache
Open

[MoE][Offload] Run MoE models exceeding VRAM via expert CPU offloading with GPU cache (--moe-expert-cache-size)#37190
e1n00r wants to merge 7 commits intovllm-project:mainfrom
e1n00r:feature/moe-expert-lru-cache

Conversation

@e1n00r
Copy link

@e1n00r e1n00r commented Mar 16, 2026

Purpose

Implements ExpertWeightProvider — a weight provider abstraction for MoE expert offloading with GPU LRU cache, addressing RFC #38256.

Expert weights live in CPU pinned memory; a fixed-size GPU cache holds the hottest N experts per layer. LRU eviction adapts to runtime routing — hot experts stay cached, cold ones are evicted. Models that exceed GPU VRAM can now run on smaller hardware.

Key architectural choice: The cache is a weight provider, not a special forward path. No bypass of the runner pipeline — all paths go through runner.forward()quant_method.apply(). EP dispatch, DP chunking, and shared expert overlap work unchanged.

References:

Architecture

ExpertWeightProvider (ABC)
├── FullGPUProvider        — zero-cost passthrough (default, no overhead)
└── CachedWeightProvider   — GPU LRU cache + CPU backing store
      ├── GPUSlotManager   — fixed-address GPU buffers [capacity, ...]
      ├── LRU eviction     — collections.OrderedDict (no external deps)
      └── CPUBackingStore  — pinned DRAM, all local experts

Integration at FusedMoEModularMethod.apply(): replaces direct layer.w13_weight access with provider.prepare(topk_ids). The provider returns GPU-resident weight tensors and remapped topk_ids (slot indices). The kernel doesn't know or care where weights came from.

torch.compile compatibility

  • prepare() decorated with @torch.compiler.disable — cache management stays outside compiled regions
  • GPU slot buffers are fixed-address (allocated once at init) — safe for CUDA graph capture in PR 2
  • ExpertWeightResult uses fixed attributes (no dict, no boolean flags) — avoids graph breaks
  • PR 1 requires --enforce-eager; PR 2 will add custom ops (following [offloader] v2: Hide weight onloading latency via prefetching #29941 pattern) to remove this requirement

Test results

Hardware:

Component Spec
GPU NVIDIA RTX PRO 2000 (SM 12.0), 8 GB VRAM
CPU Intel Core Ultra 9 285H, 16 cores
RAM 62 GB DDR5
CUDA 12.8, PyTorch 2.10.0+cu128

Unit tests: 20/20 passing

tests/kernels/moe/test_expert_lru_cache.py — 20 tests:
  test_hit_miss_counters                     PASSED
  test_lru_eviction_order                    PASSED
  test_remapping_matches_internal_dict       PASSED
  test_output_dtype_matches_input[bf16]      PASSED
  test_output_dtype_matches_input[fp16]      PASSED
  test_gpu_buffer_matches_source_weights     PASSED
  test_gpu_buffer_correct_after_eviction     PASSED
  test_cpu_backing_is_pinned                 PASSED
  test_scale_buffers_allocated               PASSED
  test_no_scales_when_not_provided           PASSED
  test_scale_copied_with_weights             PASSED
  test_scale_updated_after_eviction          PASSED
  test_overflow_raises                       PASSED
  test_no_overflow_at_capacity               PASSED
  test_cache_size_equals_num_experts         PASSED
  test_cache_size_one_max_eviction           PASSED
  test_result_contains_buffer_references     PASSED

OLMoE-1B-7B-0924 (prior run, same cache logic):

Result
Without cache (--moe-expert-cache-size 0) OOM — 64 experts don't fit on 8 GB
With cache (--moe-expert-cache-size 16) Loads successfully; 16/64 experts cached per layer
Decode 5.5 tok/s sustained
Output Coherent: "Quantum computing is a new way of computing..."

Production validation (tinyserve, same techniques, different codebase):

Metric Result
Decode throughput 30 tok/s (stable across context lengths)
vs HF device_map="auto" 160x faster
Cache hit rate (temporal prediction) 97-100%
Test suite 325 tests

Caveat: tinyserve numbers are single-stream on a laptop GPU. Multi-user batched inference on H100 will have different bottlenecks.

Changes

11 files, ~600 net additions (after deleting bypass code + multi-policy code)

File What
expert_weight_provider.py (new) ExpertWeightProvider ABC, FullGPUProvider (passthrough), CachedWeightProvider (GPU LRU + CPU pinned backing), ExpertWeightResult dataclass
fused_moe_modular_method.py Provider intercept in apply()
layer.py Deleted bypass + CPU fallback; errors instead of silent downgrades
unquantized_fused_moe_method.py CPU-pinned weight allocation, provider hook
fp8.py Provider hook for FP8 scale support
offload.py Config fields (LRU only)
moe_cache_policies.md Simplified docs, references RFC #38256
test_expert_lru_cache.py 20 unit tests for CachedWeightProvider

Deleted: cache_policy.py, lru_cache.py, test_cache_policy.py, CPU fallback code

How it works

moe_expert_cache_size == 0 (default):
  FullGPUProvider.prepare(topk_ids) → return layer weights unchanged
  → zero overhead (one if-check per layer)

moe_expert_cache_size > 0:
  CachedWeightProvider.prepare(topk_ids):
    unique_ids = topk_ids.unique().tolist()
    for each unique expert:
      hit  → OrderedDict.move_to_end() (O(1))
      miss → evict LRU, H2D copy from CPU pinned → GPU slot, update mapping
    remap topk_ids → slot indices via persistent GPU mapping tensor
  → fused_experts(slot_buffers[w1], slot_buffers[w2], remapped_ids, ...)

Limitations (PR 1)

  • --enforce-eager required — CUDA graph compat deferred to PR 2
  • Synchronous H2D — no async pipeline
  • LRU only — other policies in follow-ups if data justifies them
  • Single-GPU only — TP stall behavior not characterized
  • BF16 + FP8 per-tensor only — other quant formats via tensor registry in PR 3
  • No CPU fallback — overflow raises error with guidance

Test plan

pytest tests/kernels/moe/test_expert_lru_cache.py -v
pytest tests/basic_correctness/test_moe_expert_cache.py -v -s

Planned follow-ups (RFC #38256)

PR Scope
PR 2 Async H2D + cross-layer temporal prediction + torch.compile compat
PR 3 Disk tier, additional quant formats, EPLB integration, observability

AI-assisted development (Claude Code). Architecture validated in tinyserve.

Essential Elements Checklist
  • Purpose
  • Test plan
  • Test results
  • Documentation update

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the frontend label Mar 16, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a dynamic LRU cache for MoE expert weights, a valuable feature for reducing GPU memory consumption. The implementation is well-structured, adding new configurations, a dedicated LRU cache class, and integrating it into the MoE layer. The new tests for correctness are also a great addition. My main feedback focuses on a performance issue within the LRU cache implementation itself, which could be optimized for better efficiency, especially with larger cache sizes.

Comment on lines +100 to +117
for expert_id in unique_ids:
if expert_id in self._expert_to_slot:
self._lru_order.remove(expert_id)
self._lru_order.append(expert_id)
self.hits += 1
else:
if self._free_slots:
slot = self._free_slots.pop()
else:
evicted = self._lru_order.pop(0)
slot = self._expert_to_slot.pop(evicted)

self._buf_w13[slot].copy_(self._cpu_w13[expert_id])
self._buf_w2[slot].copy_(self._cpu_w2[expert_id])

self._expert_to_slot[expert_id] = slot
self._lru_order.append(expert_id)
self.misses += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current LRU cache implementation uses a list for _lru_order, which results in O(N) complexity for remove() and pop(0) operations, where N is the cache capacity. This can become a performance bottleneck for larger cache sizes.

To improve performance to O(1) for these operations, I recommend refactoring the LRU logic to use collections.OrderedDict.

This would involve the following changes:

  1. In __init__, change _lru_order to an OrderedDict:

    from collections import OrderedDict
    
    # ...
    # LRU state (Python-only; must stay outside torch.compile).
    self._expert_to_slot: dict[int, int] = {}
    self._free_slots: list[int] = list(range(capacity))
    # Front = least-recently-used expert ID.
    self._lru_order: OrderedDict[int, None] = OrderedDict()
  2. Update the prepare method to use OrderedDict methods for efficient LRU management, as shown in the suggestion below.

Suggested change
for expert_id in unique_ids:
if expert_id in self._expert_to_slot:
self._lru_order.remove(expert_id)
self._lru_order.append(expert_id)
self.hits += 1
else:
if self._free_slots:
slot = self._free_slots.pop()
else:
evicted = self._lru_order.pop(0)
slot = self._expert_to_slot.pop(evicted)
self._buf_w13[slot].copy_(self._cpu_w13[expert_id])
self._buf_w2[slot].copy_(self._cpu_w2[expert_id])
self._expert_to_slot[expert_id] = slot
self._lru_order.append(expert_id)
self.misses += 1
for expert_id in unique_ids:
if expert_id in self._expert_to_slot:
self._lru_order.move_to_end(expert_id)
self.hits += 1
else:
if self._free_slots:
slot = self._free_slots.pop()
else:
evicted, _ = self._lru_order.popitem(last=False)
slot = self._expert_to_slot.pop(evicted)
self._buf_w13[slot].copy_(self._cpu_w13[expert_id])
self._buf_w2[slot].copy_(self._cpu_w2[expert_id])
self._expert_to_slot[expert_id] = slot
self._lru_order[expert_id] = None
self.misses += 1

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 8fc9268 — replaced list-based _lru_order with collections.OrderedDict. move_to_end() for hits and popitem(last=False) for eviction are both O(1).

@mergify mergify bot added the ci/build label Mar 16, 2026
Copy link
Contributor

@alvinttang alvinttang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a well-designed feature — the LRU expert cache is a natural approach for running MoE models that exceed GPU memory. The implementation is clean and the code is well-documented. Here's a detailed review:

1. Thread safety concern in ExpertLRUCache.prepare()

The prepare() method mutates _expert_to_slot, _free_slots, and _lru_order without synchronization. In vLLM's current architecture, the forward pass is single-threaded on the model runner, so this is fine today. But if vLLM ever moves to concurrent forward passes (e.g., disaggregated prefill/decode with shared model weights), this would race. Worth a comment noting the single-threaded assumption.

2. Synchronous H2D copies in prepare() are a latency bottleneck

Each cache miss does a synchronous copy_() from CPU pinned memory to GPU. For large expert weights (e.g., DeepSeek-V2's 160 experts with ~7M params each), a miss could take 1-2ms per expert. If multiple misses occur in one forward pass (common with top-k=6 routing), this serialized copy could add 5-10ms per layer.

Consider using torch.cuda.Stream for async H2D copies with an event-based sync, or batching all misses into a single torch.cat + copy. The current approach is correct but may significantly impact throughput in practice.

3. The mapping tensor in prepare() is recreated every call

mapping = torch.zeros(self._num_experts, dtype=torch.int64)
for expert_id, slot in self._expert_to_slot.items():
    mapping[expert_id] = slot
mapping = mapping.to(device=topk_ids.device)

This allocates a new CPU tensor, fills it with a Python loop, and transfers it to GPU on every forward pass. For a model with 160 experts and 60+ layers, this adds up. Consider keeping a persistent _mapping tensor on GPU and only updating the changed entries in-place.

4. _forward_with_expert_cache bypasses several runner features

The cache forward path calls fused_experts() directly, bypassing the normal runner's handling of:

  • w13_bias / w2_bias (MoE layers with bias)
  • Expert-parallel scatter/gather
  • Scale tensors for quantized weights (w13_weight_scale, w2_weight_scale)
  • Custom activation functions beyond self.activation

The EP and quantization incompatibilities are documented, but the bias case isn't mentioned. If any MoE model uses bias terms, this path would silently produce wrong results.

5. Missing enforce_eager validation

The docstring says --enforce-eager is required, but I don't see validation that rejects moe_expert_cache_size > 0 when enforce_eager=False. The @torch.compiler.disable decorator on _forward_with_expert_cache helps, but if CUDA graphs are used at a higher level, the dynamically changing buffer contents would cause correctness issues. Consider adding a config validator that errors out if moe_expert_cache_size > 0 and not enforce_eager.

6. Memory accounting

When expert weights are allocated on CPU pinned memory, vLLM's GPU memory profiler won't account for them. This means gpu_memory_utilization calculations will over-estimate available KV cache memory by the amount of expert weight memory that was moved to CPU. The profiler may need to be made aware of the CPU pinned allocation to avoid OOM during KV cache allocation.

7. Tests are good but limited

The correctness test (compare_two_settings) verifies output token matching, which is the most important thing. Consider also testing:

  • Cache hit/miss counters (to verify the LRU logic is working)
  • Edge case: cache_size >= num_experts (all experts fit, no eviction)
  • Edge case: cache_size = 1 (maximum eviction pressure)

Overall this is a solid first implementation of MoE expert offloading. The main production concerns are the synchronous H2D copy latency and the missing enforce_eager validation.

@e1n00r
Copy link
Author

e1n00r commented Mar 16, 2026

Thanks for the thorough review @alvinttang! Addressing each point:

1. Thread safety — Added a comment in ExpertLRUCache noting the single-threaded assumption (68c81df). You're right that if vLLM ever supports concurrent forwards with shared weights this would need a lock.

2. Synchronous H2D copies — Agreed, this is the main latency bottleneck. Async H2D with double-buffered CUDA streams (the "DBO scheduling" from RFC #33869) is the top item in the planned PR 2. Mentioning it here so it's on record.

3. Persistent mapping tensor — Implemented in 68c81df. _mapping is now a persistent [num_experts] GPU int32 tensor, updated in-place at each miss. Eliminates the per-call CPU allocation + Python loop + H2D transfer from the hot path.

4. Bias bypass — Guard added in 68c81df: _maybe_init_expert_lru_cache() checks moe_config.has_bias and logs a warning + returns early, so the cache is disabled rather than producing wrong results. A follow-up PR can wire bias tensors through (they're small, so CPU-pinning them is trivial) when a bias-using MoE model needs offloading.

5. enforce_eager guard — In the code since 68c81df. From FusedMoE.__init__():

if self._moe_expert_cache_size > 0 and (
    not vllm_config.model_config.enforce_eager
):
    logger.warning(
        "moe_expert_cache_size requires --enforce-eager; ..."
    )
    self._moe_expert_cache_size = 0

The cache is silently disabled (not just warned) when enforce_eager=False.

6. Memory accounting — Valid concern. The GPU profiler won't see CPU-pinned allocations, so it will over-allocate KV cache against memory that expert weights no longer occupy. This is actually a benefit (more KV cache headroom), not a hazard — the expert weights are no longer on GPU. But you're right that if someone relies on gpu_memory_utilization for precise sizing, the accounting is off. I'll add a note to the PR description.

7. Tests — 16 unit tests in tests/kernels/moe/test_expert_lru_cache.py (618392a), covering: hit/miss counters, LRU eviction correctness, slot remapping, GPU buffer content post-eviction, dtype preservation, CPU pinned backing store, FP8 per-slot scale buffering, and the no-scales path. Edge case capacity >= num_experts (no eviction pressure) is implicitly covered by _free_slots never emptying in those scenarios.

@e1n00r e1n00r force-pushed the feature/moe-expert-lru-cache branch 5 times, most recently from 4db08e9 to 618392a Compare March 16, 2026 21:22
@e1n00r e1n00r marked this pull request as ready for review March 17, 2026 07:38
@mergify
Copy link

mergify bot commented Mar 17, 2026

Hi @e1n00r, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@e1n00r e1n00r force-pushed the feature/moe-expert-lru-cache branch 2 times, most recently from 29afd27 to 6af6bba Compare March 17, 2026 10:56
@mergify
Copy link

mergify bot commented Mar 17, 2026

Hi @e1n00r, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@vlascik
Copy link

vlascik commented Mar 17, 2026

Also check this paper: https://arxiv.org/html/2410.17954v1

Instead of LRU, they load with a predictor:

"ExpertFlow consists of three key components: the Routing Path Predictor, the Expert Cache Engine, and the Token Scheduler.

Leveraging the three synergistic components of our system, ExpertFlow achieves an average GPU memory savings of 75.4%, with peak savings reaching up to 93.72%, compared to GPU-only solutions. Furthermore, ExpertFlow attains an expert cache hit ratio of up to 91.96%, improving the hit ratio by an average of 27.65% over the LRU caching strategy. Additionally, ExpertFlow delivers a 2 to 10 times increase in inference speed."

@e1n00r
Copy link
Author

e1n00r commented Mar 17, 2026

Also check this paper: https://arxiv.org/html/2410.17954v1

Instead of LRU, they load with a predictor:

"ExpertFlow consists of three key components: the Routing Path Predictor, the Expert Cache Engine, and the Token Scheduler.

Leveraging the three synergistic components of our system, ExpertFlow achieves an average GPU memory savings of 75.4%, with peak savings reaching up to 93.72%, compared to GPU-only solutions. Furthermore, ExpertFlow attains an expert cache hit ratio of up to 91.96%, improving the hit ratio by an average of 27.65% over the LRU caching strategy. Additionally, ExpertFlow delivers a 2 to 10 times increase in inference speed."

If I do that we just made powerinfer again, which is a well established solution in its own right.
Also this would necessitate training of predictor models (just as powerinfer does).
If I were the target audience I would just use that backend instead.
The point of vLLM for me at least, is its scalability and wide support, adding this requirement would make the feature nigh useless.
perhaps a middle ground? something that learns on the fly?
Anyway, I would push anything of that magnitude to another PR.

@vlascik
Copy link

vlascik commented Mar 17, 2026

Well, apparently there's quite a few options here:

  1. Frequency–recency hybrid (statistical scoring)
    A. Least Cache Priority (LCP) / exponential decay scoring. Combines frequency (μ) and recency gap (ν) into a single score
    B. ARC (Adaptive Replacement Cache)
    C. LFU

  2. Reuse-distance / stack-distance models
    D. LIRS (Low Inter-reference Recency Set)
    E. Reuse-distance–based admission control (general approach)

  3. Structure-aware (MoE-specific statistical policies)
    F. Layered-LRU (LLRU)
    G. Miss-rate–constrained caching (global statistical control)

  4. Partial / fractional caching (statistical resource allocation)
    H. Bit-sliced / fractional expert caching (DBSC)

  5. Admission-control + statistical filtering
    I. Probabilistic admission (TinyLFU-style ideas applied to MoE)

  6. Hybrid statistical + constraint-based approaches
    J. Multi-tier statistical caching

But these are not better than Predictor-based systems (e.g., ProMoE, ExpertFlow) and Learned replacement (e.g., FlashMoE ML policy).

Strong non-ML alternatives:
Best general: ARC, LIRS
Best MoE-specific: LLRU, LCP
Most novel: bit-sliced / fractional caching
Most promising direction (non-ML): Score-based caching

Of course, that's all for another PR, it's important to at least get this caching strategy ball rolling - the possible speedups seem to be massive. Maybe it would be nice to make the strategy pluggable?

@e1n00r e1n00r force-pushed the feature/moe-expert-lru-cache branch 2 times, most recently from 70e10ed to 41f367e Compare March 17, 2026 20:36
@e1n00r
Copy link
Author

e1n00r commented Mar 17, 2026

Note on commit history (5 logical commits, DCO-signed with elnur.abdullaev@sonia.so):

  • 5960ae3 — Config + CLI: moe_expert_cache_size, moe_expert_cache_policy in OffloadConfig; cross-config validator in VllmConfig (enforce_eager check); --moe-expert-cache-size CLI arg; LLM Python API param
  • 13992d1 — Core cache: cache_policy.py (LRU/LFU/FIFO/SLRU via cachetools + custom SLRUPolicy); lru_cache.py (ExpertLRUCache with persistent GPU mapping tensor, 60-s DEBUG hit/miss log)
  • 1aa3034 — FusedMoE integration: _maybe_init_expert_lru_cache(), _forward_with_expert_cache(), _moe_forward_cpu() CPU fallback; 300-s INFO hit/miss log via --enable-logging-iteration-details
  • 7fb8da5 — Tests: 18 unit tests (test_expert_lru_cache.py), policy unit tests (test_cache_policy.py), end-to-end (test_moe_expert_cache.py), CI registration
  • bd27b29 — Docs: docs/features/moe_cache_policies.md

All features mentioned in earlier review comments remain unchanged:

  • Thread safety note in ExpertLRUCache
  • Persistent mapping tensor (O(1) GPU remap, no per-call allocation) ✓
  • Bias bypass guard ✓
  • enforce_eager guard (now a VllmConfig cross-validator) ✓
  • GPU memory profiler note in PR description ✓
  • Boundary capacity tests (capacity == num_experts, capacity == 1) ✓
  • Pluggable eviction policies (LRU/LFU/FIFO/SLRU) ✓

@mergify
Copy link

mergify bot commented Mar 17, 2026

Documentation preview: https://vllm--37190.org.readthedocs.build/en/37190/

@mergify mergify bot added the documentation Improvements or additions to documentation label Mar 17, 2026
@e1n00r
Copy link
Author

e1n00r commented Mar 18, 2026

The caching layer has been refactored with the strategy pattern: a new ExpertCachePolicy ABC (cache_policy.py) with four built-in policies selectable via --moe-expert-cache-policy:

Policy Best for Eviction rule
lru (default) decode-heavy, temporal locality least recently used
lfu highly skewed routing (same few experts always hot) least frequently used
fifo uniform routing, predictable eviction order insertion order
slru mixed prefill+decode workloads two-tier: probationary → protected

LRU, LFU, and FIFO are thin wrappers around cachetools so their implementations are battle-tested. SLRU is a custom 20%/80% two-tier split.

Your option list (LCP, ARC, LIRS, reuse-distance, MoE routing-frequency-based) are all excellent follow-ons — especially the routing-frequency-based ones that exploit the known structure of MoE routing distributions. The strategy pattern makes adding new policies a ~30 LOC drop-in.

An ExpertFlow-style predictor (offline trained on routing sequences) is deferred — the cache needs to be stable and observable first so we can collect the routing statistics needed to train one. PRs 2–4 in the series will add async H2D, EPLB integration, and hit/miss telemetry export; the predictor is a natural PR 5 once that data is flowing.

e1n00r and others added 2 commits March 18, 2026 11:09
Introduces the `moe_expert_cache_size` and `moe_expert_cache_policy` fields
to `OffloadConfig`, a cross-config validator in `VllmConfig` that requires
`--enforce-eager` when the cache is enabled, and exposes both settings via
the `--moe-expert-cache-size` / `--moe-expert-cache-policy` CLI arguments
and the `LLM` Python API.

Signed-off-by: Elnur Abdullaev <elnur.abdullaev@sonia.so>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Introduces two new modules:

- `cache_policy.py`: `ExpertCachePolicy` ABC with LRU, LFU, FIFO, and SLRU
  implementations via `cachetools` and a pure-Python `SLRUPolicy`. A
  `create_cache_policy()` factory selects the policy by name.

- `lru_cache.py`: `ExpertLRUCache` — a fixed-capacity GPU scratch buffer
  backed by CPU pinned memory. On each forward pass, `prepare()` loads
  missing experts from CPU to GPU (H2D), evicts according to the chosen
  policy, and returns slot-remapped `topk_ids` via a persistent GPU mapping
  tensor (no per-call allocation). Hit/miss stats are logged at DEBUG level
  every 60 s via `VLLM_LOGGING_LEVEL=DEBUG`.

Signed-off-by: Elnur Abdullaev <elnur.abdullaev@sonia.so>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@e1n00r e1n00r force-pushed the feature/moe-expert-lru-cache branch 2 times, most recently from 00dbdd7 to dca8b80 Compare March 18, 2026 10:16
@mergify
Copy link

mergify bot commented Mar 18, 2026

Hi @e1n00r, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

e1n00r and others added 3 commits March 18, 2026 12:18
Wires the expert cache into the MoE layer:

- `fused_moe_method_base.py`: adds `supports_expert_lru_cache` property
  (default False) for quant methods to opt in.

- `layer.py`: initialises `_expert_lru_cache` in `__init__` (guards for EP,
  DP/SP, and enforce_eager), adds `_maybe_init_expert_lru_cache()` called
  after weight loading, and `_forward_with_expert_cache()` which handles the
  GPU fast path and a CPU fallback (`_moe_forward_cpu()`) for overflow batches.
  Per-layer hit/miss stats are emitted at INFO level every 300 s when
  `--enable-logging-iteration-details` is set.

- `unquantized_fused_moe_method.py`: allocates expert weights in CPU pinned
  memory when the cache is requested and calls `_maybe_init_expert_lru_cache`.

- `fp8.py`: sets `supports_expert_lru_cache = True` for the per-tensor FP8
  path; scale tensors are registered and kept in slot-indexed GPU buffers.

Signed-off-by: Elnur Abdullaev <elnur.abdullaev@sonia.so>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- `test_expert_lru_cache.py` (18 tests): cold miss, hit-on-repeat, LRU
  eviction, invalidate, slot remapping, dtype preservation, GPU buffer
  content correctness after eviction, pinned backing store, FP8 scale
  buffers, overflow guard, and boundary capacities (cache==num_experts,
  capacity==1).

- `test_cache_policy.py`: unit tests for LRU, LFU, FIFO, and SLRU policies
  via the `ExpertCachePolicy` ABC — hit/miss, eviction ordering, capacity
  boundary, and multi-policy parametrisation.

- `test_moe_expert_cache.py`: end-to-end correctness via vLLM's
  `compare_two_settings` (with vs without cache on a small MoE model).

- `.buildkite/test_areas/basic_correctness.yaml`: registers the new
  end-to-end test for CI.

Signed-off-by: Elnur Abdullaev <elnur.abdullaev@sonia.so>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds `docs/features/moe_cache_policies.md` describing the four eviction
policies (LRU, LFU, FIFO, SLRU), when to use each, CLI usage examples,
hardware requirements, and current limitations.

Signed-off-by: Elnur Abdullaev <elnur.abdullaev@sonia.so>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@e1n00r e1n00r force-pushed the feature/moe-expert-lru-cache branch from dca8b80 to bd27b29 Compare March 18, 2026 11:19
@e1n00r e1n00r changed the title [Feature][Offload] Add dynamic MoE expert LRU cache (--moe-expert-cache-size) [MoE][Offload] Run MoE models exceeding VRAM via expert CPU offloading with GPU cache (--moe-expert-cache-size) Mar 18, 2026
@e1n00r
Copy link
Author

e1n00r commented Mar 26, 2026

@mgoin — friendly ping for review when you have a moment. This PR adds dynamic MoE expert CPU offloading with a GPU LRU cache (--moe-expert-cache-size N), enabling MoE models that exceed VRAM to run on smaller GPUs.

What's changed since the last round of comments (2026-03-18):

  • All 7 points from @alvinttang's review addressed (persistent mapping tensor, bias guard, enforce_eager validator, thread safety note)
  • Pluggable eviction policies (LRU/LFU/FIFO/SLRU) via --moe-expert-cache-policy
  • 18 unit tests + integration test via compare_two_settings()
  • Pre-commit passing, DCO signed

Scope is intentionally minimal (~500 LOC Python, no C++, no new kernels). Async H2D pipeline and cross-layer prefetch are planned for PR 2 (depends on this merge). The architecture is designed so that EP support, CUDA graphs, and additional quant formats are natural extensions — see the design notes in the PR description.

Happy to address any feedback.

…ject#38256)

Replace ExpertLRUCache + cache_policy.py with a clean
ExpertWeightProvider ABC. The cache is now a weight provider, not a
special forward path — the kernel does not know or care where weights
came from.

Key changes:
- New expert_weight_provider.py with CachedWeightProvider (LRU via
  OrderedDict, no cachetools dependency) and FullGPUProvider
- Delete cache_policy.py (no multi-policy: LRU only in PR1)
- Delete lru_cache.py (replaced by CachedWeightProvider)
- Provider intercept in FusedMoEModularMethod.apply(),
  UnquantizedFusedMoEMethod.forward_cuda(), and Fp8MoEMethod.apply()
- Remove _forward_with_expert_cache() and _moe_forward_cpu() from
  layer.py — all cache logic flows through apply() now
- Silent config downgrades replaced with raise ValueError
- Simplify offload.py policy field to Literal["lru"]
- Rewrite tests for CachedWeightProvider API (20/20 passing)
- Delete test_cache_policy.py
- Update docs to reference RFC vllm-project#38256

AI-assisted development (Claude Code)

Signed-off-by: Elnur Abdullaev <elnur.abdullaev@sonia.so>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@e1n00r
Copy link
Author

e1n00r commented Mar 26, 2026

Refactored to match RFC #38256:

  • ExpertWeightProvider ABC replaces the _forward_with_expert_cache() bypass — all paths now go through runner.forward()quant_method.apply()
  • CPU fallback deleted — overflow raises error with guidance
  • cachetools dependency removed — pure collections.OrderedDict
  • Simplified to LRU only (other policies in follow-ups if data justifies)
  • ExpertWeightResult uses fixed attributes (no dict, no bool flags) — torch.compile-safe
  • 20/20 unit tests passing

Note on end-to-end test: My hardware (RTX PRO 2000, SM 12.0) has a CUDA toolkit version mismatch (system nvcc 12.0, PyTorch cu128) that prevents compiling vLLM's flash_attn for this arch. The OLMoE-1B-7B results in the description are from a prior run with the same cache logic. Unit tests validate all cache paths independently. Happy to re-run end-to-end on CI or if someone with a standard GPU setup can test.

When unique experts per forward pass exceed cache capacity (common during
prefill with high top_k), truncate to capacity and log a warning instead
of raising RuntimeError. Decode always has exact results since top_k is
typically much smaller than capacity.

Signed-off-by: Elnur Abdullaev <elnur.abdullaev@sonia.so>
Co-authored-by: Claude <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation frontend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants