Skip to content

moe-offload-fixed#31938

Closed
wangyxbh wants to merge 3 commits intovllm-project:mainfrom
wangyxbh:moeoffload
Closed

moe-offload-fixed#31938
wangyxbh wants to merge 3 commits intovllm-project:mainfrom
wangyxbh:moeoffload

Conversation

@wangyxbh
Copy link
Copy Markdown

@wangyxbh wangyxbh commented Jan 8, 2026

PR: CPU Offload for Mixture-of-Experts (MoE) Inference in vLLM

Summary

This PR proposes a CPU Offload Module for MoE inference in vLLM, enabling a large portion of expert weights and computation to be dynamically offloaded to the CPU while keeping only a small, hot subset of experts cached on GPU.

The design supports:

  • Hybrid GPU–CPU execution
  • Pinned-memory–based weight streaming
  • Asynchronous GPU ↔ CPU interaction via callback
  • AVX / AMX–optimized CPU MoE kernels
  • Two execution modes:
    • DBO (Dual Batch Overlap)
    • Prefetch-based miss expert execution

This approach significantly reduces GPU memory footprint and improves scalability for large MoE models with thousands of experts.

Motivation

Problem Statement

Large MoE models (e.g., Deepseek-like architectures) suffer from GPU memory pressure during inference:

  • Expert weights dominate memory usage
  • Only a small subset of experts is active per token

Existing approaches typically:

  • Require all expert weights on GPU
  • Perform naive CPU fallback with high synchronization overhead

Key Observations

  1. Expert access is sparse and skewed: Only a small fraction of experts are accessed per token, with certain experts being accessed more frequently than others
  2. CPU memory capacity is abundant: CPU RAM can easily accommodate the full set of expert weights that would otherwise overwhelm GPU memory
  3. Modern CPUs (AVX512 / AMX) can efficiently compute MoE FFNs: CPU-based computation can be competitive for expert forward passes when properly optimized
  4. GPU kernels can overlap compute with CPU execution if callbacks are used: Asynchronous execution patterns enable efficient resource utilization

Goals of This RFC

  • Reduce GPU memory usage without sacrificing throughput
  • Enable fine-grained expert caching
  • Overlap GPU compute, CPU compute, and data movement
  • Integrate cleanly with vLLM's execution model

Proposed Change

Architecture Overview

+-------------------+        callback        +----------------------+
|      GPU          |  <----------------->   |        CPU           |
|                   |  update cache experts  |                      |
|  Cached Experts   |  <-----------------   |  Offloaded Experts   |
|  (hot experts)    |                        |  (cold experts)      |
|                   |                        |                      |
| fused_experts()   |                        | AVX / AMX MoE FFN    |
+-------------------+                        +----------------------+

Core Design Philosophy

The core design principle is that the GPU no longer stores all expert weights for each layer, but instead caches only a limited number of hot experts. The CPU maintains the complete set of experts and dynamically determines which experts need to be copied to the GPU and which should be computed directly on the CPU based on actual token routing behavior.

The entire mechanism revolves around:

  • Expert cache management
  • Miss buffer handling
  • Copy policy decisions
  • CPU/GPU computation overlap

Key Components

  1. Python Offload Manager (CpuOffloadInfer): Orchestrates the offload process, manages expert cache state, and coordinates GPU-CPU interactions
  2. GPU Expert Cache: Limited-capacity cache storing hot experts on GPU
  3. Miss Expert Buffer (double-buffered): Temporary buffer for experts that miss the cache during forward passes
  4. CPU MoE Execution Engine: AVX/AMX-optimized kernels for computing expert forward passes on CPU
  5. GPU↔CPU Callback-based Synchronization: Asynchronous communication mechanism for coordinating GPU and CPU execution

Initialization Phase

During model initialization:

  • All MoE expert weights for each layer are fully loaded and permanently resident in CPU pinned memory
  • The GPU allocates an Expert Cache with capacity cache_expert_num for each layer, storing the most frequently accessed experts
  • The GPU cache is not static; experts are dynamically managed based on runtime token routing behavior

To track the state of experts in the GPU cache, the system maintains per-layer metadata:

  • cache_map: Maps expert IDs to their positions in the GPU cache
  • miss_map: Tracks which experts are currently in the miss buffer
  • policy_sort: Maintains priority ordering for expert replacement decisions

Forward Pass Execution Flow

Step 1: Expert Cache Policy Matching

At the start of a forward pass, the model has already obtained topk_ids for each token from the router. The system calls expert_cache_policy to match these topk_ids against the current layer's cache state.

This process outputs two key pieces of information:

  1. cpu_topk_ids: Which tokens' experts require CPU computation
  2. copy_map: The set of experts that need to be copied from CPU to GPU in this forward pass

Important: copy_map does not directly correspond to "experts copied to GPU cache". It is simply a list of experts that need to be copied in this pass, and their final destination depends on the execution mode.

Step 2: Execution Mode Selection

The system operates in two primary execution modes:

DBO Mode (Dual Batch Overlap)

When the system is in DBO mode or in decode/small batch scenarios, the forward pass enters a fully parallel CPU-GPU execution path:

  • Experts in copy_map are asynchronously copied to the GPU Expert Cache for subsequent fused_experts computation
  • CPU immediately begins computing miss experts
  • CPU computation, GPU computation, and expert copying are deliberately placed in different execution threads
  • Overlap is achieved through vLLM's DBO scheduling mechanism: while the GPU computes fused experts for the current batch, the CPU is already working on miss experts for the next step or the same step, maximizing resource utilization and reducing decode latency
Prefetch Mode

In Prefetch mode (typically for larger prefill batches), system behavior adjusts based on the number of tokens in the batch:

  • As token count increases, more experts are triggered in the forward pass

  • The system dynamically calculates n_copy to limit the maximum number of experts copied in this pass

  • If n_copy is less than the total number of experts:

    • CPU still participates in computation
    • Experts in copy_map are not placed in the GPU cache
    • Instead, they are copied to a dedicated Miss Expert Buffer (temp_layer)
    • GPU uses this temp buffer to execute fused_experts
    • CPU computes the remaining experts that were not copied
    • Results from both paths are merged at the output stage
  • When batch size is extremely large and n_copy covers all or nearly all experts:

    • The system automatically degrades to "full GPU mode"
    • CPU no longer participates in computation
    • All experts are copied and fused_experts computation is completed on the GPU side
    • This is not an additional branch logic, but a natural consequence of the Prefetch strategy when copy count reaches the threshold

Double-Buffered Miss Expert Buffer Management: To prevent miss experts from being overwritten during cross-layer execution, the system globally maintains only two Miss Expert Buffers, using layer_id % 2 for double-buffering:

  • Even-numbered layers use buffer 0
  • Odd-numbered layers use buffer 1

By coordinating with independent CUDA streams and events:

  • Copy and computation on the same buffer are strictly serialized
  • Different buffers can form a natural pipeline
  • Expert copying and computation for adjacent layers can interleave, enabling efficient pipelining

Note

Cursor Bugbot is generating a summary for commit 1dfbf12. Configure here.

Signed-off-by: wangyxbh <wangyxbh@qq.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a CPU offload mechanism for Mixture-of-Experts (MoE) models, which is a significant and complex feature. The changes span across the build system (CMake), C++/CUDA kernels, and Python configuration and execution logic. The implementation includes optimizations like AVX/AMX for CPU kernels and asynchronous GPU-CPU interaction. My review focuses on critical issues related to process stability, resource utilization, and portability. I've identified a critical issue where memory allocation failures can terminate the entire process, and several high-severity issues related to inefficient busy-waiting in both CPU and GPU code, as well as hardcoded system configurations that limit portability. Addressing these points will improve the robustness and efficiency of the new MoE offload feature.

Comment on lines +67 to +71
if (buffer == nullptr) {
std::cerr << "Memory allocation failed for buffer: " << name
<< " size: " << size << std::endl;
exit(-1);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using exit(-1) on memory allocation failure is problematic for a library, as it terminates the entire process abruptly. This prevents the calling application from handling the error gracefully, such as by releasing resources or attempting recovery. It's better to throw an exception, like std::bad_alloc, to allow the caller to catch and manage the allocation failure. You will need to include the <new> header.

    if (buffer == nullptr) {
        throw std::bad_alloc();
    }

Comment on lines +104 to +106
while (thread_state_[i].status->load(std::memory_order_acquire) == ThreadStatus::WORKING) {
// Busy-wait for completion (consider condition variables for production)
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This busy-wait loop on the main thread can consume 100% of a CPU core while waiting for worker threads to complete. This is highly inefficient. While this might be an intentional choice to minimize latency, it's generally better to use a more power-efficient synchronization mechanism. The comment already suggests condition variables, which would be ideal. A simpler improvement would be to yield the thread's time slice or sleep for a very short duration inside the loop to reduce CPU pressure, similar to what's done in worker_thread.

        while (thread_state_[i].status->load(std::memory_order_acquire) == ThreadStatus::WORKING) {
            std::this_thread::yield();
        }

Comment on lines +60 to +62
while (data->callback_completed == 0) {
__threadfence_system(); // 确保从 CPU 读取最新值
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This busy-wait loop inside sync_kernel can be problematic. While it's only a single thread due to the <<<1, 1>>> launch configuration, it will still occupy a GPU execution slot and spin, consuming resources while waiting for the CPU. If the CPU task is delayed for any reason (e.g., OS preemption, heavy load), this could lead to wasted GPU cycles and potential device-wide issues if not handled carefully. A more robust approach would be to use CUDA events for synchronization, which are more efficient and don't involve active spinning on the GPU.

Comment on lines +617 to +618
# the following cpu setting only for DCG's shenzhen AI servers
core_per_worker = 64 // tp_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The CPU core affinity logic is hardcoded for a specific server configuration with 64 cores, as noted in the comment. This will cause issues or suboptimal performance on machines with a different number of CPU cores. This should be generalized to dynamically determine the number of available cores, for example by using psutil.cpu_count(logical=False).

Suggested change
# the following cpu setting only for DCG's shenzhen AI servers
core_per_worker = 64 // tp_size
# Determine cores per worker dynamically
total_cores = psutil.cpu_count(logical=False) or psutil.cpu_count(logical=True)
core_per_worker = total_cores // tp_size

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 8, 2026

Hi @wangyxbh, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 8, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wangyxbh.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 8, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Jan 9, 2026

Documentation preview: https://vllm--31938.org.readthedocs.build/en/31938/

@mergify mergify bot added the documentation Improvements or additions to documentation label Jan 9, 2026
data->submit_count += 1;
__threadfence_system(); // 确保写入对 CPU 可见
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory fence placement causes GPU-CPU synchronization race condition

High Severity

The __threadfence_system() in submit_kernel is called after writing gpu_signal = 1, not before. This means the CPU polling loop in cpu_polling_loop could observe gpu_signal == 1 before the data fields (layer_idx, batch_idx, num_tokens) are visible, causing the CPU to read stale or uninitialized values. The fence should be placed between writing the data fields and setting gpu_signal = 1 to ensure proper release-acquire semantics.

Additional Locations (1)

Fix in Cursor Fix in Web

activation="silu", # 可从 layer 配置获取
global_num_experts=config['total_experts'],
apply_router_weight_on_input=False,
expert_map=miss_map, # 使用 layer 的 miss_map
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused map_modified causes experts computed with invalid weights

High Severity

In forward_prefetch, map_modified is created and sets experts from n_copy onwards to -1, but then miss_map (the unmodified version) is passed to fused_experts as expert_map. Since only experts 0 to n_copy-1 are copied to temp_layer, using miss_map causes fused_experts to attempt computing experts with indices >= n_copy using uninitialized or stale weights in the temp buffer. The expert_map argument should use map_modified when n_copy < 256.

Fix in Cursor Fix in Web

if n_copy < 256:
map_modified = miss_map.clone()
map_modified[n_copy:] = -1
mask = (topk_ids > n_copy) & (miss_map[topk_ids] >= 0)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off-by-one error in prefetch expert threshold comparison

Medium Severity

The mask condition topk_ids > n_copy excludes expert with ID equal to n_copy from being sent to CPU. Since update_expert_cache copies experts with IDs 0 through n_copy-1, the expert at index n_copy is not copied to temp_layer but also won't be sent to CPU for computation. This should be topk_ids >= n_copy to correctly route all non-copied experts to CPU.

Fix in Cursor Fix in Web

统一入口:根据配置选择 DBO 或 Prefetch 模式
"""
if self.tp_rank == 1 and layer_id == 4:
print(f"dbo_enable():{dbo_enabled()} and ntok = {hidden_states.shape[0]}")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print statements left in production code

Low Severity

Multiple debug print statements appear to have been left in the codebase: conditional prints in forward_offload (line 548), _build_attn_metadata (line 1591), _dummy_run (lines 4254, 4350), and warmup loop (line 4911). These will produce noisy output during production inference and degrade performance.

Additional Locations (2)

Fix in Cursor Fix in Web


# 如果 buffer 为空,返回零张量
if not temp_layer:
return torch.zeros_like(hidden_states)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent return value causes tuple unpacking error

High Severity

In forward_prefetch, when temp_layer is None, the function returns only a single tensor (torch.zeros_like(hidden_states)) instead of the expected tuple of two values. The caller in fp8.py unpacks the result as miss_output, expert_map = layer.cpu_offload_eng.forward_offload(...), which will raise a ValueError when this code path is hit. The return statement needs to include both the tensor and a cache_map value.

Fix in Cursor Fix in Web


schedule.run_on_nodes(cpu_node)
memory.set_membind_nodes(cpu_node)
proc_handle.cpu_affinity(cpu_cores)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded NUMA configuration fails on non-matching systems

Medium Severity

The NUMA binding code assumes a 64-core, 2-NUMA-node system. When moe_offload is enabled on systems with different configurations (fewer cores, different NUMA topology, single NUMA node), this will cause cpu_affinity to fail with invalid core IDs or set_membind_nodes to fail with non-existent NUMA node 1. There's no validation or graceful fallback for unsupported hardware configurations.

Fix in Cursor Fix in Web

thread_num_ = std::min(max_thread_num_, task_num);

const int base = task_num / thread_num_;
const int remain = task_num % thread_num_;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Division by zero when task count is zero

Medium Severity

In do_work_stealing_job, when task_num is 0, thread_num_ becomes 0 via std::min(max_thread_num_, task_num). The subsequent division task_num / thread_num_ causes undefined behavior (division by zero). This can occur when num_tokens is 0 or when n_tokens / BM evaluates to 0 for small token counts. The function topk_sort_inplace is called unconditionally from Moe::forward before any size checks, propagating zero task_num to this code.

Fix in Cursor Fix in Web

@LucasWilkinson
Copy link
Copy Markdown
Collaborator

cc @mgoin although would probably be good to get input from @robertgshaw2-redhat @bnellnm @varun-sundar-rabindranath too

Signed-off-by: wangyxbh <wangyxbh@digitalchina.com>
@benchislett
Copy link
Copy Markdown
Collaborator

Please create an RFC or change to a draft PR given the large scope of this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation needs-rebase v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants