diff --git a/CMakeLists.txt b/CMakeLists.txt
index c46fb18d7bfe..8f1c0b55a954 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1096,6 +1096,303 @@ define_extension_target(
USE_SABI 3
WITH_SOABI)
+#
+# _offload_C extension
+#
+
+# Find OpenMP (required by offload module)
+if(VLLM_GPU_LANG STREQUAL "CUDA")
+ find_package(OpenMP)
+ if(OpenMP_CXX_FOUND)
+ message(STATUS "Found OpenMP: ${OpenMP_CXX_FLAGS}")
+ else()
+ message(WARNING "OpenMP not found, but may be required by offload module")
+ endif()
+endif()
+
+set(VLLM_OFFLOAD_EXT_SRC
+ "csrc/offload/forward_context.cpp"
+ "csrc/offload/moe.cpp"
+ "csrc/offload/primitives.cpp"
+ "csrc/offload/py_bindding.cpp")
+
+if(VLLM_GPU_LANG STREQUAL "CUDA")
+ list(APPEND VLLM_OFFLOAD_EXT_SRC
+ "csrc/offload/moe_kernel.cu")
+endif()
+
+set_gencode_flags_for_srcs(
+ SRCS "${VLLM_OFFLOAD_EXT_SRC}"
+ CUDA_ARCHS "${CUDA_ARCHS}")
+
+# Prepare include directories and libraries for _offload_C extension
+set(_OFFLOAD_C_INCLUDE_DIRS
+ "${CMAKE_CURRENT_SOURCE_DIR}/csrc/offload"
+)
+
+set(_OFFLOAD_C_LIBRARIES)
+# Add OpenMP if found (required by offload module)
+if(OpenMP_CXX_FOUND)
+ list(APPEND _OFFLOAD_C_LIBRARIES OpenMP::OpenMP_CXX)
+endif()
+
+# Add C++ specific compile flags for offload module (AVX512 and AMX optimizations)
+# Note: AVX512 and AMX are CPU optimizations, not GPU-specific, so always apply them
+set(_OFFLOAD_CXX_FLAGS)
+# Add AVX512 and AMX optimizations for CPU code in offload module
+# These flags match the standalone setup.py configuration
+list(APPEND _OFFLOAD_CXX_FLAGS
+ "-mavx512f"
+ "-mavx512bf16"
+ "-mamx-tile"
+ "-mamx-bf16"
+ "-fvisibility=hidden"
+)
+
+message(STATUS "Enabling offload extension.")
+message(STATUS "_offload_C extension sources: ${VLLM_OFFLOAD_EXT_SRC}")
+message(STATUS "_offload_C extension include directories: ${_OFFLOAD_C_INCLUDE_DIRS}")
+message(STATUS "_offload_C extension libraries: ${_OFFLOAD_C_LIBRARIES}")
+message(STATUS "_offload_C C++ compile flags: ${_OFFLOAD_CXX_FLAGS}")
+
+define_extension_target(
+ _offload_C
+ DESTINATION vllm
+ LANGUAGE ${VLLM_GPU_LANG}
+ SOURCES ${VLLM_OFFLOAD_EXT_SRC}
+ COMPILE_FLAGS ${VLLM_GPU_FLAGS}
+ ARCHITECTURES ${VLLM_GPU_ARCHES}
+ INCLUDE_DIRECTORIES ${_OFFLOAD_C_INCLUDE_DIRS}
+ LIBRARIES ${_OFFLOAD_C_LIBRARIES}
+ # Note: Not using USE_SABI 3 because pybind11 type casters need access to
+ # PyTorch's internal symbols which are restricted by Stable ABI.
+ # This extension uses pybind11 for class registration, which requires
+ # full symbol visibility.
+ WITH_SOABI)
+
+# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
+# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
+# PyBind11 and PyTorch require full Python API access, not limited API.
+target_compile_options(_offload_C PRIVATE
+ $<$:-UPy_LIMITED_API>
+ $<$:-UPy_LIMITED_API>)
+
+# Explicitly link torch_python library for pybind11 type casters
+# The pybind11 type caster symbols (e.g., type_caster) are in libtorch_python.so
+# Without this explicit link, the dynamic linker won't know to look for symbols in torch_python
+# Priority: 1) Use CMake target torch::torch_python (handles dependencies automatically)
+# 2) Find library file and link directly
+if(TARGET torch::torch_python)
+ get_target_property(_torch_python_lib torch::torch_python IMPORTED_LOCATION)
+ if(NOT _torch_python_lib)
+ get_target_property(_torch_python_lib torch::torch_python LOCATION)
+ endif()
+
+ if(_torch_python_lib AND EXISTS "${_torch_python_lib}")
+ message(STATUS "Linking _offload_C with torch::torch_python target (preferred): ${_torch_python_lib}")
+
+ # Use --whole-archive to force complete linking (fixes RTLD_LOCAL symbol isolation)
+ if(NOT WIN32)
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+ # GNU linker: use --whole-archive to force complete linking
+ target_link_libraries(_offload_C PRIVATE
+ "-Wl,--no-as-needed"
+ "-Wl,--whole-archive"
+ "${_torch_python_lib}"
+ "-Wl,--no-whole-archive"
+ "-Wl,--as-needed"
+ )
+ message(STATUS "Applied --whole-archive to torch_python for GNU linker")
+ elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
+ # Clang linker: use -force_load (macOS) or --whole-archive (Linux)
+ if(APPLE)
+ target_link_libraries(_offload_C PRIVATE
+ "-Wl,-force_load,${_torch_python_lib}"
+ )
+ message(STATUS "Applied -force_load to torch_python for Clang on macOS")
+ else()
+ target_link_libraries(_offload_C PRIVATE
+ "-Wl,--no-as-needed"
+ "-Wl,--whole-archive"
+ "${_torch_python_lib}"
+ "-Wl,--no-whole-archive"
+ "-Wl,--as-needed"
+ )
+ message(STATUS "Applied --whole-archive to torch_python for Clang on Linux")
+ endif()
+ else()
+ # Fallback: normal linking
+ target_link_libraries(_offload_C PRIVATE "${_torch_python_lib}")
+ endif()
+ else()
+ # Windows: normal linking
+ target_link_libraries(_offload_C PRIVATE "${_torch_python_lib}")
+ endif()
+
+ # Store the library directory for RPATH
+ get_filename_component(TORCH_PYTHON_LIB_DIR ${_torch_python_lib} DIRECTORY)
+ if(TORCH_PYTHON_LIB_DIR)
+ set(_TORCH_PYTHON_RPATH_DIR "${TORCH_PYTHON_LIB_DIR}")
+ message(STATUS "Will add torch_python directory to RPATH: ${TORCH_PYTHON_LIB_DIR}")
+ endif()
+ else()
+ # Fallback to normal target linking if library path not found
+ target_link_libraries(_offload_C PRIVATE torch::torch_python)
+ message(STATUS "Linking _offload_C with torch::torch_python target (fallback)")
+ endif()
+else()
+ # Fallback: find library file
+ find_library(TORCH_PYTHON_LIB
+ NAMES torch_python
+ PATHS
+ ${CMAKE_PREFIX_PATH}
+ ${Python_SITELIB}
+ PATH_SUFFIXES
+ lib
+ lib64
+ torch/lib
+ )
+
+ # Also try to find it from Python's torch package
+ # Use platform-agnostic detection for library extension
+ if(NOT TORCH_PYTHON_LIB)
+ execute_process(
+ COMMAND ${Python_EXECUTABLE} -c "import torch, os, sys; ext='dll' if sys.platform=='win32' else ('dylib' if sys.platform=='darwin' else 'so'); lib_path = os.path.join(os.path.dirname(torch.__file__), 'lib', f'libtorch_python.{ext}'); print(lib_path if os.path.exists(lib_path) else '')"
+ OUTPUT_VARIABLE TORCH_PYTHON_LIB_PATH
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ ERROR_QUIET
+ )
+ if(TORCH_PYTHON_LIB_PATH AND EXISTS "${TORCH_PYTHON_LIB_PATH}")
+ set(TORCH_PYTHON_LIB "${TORCH_PYTHON_LIB_PATH}")
+ message(STATUS "Found torch_python from Python package: ${TORCH_PYTHON_LIB}")
+ else()
+ # Try torch.utils.cmake_prefix_path as fallback
+ execute_process(
+ COMMAND ${Python_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
+ OUTPUT_VARIABLE TORCH_CMAKE_PREFIX_PATH
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ ERROR_QUIET
+ )
+ if(TORCH_CMAKE_PREFIX_PATH)
+ find_library(TORCH_PYTHON_LIB_FALLBACK
+ NAMES torch_python
+ PATHS ${TORCH_CMAKE_PREFIX_PATH}
+ PATH_SUFFIXES lib lib64 torch/lib
+ NO_DEFAULT_PATH
+ )
+ if(TORCH_PYTHON_LIB_FALLBACK)
+ set(TORCH_PYTHON_LIB "${TORCH_PYTHON_LIB_FALLBACK}")
+ message(STATUS "Found torch_python via torch.utils.cmake_prefix_path: ${TORCH_PYTHON_LIB}")
+ else()
+ set(TORCH_PYTHON_LIB "")
+ endif()
+ else()
+ set(TORCH_PYTHON_LIB "")
+ endif()
+ endif()
+ endif()
+
+ if(TORCH_PYTHON_LIB AND EXISTS "${TORCH_PYTHON_LIB}")
+ # Link the library file directly
+ # Important: Use --whole-archive to force complete linking (fixes RTLD_LOCAL symbol isolation)
+ message(STATUS "Linking _offload_C with torch_python: ${TORCH_PYTHON_LIB}")
+
+ # Use --whole-archive to force complete linking (fixes RTLD_LOCAL symbol isolation)
+ if(NOT WIN32)
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+ # GNU linker: use --whole-archive to force complete linking
+ target_link_libraries(_offload_C PRIVATE
+ "-Wl,--no-as-needed"
+ "-Wl,--whole-archive"
+ "${TORCH_PYTHON_LIB}"
+ "-Wl,--no-whole-archive"
+ "-Wl,--as-needed"
+ )
+ message(STATUS "Applied --whole-archive to torch_python for GNU linker")
+ elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
+ # Clang linker: use -force_load (macOS) or --whole-archive (Linux)
+ if(APPLE)
+ target_link_libraries(_offload_C PRIVATE
+ "-Wl,-force_load,${TORCH_PYTHON_LIB}"
+ )
+ message(STATUS "Applied -force_load to torch_python for Clang on macOS")
+ else()
+ target_link_libraries(_offload_C PRIVATE
+ "-Wl,--no-as-needed"
+ "-Wl,--whole-archive"
+ "${TORCH_PYTHON_LIB}"
+ "-Wl,--no-whole-archive"
+ "-Wl,--as-needed"
+ )
+ message(STATUS "Applied --whole-archive to torch_python for Clang on Linux")
+ endif()
+ else()
+ # Fallback: normal linking
+ target_link_libraries(_offload_C PRIVATE ${TORCH_PYTHON_LIB})
+ endif()
+ else()
+ # Windows: normal linking
+ target_link_libraries(_offload_C PRIVATE ${TORCH_PYTHON_LIB})
+ endif()
+
+ # Store the library directory for RPATH (will be merged later)
+ get_filename_component(TORCH_PYTHON_LIB_DIR ${TORCH_PYTHON_LIB} DIRECTORY)
+ if(TORCH_PYTHON_LIB_DIR)
+ set(_TORCH_PYTHON_RPATH_DIR "${TORCH_PYTHON_LIB_DIR}")
+ message(STATUS "Will add torch_python directory to RPATH: ${TORCH_PYTHON_LIB_DIR}")
+ endif()
+ else()
+ message(WARNING "torch_python library not found. Pybind11 type casters may not work correctly.")
+ message(WARNING "Trying to use --allow-shlib-undefined as fallback...")
+ # Set linker flags to allow undefined symbols (they will be resolved at runtime from torch)
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
+ target_link_options(_offload_C PRIVATE
+ "-Wl,--allow-shlib-undefined"
+ )
+ endif()
+ endif()
+endif()
+
+# Add C++ specific compile options for offload module
+if(_OFFLOAD_CXX_FLAGS)
+ target_compile_options(_offload_C PRIVATE
+ $<$:${_OFFLOAD_CXX_FLAGS}>)
+endif()
+
+# Set RPATH to ensure all required libraries can be found at runtime
+# Collect all RPATH directories and merge them (don't overwrite!)
+set(_OFFLOAD_RPATH_DIRS)
+
+# Add torch_python RPATH if it was set
+if(DEFINED _TORCH_PYTHON_RPATH_DIR AND _TORCH_PYTHON_RPATH_DIR)
+ list(APPEND _OFFLOAD_RPATH_DIRS "${_TORCH_PYTHON_RPATH_DIR}")
+endif()
+
+# Set all RPATH directories at once (merge, don't overwrite)
+if(_OFFLOAD_RPATH_DIRS)
+ # Remove duplicates
+ list(REMOVE_DUPLICATES _OFFLOAD_RPATH_DIRS)
+ # Convert list to platform-specific separator
+ if(UNIX AND NOT APPLE)
+ # Linux: colon-separated
+ string(REPLACE ";" ":" _OFFLOAD_RPATH_STRING "${_OFFLOAD_RPATH_DIRS}")
+ elseif(APPLE)
+ # macOS: colon-separated
+ string(REPLACE ";" ":" _OFFLOAD_RPATH_STRING "${_OFFLOAD_RPATH_DIRS}")
+ else()
+ # Windows: semicolon-separated
+ string(REPLACE ";" ";" _OFFLOAD_RPATH_STRING "${_OFFLOAD_RPATH_DIRS}")
+ endif()
+
+ # Set all properties in one call to avoid redundancy
+ set_target_properties(_offload_C PROPERTIES
+ INSTALL_RPATH "${_OFFLOAD_RPATH_STRING}"
+ BUILD_WITH_INSTALL_RPATH TRUE
+ INSTALL_RPATH_USE_LINK_PATH TRUE
+ )
+ message(STATUS "Setting RPATH for _offload_C to include: ${_OFFLOAD_RPATH_STRING}")
+endif()
+
if(VLLM_GPU_LANG STREQUAL "HIP")
#
# _rocm_C extension
diff --git a/README.md b/README.md
index 705fbcb9150b..475f7e3e1347 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,7 @@
-
-
-
-
-
-
+
+

+
Easy, fast, and cheap LLM serving for everyone
@@ -14,66 +11,136 @@ Easy, fast, and cheap LLM serving for everyone
| Documentation | Blog | Paper | Twitter/X | User Forum | Developer Slack |
-🔥 We have built a vllm website to help you get started with vllm. Please visit [vllm.ai](https://vllm.ai) to learn more.
-For events, please visit [vllm.ai/events](https://vllm.ai/events) to join us.
-
---
## About
-vLLM is a fast and easy-to-use library for LLM inference and serving.
+This is a vLLM fork based on v0.14.0 with **MoE Offload** feature, enabling efficient CPU offloading for Mixture-of-Experts (MoE) model inference.
+
+## Design Overview
+
+### Core Design Philosophy
+
+The core design principle is that the GPU no longer stores all expert weights for each layer, but instead caches only a limited number of hot experts. The CPU maintains the complete set of experts and dynamically determines which experts need to be copied to the GPU and which should be computed directly on the CPU based on actual token routing behavior.
+
+The entire mechanism revolves around:
+- Expert cache management
+- Miss buffer handling
+- Copy policy decisions
+- CPU/GPU computation overlap
+
+### Key Components
+
+1. **Python Offload Manager (CpuOffloadInfer)**: Orchestrates the offload process, manages expert cache state, and coordinates GPU-CPU interactions
+2. **GPU Expert Cache**: Limited-capacity cache storing hot experts on GPU
+3. **Miss Expert Buffer (double-buffered)**: Temporary buffer for experts that miss the cache during forward passes
+4. **CPU MoE Execution Engine**: AVX/AMX-optimized kernels for computing expert forward passes on CPU
+5. **GPU↔CPU Callback-based Synchronization**: Asynchronous communication mechanism for coordinating GPU and CPU execution
+
+### Initialization Phase
+
+During model initialization:
+- All MoE expert weights for each layer are fully loaded and permanently resident in CPU pinned memory
+- The GPU allocates an Expert Cache with capacity `cache_expert_num` for each layer, storing the most frequently accessed experts
+- The GPU cache is not static; experts are dynamically managed based on runtime token routing behavior
+
+To track the state of experts in the GPU cache, the system maintains per-layer metadata:
+- `cache_map`: Maps expert IDs to their positions in the GPU cache
+- `miss_map`: Tracks which experts are currently in the miss buffer
+- `policy_sort`: Maintains priority ordering for expert replacement decisions
+
+### Forward Pass Execution Flow
+
+#### Step 1: Expert Cache Policy Matching
+
+At the start of a forward pass, the model has already obtained `topk_ids` for each token from the router. The system calls `expert_cache_policy` to match these `topk_ids` against the current layer's cache state.
-Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
+This process outputs two key pieces of information:
+1. `cpu_topk_ids`: Which tokens' experts require CPU computation
+2. `copy_map`: The set of experts that need to be copied from CPU to GPU in this forward pass
-vLLM is fast with:
+**Important**: `copy_map` does not directly correspond to "experts copied to GPU cache". It is simply a list of experts that need to be copied in this pass, and their final destination depends on the execution mode.
-- State-of-the-art serving throughput
-- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
-- Continuous batching of incoming requests
-- Fast model execution with CUDA/HIP graph
-- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516), INT4, INT8, and FP8
-- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer
-- Speculative decoding
-- Chunked prefill
+#### Step 2: Execution Mode Selection
-vLLM is flexible and easy to use with:
+The system operates in two primary execution modes:
-- Seamless integration with popular Hugging Face models
-- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
-- Tensor, pipeline, data and expert parallelism support for distributed inference
-- Streaming outputs
-- OpenAI-compatible API server
-- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
-- Prefix caching support
-- Multi-LoRA support
+**DBO Mode (Dual Batch Overlap)**
-vLLM seamlessly supports most popular open-source models on HuggingFace, including:
+When the system is in DBO mode or in decode/small batch scenarios, the forward pass enters a fully parallel CPU-GPU execution path:
-- Transformer-like LLMs (e.g., Llama)
-- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
-- Embedding Models (e.g., E5-Mistral)
-- Multi-modal LLMs (e.g., LLaVA)
+- Experts in `copy_map` are asynchronously copied to the GPU Expert Cache for subsequent `fused_experts` computation
+- CPU immediately begins computing miss experts
+- CPU computation, GPU computation, and expert copying are deliberately placed in different execution threads
+- Overlap is achieved through vLLM's DBO scheduling mechanism: while the GPU computes fused experts for the current batch, the CPU is already working on miss experts for the next step or the same step, maximizing resource utilization and reducing decode latency
-Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
+**Prefetch Mode**
-## Getting Started
+In Prefetch mode (typically for larger prefill batches), system behavior adjusts based on the number of tokens in the batch:
-Install vLLM with `pip` or [from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source):
+- As token count increases, more experts are triggered in the forward pass
+- The system dynamically calculates `n_copy` to limit the maximum number of experts copied in this pass
+- If `n_copy` is less than the total number of experts:
+ - CPU still participates in computation
+ - Experts in `copy_map` are not placed in the GPU cache
+ - Instead, they are copied to a dedicated Miss Expert Buffer (`temp_layer`)
+ - GPU uses this temp buffer to execute `fused_experts`
+ - CPU computes the remaining experts that were not copied
+ - Results from both paths are merged at the output stage
+- When batch size is extremely large and `n_copy` covers all or nearly all experts:
+ - The system automatically degrades to "full GPU mode"
+ - CPU no longer participates in computation
+ - All experts are copied and `fused_experts` computation is completed on the GPU side
+ - This is not an additional branch logic, but a natural consequence of the Prefetch strategy when copy count reaches the threshold
+
+**Double-Buffered Miss Expert Buffer Management**: To prevent miss experts from being overwritten during cross-layer execution, the system globally maintains only two Miss Expert Buffers, using `layer_id % 2` for double-buffering:
+- Even-numbered layers use buffer 0
+- Odd-numbered layers use buffer 1
+
+By coordinating with independent CUDA streams and events:
+- Copy and computation on the same buffer are strictly serialized
+- Different buffers can form a natural pipeline
+- Expert copying and computation for adjacent layers can interleave, enabling efficient pipelining
+
+## Installation
+
+Install this version in development mode:
+
+```bash
+pip install -e .
+```
+
+## Usage
+
+### Example 1: 4 GPU Setup (TP=4)
```bash
-pip install vllm
+CUDA_VISIBLE_DEVICES='2,3,4,5' vllm serve /home/models/DeepSeek-R1/ \
+--trust-remote-code --max-num-seqs 4 --tensor_parallel_size 4 --distributed-executor-backend "mp" \
+--compilation-config '{"cudagraph_capture_sizes": [1,2,4]}' \
+--enable-dbo --dbo-decode-token-threshold 2 --dbo-prefill-token-threshold 16384 --max-model-len 16384 --no-enable-chunked-prefill --no-enable-prefix-caching --moe-offload \
+--moe-offload-cache-expert-num 32 --moe-offload-cache-topk 2 --moe-offload-update-expert-num 2 --moe-offload-context-num-threads 14
```
-Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
+### Example 2: 8 GPU Setup (TP=8)
-- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation.html)
-- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html)
-- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html)
+```bash
+CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' vllm serve /home/models/DeepSeek-R1/ \
+--trust-remote-code --max-num-seqs 8 --tensor_parallel_size 8 --distributed-executor-backend "mp" \
+--compilation-config '{"cudagraph_capture_sizes": [1,2,4,8]}' \
+--enable-dbo --dbo-decode-token-threshold 2 --dbo-prefill-token-threshold 16384 --max-model-len 16384 --no-enable-chunked-prefill --no-enable-prefix-caching --moe-offload \
+--moe-offload-cache-expert-num 104 --moe-offload-cache-topk 2 --moe-offload-update-expert-num 2 --moe-offload-context-num-threads 6
+```
-## Contributing
+### MoE Offload Parameters
-We welcome and value any contributions and collaborations.
-Please check out [Contributing to vLLM](https://docs.vllm.ai/en/latest/contributing/index.html) for how to get involved.
+| Parameter | Description | Default | Recommended Values |
+|-----------|-------------|---------|-------------------|
+| `--moe-offload` | Enable MoE offload mode | `false` | Required to enable |
+| `--moe-offload-cache-expert-num` | Number of MoE experts cached per layer on GPU | - | TP=4: 32, TP=8: 104 |
+| `--moe-offload-cache-topk` | CPU cache computation strategy | `2` | 2 |
+| `--moe-offload-update-expert-num` | Number of experts updated in CPU MoE | `2` | 2 |
+| `--moe-offload-context-num-threads` | Number of threads per process for CPU computation | - | TP=4: 12-14, TP=8: 6 |
## Citation
@@ -88,16 +155,3 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
}
```
-## Contact Us
-
-
-- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues)
-- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
-- For coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
-- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
-- For collaborations and partnerships, please contact us at [collaboration@vllm.ai](mailto:collaboration@vllm.ai)
-
-
-## Media Kit
-
-- If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit)
diff --git a/csrc/offload/forward_context.cpp b/csrc/offload/forward_context.cpp
new file mode 100644
index 000000000000..fbd6bc838ac4
--- /dev/null
+++ b/csrc/offload/forward_context.cpp
@@ -0,0 +1,145 @@
+#include "forward_context.h"
+#include
+#include
+#include
+#include
+
+// Constructor: Initialize thread pool
+ForwardContext::ForwardContext(int max_thread_num)
+ : max_thread_num_(max_thread_num),
+ thread_state_(max_thread_num_),
+ workers_(max_thread_num_) {
+
+ // Initialize thread state for all potential threads
+ for (int i = 0; i < max_thread_num_; ++i) {
+ thread_state_[i].curr = std::make_unique>(0);
+ thread_state_[i].status = std::make_unique>(ThreadStatus::WAITING);
+ thread_state_[i].end = 0;
+ }
+
+ // Launch worker threads (thread 0 is reserved for main thread)
+ for (int i = 1; i < max_thread_num_; ++i) {
+ workers_[i] = std::thread(&ForwardContext::worker_thread, this, i);
+ }
+}
+
+// Destructor: Cleanup threads and memory
+ForwardContext::~ForwardContext() {
+ // Signal all threads to exit
+ for (int i = 0; i < max_thread_num_; ++i) {
+ thread_state_[i].status->store(ThreadStatus::EXIT, std::memory_order_release);
+ }
+
+ // Wait for worker threads to finish
+ for (int i = 1; i < max_thread_num_; ++i) {
+ if (workers_[i].joinable()) {
+ workers_[i].join();
+ }
+ }
+
+ // Free all cached memory buffers
+ for (auto& entry : memoryMap) {
+ void* buffer = std::get<0>(entry.second);
+ if (buffer) {
+ free(buffer); // Use free() for aligned_alloc() memory
+ }
+ }
+ memoryMap.clear();
+}
+
+// Memory buffer management with LRU-style reuse
+void* ForwardContext::getBuffer(const std::string& name, size_t size, size_t alignment) {
+ if (name.empty() || size == 0) {
+ return nullptr;
+ }
+
+ auto it = memoryMap.find(name);
+ if (it != memoryMap.end()) {
+ if (std::get<1>(it->second) >= size) {
+ return std::get<0>(it->second); // Reuse existing buffer
+ } else {
+ free(std::get<0>(it->second)); // Free insufficient buffer
+ }
+ }
+
+ // Allocate new aligned buffer
+ void* buffer = std::aligned_alloc(alignment, size);
+ if (buffer == nullptr) {
+ std::cerr << "Memory allocation failed for buffer: " << name
+ << " size: " << size << std::endl;
+ exit(-1);
+ }
+
+ memoryMap[name] = std::make_tuple(buffer, size);
+ return buffer;
+}
+
+// Main entry point for parallel task execution
+void ForwardContext::do_work_stealing_job(int task_num, std::function compute_func) {
+ compute_func_ = compute_func;
+ thread_num_ = std::min(max_thread_num_, task_num);
+
+ const int base = task_num / thread_num_;
+ const int remain = task_num % thread_num_;
+
+ // Configure main thread's range
+ thread_state_[0].end = base + (0 < remain);
+ thread_state_[0].curr->store(0, std::memory_order_relaxed);
+
+ // Configure and activate worker threads
+ for (int i = 1; i < thread_num_; ++i) {
+ thread_state_[i].curr->store(thread_state_[i - 1].end, std::memory_order_relaxed);
+ thread_state_[i].end = thread_state_[i - 1].end + base + (i < remain);
+ thread_state_[i].status->store(ThreadStatus::WORKING, std::memory_order_release);
+ }
+
+ // Activate main thread last (after workers are ready)
+ thread_state_[0].status->store(ThreadStatus::WORKING, std::memory_order_release);
+
+ // Main thread processes its task range (thread_id = 0)
+ process_tasks(0);
+
+ // Wait for all worker threads to complete
+ for (int i = 1; i < thread_num_; ++i) {
+ while (thread_state_[i].status->load(std::memory_order_acquire) == ThreadStatus::WORKING) {
+ // Busy-wait for completion (consider condition variables for production)
+ }
+ }
+}
+
+// Process tasks for a specific thread ID
+void ForwardContext::process_tasks(int thread_id) {
+ while (true) {
+ int task_id = thread_state_[thread_id].curr->fetch_add(1, std::memory_order_acq_rel);
+ if (task_id >= thread_state_[thread_id].end) {
+ break;
+ }
+ compute_func_(thread_id, task_id); // Pass thread_id to compute function
+ }
+
+ thread_state_[thread_id].status->store(ThreadStatus::WAITING, std::memory_order_release);
+}
+
+// Worker thread main loop (for threads 1..n)
+void ForwardContext::worker_thread(int thread_id) {
+ auto start = std::chrono::steady_clock::now();
+
+ while (true) {
+ ThreadStatus status = thread_state_[thread_id].status->load(std::memory_order_acquire);
+
+ if (status == ThreadStatus::WORKING) {
+ // Process tasks and reset idle timer
+ process_tasks(thread_id);
+ start = std::chrono::steady_clock::now();
+ } else if (status == ThreadStatus::WAITING) {
+ // Check if we should sleep to reduce CPU usage
+ auto now = std::chrono::steady_clock::now();
+ auto duration = std::chrono::duration_cast(now - start).count();
+ if (duration > 50) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ }
+ } else if (status == ThreadStatus::EXIT) {
+ return;
+ }
+ }
+}
\ No newline at end of file
diff --git a/csrc/offload/forward_context.h b/csrc/offload/forward_context.h
new file mode 100644
index 000000000000..72eeb06aa303
--- /dev/null
+++ b/csrc/offload/forward_context.h
@@ -0,0 +1,72 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+// Enumerates possible states of a worker thread
+enum class ThreadStatus {
+ WORKING, // Thread is actively processing tasks
+ WAITING, // Thread is idle and waiting for work
+ EXIT // Thread should terminate
+};
+
+// Per-thread state information for task distribution
+struct ThreadState {
+ std::unique_ptr> status; // Current thread status
+ std::unique_ptr> curr; // Current task index
+ int end; // End index (exclusive)
+};
+
+// Manages a pool of worker threads for parallel task execution
+// Provides work-stealing scheduling and memory buffer management
+class ForwardContext {
+public:
+ // Constructor: creates thread pool with specified maximum threads
+ explicit ForwardContext(int max_thread_num);
+
+ // Destructor: stops all threads and frees memory
+ ~ForwardContext();
+
+ // Disallow copy and move operations
+ ForwardContext(const ForwardContext&) = delete;
+ ForwardContext& operator=(const ForwardContext&) = delete;
+ ForwardContext(ForwardContext&&) = delete;
+ ForwardContext& operator=(ForwardContext&&) = delete;
+
+ // Get or allocate an aligned memory buffer
+ void* getBuffer(const std::string& name, size_t size, size_t alignment = 64);
+
+ // Execute tasks using work-stealing scheduling
+ void do_work_stealing_job(int task_num, std::function compute_func);
+
+private:
+ // Process tasks for a specific thread (main loop for each thread)
+ void process_tasks(int thread_id);
+
+ // Worker thread entry point
+ void worker_thread(int thread_id);
+
+ // Memory buffer cache for reuse
+ std::unordered_map> memoryMap;
+
+ // Active number of threads for current job
+ int thread_num_ = 0;
+
+ // Maximum number of threads in pool
+ int max_thread_num_ = 0;
+
+ // Per-thread state information
+ std::vector thread_state_;
+
+ // Task computation function: f(thread_id, task_id)
+ std::function compute_func_;
+
+ // Worker thread objects (index 0 is empty - main thread runs directly)
+ std::vector workers_;
+};
\ No newline at end of file
diff --git a/csrc/offload/moe.cpp b/csrc/offload/moe.cpp
new file mode 100644
index 000000000000..861b2a605a70
--- /dev/null
+++ b/csrc/offload/moe.cpp
@@ -0,0 +1,572 @@
+#include "moe.h"
+#ifndef __NVCC__
+#include
+#endif
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+MOEConfig::MOEConfig(int tp_rank, int tp_size, int expert_num, int num_experts_per_tok,
+ int hidden_size, int intermediate_size, int max_batch_token,
+ int cache_expert_num, int block_size, int cache_topk, int update_expert_num,
+ int forward_context_num_threads)
+ : tp_rank(tp_rank), tp_size(tp_size), expert_num(expert_num),
+ num_experts_per_tok(num_experts_per_tok), hidden_size(hidden_size),
+ intermediate_size(intermediate_size), max_batch_token(max_batch_token),
+ cache_expert_num(cache_expert_num), block_size(block_size),
+ cache_topk(cache_topk), update_expert_num(update_expert_num),
+ forward_context_num_threads(forward_context_num_threads) {}
+
+
+Moe::Moe(float8_e4m3_t* w13_weights, float8_e4m3_t* w2_weights,
+ float* w13_scales, float* w2_scales, int layer_id,
+ const MOEConfig& config, ForwardContext* ctx)
+ : config_(config), layer_id_(layer_id), m_w13_weights(w13_weights), m_w2_weights(w2_weights),
+ m_w13_scale(w13_scales), m_w2_scale(w2_scales), ctx(ctx){
+
+ // 验证指针非空
+ if (!w13_weights || !w2_weights || !w13_scales || !w2_scales) {
+ throw std::invalid_argument("Weight/scale pointers cannot be null");
+ }
+
+ if (!ctx) {
+ this->ctx = new ForwardContext(config.forward_context_num_threads);
+ std::cout << "ForwardContext created" << std::endl;
+ owns_ctx_ = true;
+ if (!set_tiledata_use()) {
+ throw std::runtime_error("Failed to enable AMX tile data. Ensure CPU supports AMX.");
+ }
+ } else {
+ this->ctx = ctx;
+ owns_ctx_ = false;
+ }
+
+}
+
+Moe::~Moe() {
+ if (owns_ctx_ && ctx) {
+ delete ctx;
+ ctx = nullptr;
+ }
+}
+
+AsyncState::AsyncState()
+ : gpu_signal(0),
+ callback_completed(1),
+ layer_idx(0),
+ batch_idx(0),
+ num_tokens(0),
+ sync_count(0),
+ submit_count(0),
+ complete_count(0) {}
+
+
+static float act_fn(float x) {
+ return x / (1.0f + expf(-x)); // expf,fabsf
+}
+
+void Moe::topk_sort_inplace(int *topk_ids, float *topk_weights, int n_tokens,
+ int num_experts_per_tok, int num_experts) {
+ //std::cout<<"topk_Sort"<do_work_stealing_job(n_tokens, [&](int thread_id, int idx){
+ int token_id = idx;
+ int write_idx = 0;
+ for (int j = 0; j < top_k; ++j) {
+ int id = topk_ids[token_id * top_k + j];
+ if (id != -1) {
+ topk_ids[token_id * top_k + write_idx] = id;
+ if (j != write_idx) {
+ topk_weights[token_id * top_k + write_idx] =
+ topk_weights[token_id * top_k + j];
+ }
+ ++write_idx;
+ if (write_idx == num_experts) break;
+ }
+ }
+ for (int j = write_idx; j < top_k; ++j) {
+ topk_ids[token_id * top_k + j] = -1;
+ }
+ });
+}
+
+void Moe::packet_input(bfloat16_t *input, bfloat16_t *padding_buf, std::vector ids, int stride)
+{
+ static const int MB = 32;
+ static const int KB = 32;
+ int hidden_size = config_.hidden_size;
+ ctx->do_work_stealing_job(ids.size(), [&](int thread_id, int idx){
+ for(int k = 0; k < stride; k+=KB)
+ {
+ memcpy(padding_buf + idx / MB * MB * hidden_size + idx % MB * KB + k * MB, input + ids[idx] * stride + k, KB * sizeof(bfloat16_t));
+ }
+ });
+
+}
+
+// 合并专家输出
+void Moe::combine_expert_output(float *output, float *expert_output,
+ const std::vector& ids,
+ const std::vector& weights,
+ int stride) {
+ constexpr int kStep = 16;
+ const int n = ids.size();
+ ctx->do_work_stealing_job(n, [&](int thread_id, int idx) {
+ int token_id = ids[idx];
+ __m512 weight_vec = _mm512_set1_ps(weights[idx]);
+ float *out_ptr = output + token_id * stride;
+ float *data = expert_output + idx * stride;
+
+ for (int i = 0; i < stride; i += kStep) {
+ __m512 v = _mm512_loadu_ps(data + i);
+ __m512 o = _mm512_loadu_ps(out_ptr + i);
+ __m512 c = _mm512_fmadd_ps(v, weight_vec, o);
+ _mm512_storeu_ps(out_ptr + i, c);
+ }
+ });
+
+}
+
+
+void Moe::forward_single_expert(bfloat16_t *input, float* output, int expert_id, int n_tokens)
+{
+ std::string nvtx_msg = "forwardSingleExpert_expert_" + std::to_string(expert_id) +
+ "_ntoks_" + std::to_string(n_tokens);
+ nvtxRangePushA(nvtx_msg.c_str());
+
+ constexpr int BM = 32;
+ constexpr int BN = 32;
+
+ int intermediate_size = config_.intermediate_size;
+ int hidden_size = config_.hidden_size;
+ int block_size = config_.block_size;
+ //std::cout << "n_tokens = " << n_tokens << std::endl;
+ float* gate_up_output = static_cast(ctx->getBuffer("gate_up_output", n_tokens * intermediate_size * 2 * sizeof(float)));
+ bfloat16_t* m_down_input = static_cast(ctx->getBuffer("down_input", n_tokens * intermediate_size * sizeof(bfloat16_t)));
+
+ memset(gate_up_output, 0, n_tokens * intermediate_size * 2 * sizeof(float));
+ memset(output, 0, n_tokens * hidden_size * sizeof(float));
+
+ int nth = intermediate_size * 2 / BN; // 每个任务计算input BM行, weight BN列 7168 * 2 gate/up 跨度 7168 * 32
+ ctx->do_work_stealing_job((n_tokens / BM) * nth , [&](int thread_id, int idx){
+ // idx / nth ==> 下一组数据
+ // input idx / nth * BM * hidden_size OK!
+ // weight/scale OK!
+ // out idx/ nth * BM * intermediate_size * 2
+ bfloat16_t* input_ptr = input + idx / nth * BM * hidden_size;
+ float8_e4m3_t* weights = m_w13_weights + expert_id * intermediate_size * hidden_size * 2 + idx % nth * BN * hidden_size;
+ float* scale = m_w13_scale + expert_id * intermediate_size/block_size * hidden_size/block_size * 2 + (idx % nth * BN/ block_size) * (hidden_size/block_size);
+ float* out = gate_up_output + (idx / nth) * BM * intermediate_size * 2 + idx % nth * BM * BN;
+ amx_gemm_block_32_K_32(input_ptr, weights, scale, out, hidden_size, 32);
+ });
+
+ //dump_martix(gate_up_output, 32, 8, "gate_up_output", BM);
+ //dump_martix(gate_up_output + BM * intermediate_size * 2, 2, 8, "gate_up_output", BM);
+ // 每个任务处理一行内容
+ // src = idx / BN * BN * intermediate_size * 2 + idx %BN * intermediate_size
+ // dst = idx / BN * BN * intermediate_size + idx %BN * intermediate_size
+ // 64行out 0 -> 0 32
+ // 1 -> 1 33
+ // 2 -> 2 34
+ // 32 -> 64 96
+ // 33 -> 65
+
+ ctx->do_work_stealing_job(n_tokens, [&](int thread_id, int idx){
+ //int start = idx / 32 * 32 * intermediate_size * 2 + idx % 32 * intermediate_size;
+ //int end = start + intermediate_size;
+ //for (int j = start; j < end; j++)
+ //{
+ //m_down_input[j] = bfloat16_t::from_float(act_fn(gate_up_output[j]) * gate_up_output[j + intermediate_size * 32]);
+ //}
+ float* gate = gate_up_output + idx / BN * BN * intermediate_size * 2 + idx % BN * intermediate_size;
+ float* up = gate + intermediate_size * BN;
+ bfloat16_t* out = m_down_input + idx / BN * BN * intermediate_size + idx % BN * intermediate_size;
+ for( int j = 0; j < intermediate_size; j++)
+ {
+ out[j] = bfloat16_t::from_float(act_fn(gate[j]) * up[j]);
+ }
+ });
+
+ // 每个任务计算input BM行, weight BN列
+ nth = hidden_size / BN;
+ ctx->do_work_stealing_job((n_tokens / BM) * nth, [&](int thread_id, int idx) {
+ bfloat16_t* down_input_ptr = m_down_input + idx / nth * BM * intermediate_size;
+ float* down_output = output + idx / nth * BM * hidden_size + idx % nth * BN;
+
+ float8_e4m3_t* weights = m_w2_weights + expert_id * intermediate_size * hidden_size + idx % nth * BN * intermediate_size;
+ float* scale = m_w2_scale + expert_id * (intermediate_size/block_size) * (hidden_size / block_size) + (idx % nth * BN/ block_size) * (intermediate_size / block_size);
+
+ amx_gemm_block_32_K_32(down_input_ptr, weights, scale, down_output, intermediate_size, hidden_size);
+ });
+
+ nvtxRangePop();
+}
+
+
+void Moe::forward_experts(bfloat16_t *input, int *topk_ids, float *topk_weights, bfloat16_t *output, int n_tokens, ForwardContext *ctx)
+{
+ std::string nvtx_msg = "forwardExperts_ntoks_" + std::to_string(n_tokens);
+ nvtxRangePushA(nvtx_msg.c_str());
+
+ const int expert_num = config_.expert_num;
+ const int top_k = config_.num_experts_per_tok;
+ const int hidden_size = config_.hidden_size;
+ constexpr int MB = 32;
+
+ std::vector ids[expert_num]; // index for each expert
+ std::vector weights[expert_num]; // weight for each expert
+ float *output_fp32_buf = (float *)ctx->getBuffer("output_fp32", n_tokens * hidden_size * sizeof(float));
+ memset(output_fp32_buf, 0, n_tokens * hidden_size * sizeof(float));
+ for (int i = 0; i < n_tokens; ++i) {
+ for (int j = 0; j < top_k; ++j) {
+ int expert_id = topk_ids[i * top_k + j];
+ if (expert_id < 0) break;
+ ids[expert_id].push_back(i);
+ weights[expert_id].push_back(topk_weights[i * top_k + j]);
+ }
+ }
+
+
+ for(int expert_id = 0; expert_id < expert_num; expert_id++)
+ {
+ if(ids[expert_id].size()==0) continue;
+ int padding_len = (ids[expert_id].size() + MB - 1) / MB * MB;
+ bfloat16_t * input_packet = static_cast(ctx->getBuffer("input_packet", padding_len * config_.hidden_size * sizeof(bfloat16_t)));
+ float * m_down_output = static_cast(ctx->getBuffer("m_down_output", padding_len * config_.hidden_size * sizeof(float)));
+ packet_input(input, input_packet, ids[expert_id], config_.hidden_size);
+
+ forward_single_expert(input_packet, m_down_output, expert_id, padding_len);
+
+ combine_expert_output(output_fp32_buf, m_down_output, ids[expert_id], weights[expert_id], config_.hidden_size);
+ }
+
+ ctx->do_work_stealing_job(n_tokens, [&](int thread_id, int idx){
+ fp32_to_bf16(output_fp32_buf + idx * hidden_size, output + idx * hidden_size, hidden_size);
+ });
+
+ nvtxRangePop();
+}
+
+void Moe::forward_sparse(bfloat16_t *input, int *topk_ids, float *topk_weights, bfloat16_t *output, int n_tokens, ForwardContext *ctx)
+{
+ std::string nvtx_msg = "forwardSparse_ntoks_" + std::to_string(n_tokens);
+ nvtxRangePushA(nvtx_msg.c_str());
+
+ int task_num = 0;
+ int intermediate_size = config_.intermediate_size;
+ int hidden_size = config_.hidden_size;
+ int top_k = config_.num_experts_per_tok;
+ int block_size = config_.block_size;
+
+
+
+
+ float *m_gate_up_output = (float*)ctx->getBuffer("m_gate_up_output", n_tokens * top_k * hidden_size * 2 * sizeof(float));
+ bfloat16_t *m_down_input = (bfloat16_t*)ctx->getBuffer("m_down_input", n_tokens * top_k * intermediate_size * sizeof(bfloat16_t));
+ float *m_down_output = (float *)ctx->getBuffer("down_output", n_tokens * top_k * hidden_size * sizeof(float));
+
+ int* input_ids = (int*)ctx->getBuffer("input_ids", n_tokens * top_k * sizeof(int));
+ int* expert_ids = (int*)ctx->getBuffer("expert_ids", n_tokens * top_k * sizeof(int));
+ int* task_ids = (int*)ctx->getBuffer("task_ids", n_tokens * top_k * sizeof(int));
+ int* n_act = (int*)ctx->getBuffer("n_act", n_tokens * sizeof(int));
+
+ for(int i = 0; i < n_tokens; i++) {
+ int act = 0;
+ for(int j = 0; j < top_k; j++)
+ {
+ if(topk_ids[i * top_k + j] == -1) break;
+ input_ids[task_num] = i;
+ expert_ids[task_num] = topk_ids[i * top_k + j];
+ task_ids[i * top_k + j] = task_num;
+ act ++;
+ task_num++;
+ }
+ n_act[i] = act;
+ }
+
+
+ if(task_num != 0)
+ {
+ uint64_t stride = block_size;
+ uint64_t nth = 2 * intermediate_size / stride;
+ uint64_t weight_size = hidden_size * intermediate_size * 2;
+ uint64_t weight_stride_size = stride * hidden_size;
+ uint64_t scale_size = (hidden_size / block_size) * (intermediate_size / block_size) * 2;
+ uint64_t scale_stride_size = hidden_size / block_size;
+
+
+ ctx->do_work_stealing_job(nth * task_num, [&](int thread_id, int idx){
+ uint64_t task_id = idx / nth;
+ uint64_t ith = idx % nth;
+
+ bfloat16_t* input_ptr = input + input_ids[task_id] * hidden_size;
+ float8_e4m3_t* weights = m_w13_weights + expert_ids[task_id] * weight_size + ith * weight_stride_size;
+ float* scale = m_w13_scale + expert_ids[task_id] * scale_size + ith * scale_stride_size;
+ float* gate_up_output = m_gate_up_output + task_id * intermediate_size * 2 + ith * stride;
+
+ gemv_anni_grouped(input_ptr, (const uint8_t *)weights, scale, gate_up_output, stride, hidden_size, block_size);
+
+ });
+
+
+ nth = intermediate_size / stride;
+
+ ctx->do_work_stealing_job(nth * task_num, [&](int thread_id, int idx){
+ uint64_t task_id = idx / nth;
+ uint64_t ith = idx % nth;
+
+ bfloat16_t* down_input_ptr = m_down_input + task_id * intermediate_size + ith * stride;
+ float* gate_up_output = m_gate_up_output + task_id * intermediate_size * 2 + ith * stride;
+
+ for (uint64_t j = 0; j < stride; j++)
+ {
+ down_input_ptr[j] = bfloat16_t::from_float(act_fn(gate_up_output[j]) * gate_up_output[j + intermediate_size]);
+ }
+ });
+
+
+ weight_size = hidden_size * intermediate_size;
+ weight_stride_size = stride * intermediate_size;
+ scale_size = (hidden_size / block_size) * (intermediate_size / block_size);
+ scale_stride_size = intermediate_size / block_size;
+
+
+ nth = hidden_size / stride;
+ ctx->do_work_stealing_job(nth * task_num, [&](int thread_id, int idx) {
+ uint32_t task_id = idx / nth;
+ uint64_t ith = idx % nth;
+
+ bfloat16_t* down_input_ptr = m_down_input + task_id * intermediate_size;
+ float8_e4m3_t* weights = m_w2_weights + expert_ids[task_id] * weight_size + ith * weight_stride_size;
+ float* scale = m_w2_scale + expert_ids[task_id] * scale_size + ith * scale_stride_size;
+ float* down_output = m_down_output + task_id * hidden_size + ith * stride;
+
+ gemv_anni_grouped(down_input_ptr, (const uint8_t *)weights, scale, down_output, stride, intermediate_size, block_size);
+
+ });
+
+ }
+
+
+
+
+ ctx->do_work_stealing_job(n_tokens, [&](int thread_id, int idx) {
+ __m512 vw[8];
+ int active_num = n_act[idx];
+
+ for(int t = 0; t < active_num; t++)
+ {
+ vw[t] = _mm512_set1_ps(topk_weights[idx * top_k + t]);
+ }
+
+ for (int m = 0; m < config_.hidden_size; m+=16)
+ {
+ __m512 vo = _mm512_setzero_ps();
+ for (int j = 0; j < active_num; j++) {
+ __m512 vi = _mm512_load_ps(m_down_output + m + task_ids[idx * top_k + j] * hidden_size);
+ vo = _mm512_fmadd_ps(vi, vw[j], vo);
+ }
+ _mm256_storeu_si256((__m256i_u*)(output + idx * hidden_size + m) , (__m256i)_mm512_cvtneps_pbh(vo));
+ }
+ });
+
+
+ nvtxRangePop();
+}
+
+void Moe::forward(bfloat16_t *input, int *topk_ids,
+ float *topk_weights, bfloat16_t *output,
+ int num_tokens) {
+ std::string nvtx_msg = "forwardMoE_layer_" + std::to_string(layer_id_) +
+ "_ntoks_" + std::to_string(num_tokens);
+ nvtxRangePushA(nvtx_msg.c_str());
+
+
+ topk_sort_inplace(topk_ids, topk_weights, num_tokens, config_.num_experts_per_tok, config_.expert_num);
+ if (num_tokens < 128) {
+ forward_sparse(input, topk_ids, topk_weights, output, num_tokens, ctx);
+ } else {
+ forward_experts(input, topk_ids, topk_weights, output, num_tokens, ctx);
+ }
+
+ nvtxRangePop();
+}
+
+void Moe::forward(torch::Tensor input, torch::Tensor topk_ids,
+ torch::Tensor topk_weights, torch::Tensor output,
+ int num_tokens) {
+ TORCH_CHECK(input.device().is_cpu(), "Input must be CPU tensor");
+ TORCH_CHECK(topk_ids.dtype() == torch::kInt32, "topk_ids must be int32");
+ TORCH_CHECK(num_tokens <= config_.max_batch_token, "num_tokens exceeds max_batch_token");
+
+
+
+ auto* input_ptr = reinterpret_cast(input.data_ptr());
+ auto* topk_ids_ptr = topk_ids.data_ptr();
+ auto* topk_weights_ptr = topk_weights.data_ptr();
+ auto* output_ptr = reinterpret_cast(output.data_ptr());
+
+
+ forward(input_ptr, topk_ids_ptr, topk_weights_ptr, output_ptr, num_tokens);
+}
+
+
+
+
+// ==================== MoeOffloadEngine 实现 ====================
+MoeOffloadEngine::MoeOffloadEngine(const MOEConfig& config)
+ : config_(config),
+ cpu_state_(nullptr),
+ gpu_state_(nullptr),
+ shutdown_(false),
+ async_state_initialized_(false) {
+
+ try {
+ if (!set_tiledata_use()) {
+ throw std::runtime_error("Failed to enable AMX tile data. Ensure CPU supports AMX.");
+ }
+ forward_context_ = std::make_unique(config.forward_context_num_threads);
+ } catch (const std::exception& e) {
+ std::cerr << "[ERROR] MoeOffloadEngine: Failed to initialize ForwardContext: " << e.what() << std::endl;
+ throw;
+ }
+
+ try {
+
+ size_t output_size = config_.max_batch_token * config_.hidden_size * 2; // BFloat16 = 2 bytes
+ size_t topk_ids_size = config_.max_batch_token * config_.num_experts_per_tok * sizeof(int32_t);
+ size_t topk_weights_size = config_.max_batch_token * config_.num_experts_per_tok * sizeof(float);
+ size_t hidden_states_size = config_.max_batch_token * config_.hidden_size * 2; // BFloat16 = 2 bytes
+ size_t total_size = output_size + topk_ids_size + topk_weights_size + hidden_states_size;
+
+ output_ = torch::zeros({config_.max_batch_token, config_.hidden_size}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kBFloat16).pinned_memory(true));
+
+ topk_ids_ = torch::zeros({config_.max_batch_token, config_.num_experts_per_tok}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt32).pinned_memory(true));
+
+ topk_weights_ = torch::zeros({config_.max_batch_token, config_.num_experts_per_tok}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kFloat32).pinned_memory(true));
+
+ hidden_states_ = torch::zeros({config_.max_batch_token, config_.hidden_size}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kBFloat16).pinned_memory(true));
+
+ initialize_async_state();
+ } catch (const std::bad_alloc& e) {
+ std::cerr << "[ERROR] MoeOffloadEngine: Memory allocation failed (std::bad_alloc)" << std::endl;
+ std::cerr << " max_batch_token=" << config_.max_batch_token
+ << ", hidden_size=" << config_.hidden_size
+ << ", num_experts_per_tok=" << config_.num_experts_per_tok << std::endl;
+ std::cerr << " Total pinned memory required: "
+ << (config_.max_batch_token * config_.hidden_size * 2 * 2 + // 2 BFloat16 tensors, 2 bytes each
+ config_.max_batch_token * config_.num_experts_per_tok * (sizeof(int32_t) + sizeof(float))) / (1024.0 * 1024.0)
+ << " MB" << std::endl;
+ std::cerr << " Suggestion: Reduce max_batch_token or check available system memory" << std::endl;
+ throw std::runtime_error("Failed to allocate pinned memory for MoeOffloadEngine: " + std::string(e.what()));
+ } catch (const std::exception& e) {
+ std::cerr << "[ERROR] MoeOffloadEngine: Exception during initialization: " << e.what() << std::endl;
+ throw;
+ }
+}
+
+MoeOffloadEngine::~MoeOffloadEngine() {
+ if (async_state_initialized_) {
+ shutdown_.store(true);
+ if (polling_thread_.joinable()) {
+ polling_thread_.join();
+ }
+ cleanup_async_state();
+ }
+
+}
+
+void MoeOffloadEngine::create_cpu_moe_layer(torch::Tensor w13_weight,
+ torch::Tensor w2_weight,
+ torch::Tensor w13_scale,
+ torch::Tensor w2_scale,
+ int layer_id) {
+ TORCH_CHECK(layer_id >= 0, "layer_id must be >= 0");
+ TORCH_CHECK(w13_weight.device().is_cpu(), "w13_weight must be CPU tensor");
+ TORCH_CHECK(w2_weight.device().is_cpu(), "w2_weight must be CPU tensor");
+ TORCH_CHECK(w13_scale.device().is_cpu(), "w13_scale must be CPU tensor");
+ TORCH_CHECK(w2_scale.device().is_cpu(), "w2_scale must be CPU tensor");
+
+
+ auto* w13_ptr = reinterpret_cast(w13_weight.data_ptr());
+ auto* w2_ptr = reinterpret_cast(w2_weight.data_ptr());
+ auto* w13_scale_ptr = w13_scale.data_ptr();
+ auto* w2_scale_ptr = w2_scale.data_ptr();
+
+ try {
+
+ ForwardContext* ctx = forward_context_.get();
+
+ auto result = moe_layers_.emplace(
+ std::piecewise_construct,
+ std::forward_as_tuple(layer_id),
+ std::forward_as_tuple(w13_ptr, w2_ptr, w13_scale_ptr, w2_scale_ptr, layer_id, config_, ctx)
+ );
+
+ TORCH_CHECK(result.second, "Layer with id ", layer_id, " already exists");
+ } catch (const std::bad_alloc& e) {
+ std::cerr << "[ERROR] MoeOffloadEngine::create_cpu_moe_layer: Memory allocation failed (std::bad_alloc)" << std::endl;
+ std::cerr << " layer_id=" << layer_id << std::endl;
+ std::cerr << " This may be due to ForwardContext thread pool allocation failure" << std::endl;
+ throw std::runtime_error("Failed to create CPU MoE layer: " + std::string(e.what()));
+ } catch (const std::exception& e) {
+ std::cerr << "[ERROR] MoeOffloadEngine::create_cpu_moe_layer: Exception: " << e.what() << std::endl;
+ std::cerr << " layer_id=" << layer_id << std::endl;
+ throw;
+ }
+}
+
+Moe* MoeOffloadEngine::get_cpu_moe_layer(int layer_id) {
+ auto it = moe_layers_.find(layer_id);
+ return (it != moe_layers_.end()) ? &it->second : nullptr;
+}
+
+const Moe* MoeOffloadEngine::get_cpu_moe_layer(int layer_id) const {
+ auto it = moe_layers_.find(layer_id);
+ return (it != moe_layers_.end()) ? &it->second : nullptr;
+}
+
+void MoeOffloadEngine::call(int layer_idx, int batch_idx, int num_tokens) {
+ std::string nvtx_msg = "MoeOffloadEngine_call_layer_" + std::to_string(layer_idx) +
+ "_batch_" + std::to_string(batch_idx) +
+ "_ntoks_" + std::to_string(num_tokens);
+ nvtxRangePushA(nvtx_msg.c_str());
+
+ auto layer = get_cpu_moe_layer(layer_idx);
+ TORCH_CHECK(layer != nullptr, "Layer ", layer_idx, " not found");
+
+ layer->forward(hidden_states_, topk_ids_, topk_weights_, output_, num_tokens);
+
+ nvtxRangePop();
+}
+
+void MoeOffloadEngine::update_expert_cache(torch::Tensor w13_cache, torch::Tensor w2_cache,
+ torch::Tensor w13_scale_cache, torch::Tensor w2_scale_cache,
+ torch::Tensor map, int layer_id, int num_experts) {
+ auto layer = get_cpu_moe_layer(layer_id);
+ TORCH_CHECK(layer != nullptr, "Layer ", layer_id, " not found");
+ layer->update_expert_cache(w13_cache, w2_cache, w13_scale_cache, w2_scale_cache, map, num_experts);
+}
+
+void MoeOffloadEngine::cpu_polling_loop() {
+ std::cout << "cpu_polling_loop thread Start!" << std::endl;
+ while (!shutdown_.load(std::memory_order_acquire)) {
+ if (cpu_state_->gpu_signal == 1) {
+ nvtxRangePushA("cpu_callback_func");
+ call(cpu_state_->layer_idx, cpu_state_->batch_idx, cpu_state_->num_tokens);
+ nvtxRangePop();
+
+ std::atomic_thread_fence(std::memory_order_seq_cst);
+ cpu_state_->callback_completed = 1;
+ cpu_state_->complete_count += 1;
+ cpu_state_->gpu_signal = 0;
+ }
+ std::this_thread::sleep_for(std::chrono::microseconds(10));
+ }
+}
diff --git a/csrc/offload/moe.h b/csrc/offload/moe.h
new file mode 100644
index 000000000000..307b4e08ff9d
--- /dev/null
+++ b/csrc/offload/moe.h
@@ -0,0 +1,174 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "primitives.h"
+#include "forward_context.h"
+
+#ifndef cudaErrorNotInitialized
+#define cudaErrorNotInitialized ((cudaError_t)1000) // Arbitrary unused error code
+#endif
+
+// MOE配置结构体
+struct MOEConfig {
+ int tp_rank = 0;
+ int tp_size = 1;
+ int expert_num = 0;
+ int num_experts_per_tok = 0;
+ int hidden_size = 0;
+ int intermediate_size = 0;
+ int max_batch_token = 0;
+ int cache_expert_num = 0;
+ int block_size = 0;
+ int cache_topk = 0;
+ int update_expert_num = 0;
+ int forward_context_num_threads = 14; // ForwardContext 线程数
+ bool normTopKProb = false;
+ int nGroup = 0;
+ int topKGroup = 0;
+
+ MOEConfig() = default;
+ MOEConfig(int tp_rank, int tp_size, int expert_num, int num_experts_per_tok,
+ int hidden_size, int intermediate_size, int max_batch_token,
+ int cache_expert_num, int block_size, int cache_topk, int update_expert_num,
+ int forward_context_num_threads = 14);
+};
+
+struct AsyncState {
+ // FIXED: 移除 volatile,使用 memory fence 保证同步
+ int32_t gpu_signal;
+ int32_t callback_completed;
+ int32_t layer_idx;
+ int32_t batch_idx;
+ int32_t num_tokens;
+ int32_t sync_count;
+ int32_t submit_count;
+ int32_t complete_count;
+
+ AsyncState();
+};
+// MoE主类
+class Moe {
+public:
+ // 构造函数:接收指针和形状信息
+ Moe(float8_e4m3_t* w13_weights, float8_e4m3_t* w2_weights,
+ float* w13_scales, float* w2_scales, int layer_id,
+ const MOEConfig& config, ForwardContext* ctx = nullptr);
+
+ ~Moe();
+ int layer_id() const { return layer_id_;}
+ // 前向接口:接收原始指针
+ void forward(bfloat16_t* input, int* topk_ids, float* topk_weights,
+ bfloat16_t* output, int num_tokens);
+
+ // 前向接口:接收 Tensor(重载)
+ void forward(torch::Tensor input, torch::Tensor topk_ids,
+ torch::Tensor topk_weights, torch::Tensor output, int num_tokens);
+
+ // CUDA 方法:更新 expert cache
+ void update_expert_cache(torch::Tensor w13_cache, torch::Tensor w2_cache,
+ torch::Tensor w13_scale_cache, torch::Tensor w2_scale_cache,
+ torch::Tensor map, int64_t num_experts);
+private:
+ // 核心方法
+ void topk_sort_inplace(int *topk_ids, float *topk_weights, int n_tokens,
+ int num_experts_per_tok, int num_experts);
+
+ void packet_input(bfloat16_t *input, bfloat16_t *padding_buf, std::vector ids, int stride);
+
+ void combine_expert_output(float *output, float *expert_output,
+ const std::vector& ids,
+ const std::vector& weights,
+ int stride);
+
+ void forward_single_expert(bfloat16_t *input, float* output, int expert_id, int n_tokens);
+
+ void forward_experts(bfloat16_t *input, int *topk_ids, float *topk_weights, bfloat16_t *output, int n_tokens, ForwardContext *ctx);
+
+ void forward_sparse(bfloat16_t *input, int *topk_ids, float *topk_weights, bfloat16_t *output, int n_tokens, ForwardContext *ctx);
+
+ // 成员变量
+ MOEConfig config_;
+ int layer_id_;
+
+ // 权重指针
+ float8_e4m3_t* m_w13_weights = nullptr;
+ float8_e4m3_t* m_w2_weights = nullptr;
+ float* m_w13_scale = nullptr;
+ float* m_w2_scale = nullptr;
+
+ // 权重buffer大小(用于边界检查)
+ int64_t w13_weight_size_ = 0;
+ int64_t w2_weight_size_ = 0;
+ int64_t w13_scale_size_ = 0;
+ int64_t w2_scale_size_ = 0;
+
+ ForwardContext* ctx = nullptr;
+ bool owns_ctx_ = false;
+};
+
+
+class MoeOffloadEngine {
+ public:
+ explicit MoeOffloadEngine(const MOEConfig& config);
+ ~MoeOffloadEngine();
+
+ uintptr_t ptr() { return reinterpret_cast(this); }
+
+ void create_cpu_moe_layer(torch::Tensor w13_weight, torch::Tensor w2_weight,
+ torch::Tensor w13_scale, torch::Tensor w2_scale,
+ int layer_id);
+
+ Moe* get_cpu_moe_layer(int layer_id);
+ const Moe* get_cpu_moe_layer(int layer_id) const;
+
+ void update_expert_cache(torch::Tensor w13_cache, torch::Tensor w2_cache,
+ torch::Tensor w13_scale_cache, torch::Tensor w2_scale_cache,
+ torch::Tensor map, int layer_id, int num_experts);
+
+ void expert_cache_policy(torch::Tensor cache_map, torch::Tensor miss_map,
+ torch::Tensor policy_sort, torch::Tensor topk_ids,
+ torch::Tensor cpu_topk, torch::Tensor copy_map);
+
+ void get_output(torch::Tensor gpu_output);
+ void set_input(torch::Tensor gpu_hidden_states, torch::Tensor gpu_topk_ids,
+ torch::Tensor gpu_topk_weights);
+
+ void call(int layer_idx, int batch_idx, int num_tokens);
+ // FIXED: 移除重复类名限定
+ void initialize_async_state();
+ void cleanup_async_state();
+ void cpu_polling_loop();
+
+ // FIXED: 添加参数类型
+ cudaError_t submit(int layer_idx, int batch_idx, int num_tokens);
+ cudaError_t sync();
+
+ private:
+ MOEConfig config_;
+ // FIXED: 添加成员变量名
+ std::unordered_map moe_layers_;
+
+ std::unique_ptr forward_context_;
+
+ // CPU input/output buffers
+ torch::Tensor output_;
+ torch::Tensor topk_ids_;
+ torch::Tensor topk_weights_;
+ torch::Tensor hidden_states_;
+
+ // Async state
+ AsyncState* cpu_state_;
+ AsyncState* gpu_state_;
+ std::thread polling_thread_;
+ std::atomic shutdown_;
+ bool async_state_initialized_;
+
+ };
+
\ No newline at end of file
diff --git a/csrc/offload/moe_kernel.cu b/csrc/offload/moe_kernel.cu
new file mode 100644
index 000000000000..805213ec4c15
--- /dev/null
+++ b/csrc/offload/moe_kernel.cu
@@ -0,0 +1,596 @@
+#include
+#ifndef cudaErrorNotInitialized
+#define cudaErrorNotInitialized ((cudaError_t)1000) // Arbitrary unused error code
+#endif
+#include
+#include "moe.h"
+#include
+#include
+#include
+#include
+
+// ==================== CUDA Kernels ====================
+constexpr int MAX_NUM_EXPERT = 256;
+constexpr int MAX_CACHE_EXPERT = 128;
+constexpr int BYTES_PER_VEC = 16;
+constexpr int BN = 32;
+constexpr int BK = 32;
+constexpr int BM = 16;
+constexpr int warpSize = 32;
+
+__device__ __forceinline__ void cp_async_16B(void* smem_dst, const void* gmem_src) {
+ #if __CUDA_ARCH__ >= 800
+ uint32_t smem_addr = static_cast(__cvta_generic_to_shared(smem_dst));
+ unsigned long long gmem_addr = (unsigned long long)gmem_src;
+ asm volatile("cp.async.ca.shared.global [%0], [%1], 16;\n" :: "r"(smem_addr), "l"(gmem_addr));
+ #else
+ // fallback
+ *reinterpret_cast(smem_dst) = *reinterpret_cast(gmem_src);
+ #endif
+ }
+
+ __device__ __forceinline__ void cp_async_commit() {
+ #if __CUDA_ARCH__ >= 800
+ asm volatile("cp.async.commit_group;\n" ::: "memory");
+ #endif
+ }
+
+ template
+ __device__ __forceinline__ void cp_async_wait_impl() {
+ #if __CUDA_ARCH__ >= 800
+ asm volatile("cp.async.wait_group %0;\n" :: "n"(n) : "memory");
+ #endif
+ }
+
+
+__global__ void submit_kernel(AsyncState* data, int layer_idx, int batch_idx, int num_tokens) {
+ if (blockIdx.x == 0 && threadIdx.x == 0) {
+ data->layer_idx = layer_idx;
+ data->batch_idx = batch_idx;
+ data->num_tokens = num_tokens;
+ data->callback_completed = 0;
+ data->gpu_signal = 1;
+ data->submit_count += 1;
+ __threadfence_system(); // 确保写入对 CPU 可见
+ }
+}
+
+__global__ void sync_kernel(AsyncState* data) {
+ if (blockIdx.x == 0 && threadIdx.x == 0) {
+ while (data->callback_completed == 0) {
+ __threadfence_system(); // 确保从 CPU 读取最新值
+ }
+ data->sync_count += 1;
+ }
+}
+
+__global__ void cache_policy_kernel(
+ int* cache_map, // [256] 输入映射表
+ int* miss_map,
+ int* copy_map,
+ int* sort, // [128] 输入输出排序数组
+ int* topk, // [32] 输入topk数组
+ int* cpu_topk, // [32] 输出处理后topk
+ const int C, // total num of cached expert < 128
+ const int N, // total num of expert < 256
+ const int active_expert_num, // total num of expert in topk < 64
+ const int K,
+ const int update_expert_num
+) {
+ const int tid = threadIdx.x;
+
+ using BlockScan = cub::BlockScan;
+ __shared__ typename BlockScan::TempStorage temp_buf;
+ // 共享内存分配
+ __shared__ bool s_in_topk[MAX_NUM_EXPERT]; // topk标记
+ __shared__ bool s_in_match[MAX_NUM_EXPERT]; // match标记
+ __shared__ int s_miss_list[MAX_NUM_EXPERT]; // miss列表
+ __shared__ int s_cache_map[MAX_NUM_EXPERT];
+ __shared__ int s_sort[MAX_CACHE_EXPERT];
+
+ // 初始化共享内存
+ s_in_topk[tid] = false;
+ s_in_match[tid] = false;
+ s_miss_list[tid] = -1;
+ s_cache_map[tid] = cache_map[tid];
+
+ copy_map[tid] = -1;
+ int top_k_id = -1;
+ bool is_match=true;
+ int val = 0;
+ __syncthreads();
+
+ // 阶段1: 标记topk元素
+ if (tid < active_expert_num) {
+ top_k_id = topk[tid];
+ if (tid % 8 < K){
+ s_in_topk[top_k_id] = true;
+ }
+ }
+ __syncthreads();
+
+
+ // 阶段2: 计算miss/match
+ int offset;
+ bool match = s_cache_map[tid] > -1;
+ bool active = s_in_topk[tid];
+ const bool is_miss = active & !match;
+ s_in_match[tid] = active & match;
+
+ {
+ BlockScan(temp_buf).ExclusiveSum(is_miss ? 1 : 0, offset);
+ if(is_miss){
+ s_miss_list[offset] = tid;
+ }
+ }
+
+ // 阶段3: 重新排序,match的数据移动到末尾
+ if(tid < C){
+ val = sort[tid];
+ //printf("tid:%3d, val:%3d\n", tid, val);
+ is_match = s_in_match[val];
+ }
+ int unused_offset;
+ int used_offset;
+
+ BlockScan(temp_buf).ExclusiveSum(is_match ? 0 : 1, unused_offset);
+ BlockScan(temp_buf).ExclusiveSum(is_match ? 1 : 0, used_offset);
+
+ int new_pos = is_match ? (C - 1 - used_offset) : unused_offset;
+ if(tid < C){
+ //printf("tid:%3d, val:%3d, pos:%3d\n", tid, val, new_pos);
+ s_sort[new_pos] = val;
+ }
+ __syncthreads();
+
+ // 处理swap数据,
+ if(tid < update_expert_num){
+ int miss_id = s_miss_list[tid];
+ int evict_id = s_sort[tid];
+ int new_pos = s_cache_map[evict_id];
+
+ s_sort[tid] = miss_id >= 0 ? miss_id : evict_id;
+
+ if(miss_id >= 0)
+ {
+ //printf("missid: %d\n", miss_id);
+ int evict_pos = miss_map[miss_id];
+ copy_map[miss_id] = new_pos;
+
+ s_cache_map[miss_id] = new_pos;
+ s_cache_map[evict_id] = -1;
+ miss_map[evict_id] = evict_pos;
+ miss_map[miss_id] = -1;
+
+ s_in_match[miss_id] = true;
+ }
+ }
+ __syncthreads();
+
+ // 阶段5: 更新load expert后的sort
+ if(tid < C){
+ val = s_sort[tid];
+ is_match = s_in_match[val];
+ }
+ BlockScan(temp_buf).ExclusiveSum(is_match ? 0 : 1, unused_offset);
+ BlockScan(temp_buf).ExclusiveSum(is_match ? 1 : 0, used_offset);
+
+ new_pos = is_match ? (C - 1 - used_offset) : unused_offset;
+
+ if (tid < C){
+ sort[new_pos] = val;
+ }
+
+ // 计算topk
+ if(tid < active_expert_num)
+ {
+ int flag = (s_cache_map[top_k_id] >= 0);
+ cpu_topk[tid] = flag ? -1 : top_k_id;
+ }
+ cache_map[tid] = s_cache_map[tid];
+ __syncthreads();
+}
+
+// ==================== CpuMoeLayer CUDA 方法 ====================
+
+__global__ void update_expert_cache_kernel(
+ const uint8_t* __restrict__ src_w13,
+ const uint8_t* __restrict__ src_w2,
+ const float* __restrict__ src_scales_w13,
+ const float* __restrict__ src_scales_w2,
+ uint8_t* __restrict__ dst_w13,
+ uint8_t* __restrict__ dst_w2,
+ float* __restrict__ dst_scales_w13,
+ float* __restrict__ dst_scales_w2,
+ const int* __restrict__ map,
+ const int64_t w13_bytes_per_expert,
+ const int64_t w2_bytes_per_expert,
+ const int64_t w13_scale_bytes_per_expert,
+ const int64_t w2_scale_bytes_per_expert,
+ const int num_experts,
+ const int W13_N,
+ const int W13_K,
+ const int W2_N,
+ const int W2_K,
+ const int W13_scale_N,
+ const int W13_scale_K,
+ const int W2_scale_N,
+ const int W2_scale_K,
+ const int map_size, // 新增:map 数组大小
+ const int cache_expert_num, // 新增:缓存中的专家数量
+ const int source_expert_num // 新增:源数据中的专家数量
+)
+{
+ const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
+ const int64_t grid_stride = gridDim.x * blockDim.x * BYTES_PER_VEC;
+ const int64_t thread_stride = blockDim.x * gridDim.x;
+
+ const uint8_t* src_w13_scales_bytes = reinterpret_cast(src_scales_w13);
+ const uint8_t* src_w2_scales_bytes = reinterpret_cast(src_scales_w2);
+ uint8_t* dst_w13_scales_bytes = reinterpret_cast(dst_scales_w13);
+ uint8_t* dst_w2_scales_bytes = reinterpret_cast(dst_scales_w2);
+ const int block_id = blockIdx.x;
+ const int64_t tid_warp = threadIdx.x % warpSize;
+ const int64_t warp_id = threadIdx.x / warpSize;
+ const int64_t n_warp = blockDim.x / warpSize;
+ const int64_t thread_offset = tid_warp * BYTES_PER_VEC;
+ const int64_t start_tile = block_id * n_warp + warp_id;
+
+ __shared__ int16_t tile[8][2][32][8];
+ int tmp;
+ int k_mid;
+ int k_outer;
+ int n_outer;
+ int n_dst;
+ int k_dst;
+ int16_t out16[8];
+ const int lane_row = (tid_warp >> 2) + ((tid_warp & 3) << 3);
+
+
+ auto copy_chunk = [&](const uint8_t* __restrict__ src_base,
+ uint8_t* __restrict__ dst_base,
+ const int64_t chunk_size,
+ const int N,
+ const int K
+ ) {
+ for (int64_t byte_offset = tid * BYTES_PER_VEC; byte_offset < chunk_size; byte_offset += grid_stride) {
+ const float4* s = reinterpret_cast(src_base + byte_offset);
+ float4* d = reinterpret_cast(dst_base + byte_offset);
+ *d = *s;
+ }
+ };
+
+
+ auto copy_chunk_vnni = [&](const uint8_t* __restrict__ src_base,
+ uint8_t* __restrict__ dst_base,
+ const int64_t chunk_size,
+ const int N,
+ const int K
+ ) {
+ const int tile_size = BM * BN; // 512 字节
+ const int64_t n_tiles = chunk_size / tile_size;
+ const int64_t k_outer_dim = (int64_t)K / (int64_t)32; // BK=32
+ const int64_t tile_stride = gridDim.x * blockDim.x * BYTES_PER_VEC / tile_size;
+
+ int buf_id = 0;
+ int next_buf = buf_id ^ 1;
+
+ // 第一个 tile 的异步拷贝
+ const int4* s = reinterpret_cast(src_base + start_tile * tile_size + thread_offset);
+ cp_async_16B(reinterpret_cast(&tile[warp_id][buf_id][tid_warp]), s);
+ cp_async_commit();
+ cp_async_wait_impl<0>();
+ __syncwarp();
+
+ int64_t tile_id = start_tile;
+ for(; tile_id < n_tiles - tile_stride; tile_id += tile_stride){
+ // 先进行下一个 tile 的异步拷贝
+ const int4* s = reinterpret_cast(src_base + (tile_id + tile_stride) * tile_size + thread_offset);
+ cp_async_16B(reinterpret_cast(&tile[warp_id][next_buf][tid_warp]), s);
+ cp_async_commit();
+ cp_async_wait_impl<1>();
+
+ // 进行转换(从当前 buf_id 读取)
+ #pragma unroll
+ for(int i = 0; i < 8; i++){
+ out16[i] = tile[warp_id][buf_id][(i << 2) + (tid_warp & 3)][tid_warp >> 2];
+ }
+
+ // 计算索引(基于当前 tile_id)
+ tmp = tile_id << 3;
+ k_mid = tmp & 15;
+ tmp >>= 4;
+ k_outer = tmp % k_outer_dim;
+ n_outer = tmp / k_outer_dim;
+ n_dst = (n_outer << 5) + lane_row;
+ k_dst = (k_outer << 5) + (k_mid << 1);
+ int4* d = reinterpret_cast(dst_base + n_dst * K + k_dst);
+ const int4 data = *reinterpret_cast(&out16[0]);
+ *d = data;
+
+ // 流水线更替 buffer_id
+ buf_id = buf_id ^ 1;
+ next_buf = next_buf ^ 1;
+ __syncwarp();
+ }
+
+ // 处理最后一个 tile
+ cp_async_wait_impl<0>();
+ #pragma unroll
+ for(int i = 0; i < 8; i++){
+ out16[i] = tile[warp_id][buf_id][(i << 2) + (tid_warp & 3)][tid_warp >> 2];
+ }
+ tmp = tile_id << 3;
+ k_mid = tmp & 15;
+ tmp >>= 4;
+ k_outer = tmp % k_outer_dim;
+ n_outer = tmp / k_outer_dim;
+ n_dst = (n_outer << 5) + lane_row;
+ k_dst = (k_outer << 5) + (k_mid << 1);
+ const int4 data = *reinterpret_cast(&out16[0]);
+ int4* d = reinterpret_cast(dst_base + n_dst * K + k_dst);
+ *d = data;
+ };
+
+ for (int e = 0; e < num_experts; ++e) {
+ const int dst_idx = map[e];
+ if (dst_idx < 0) continue;
+ copy_chunk_vnni(src_w13 + e * w13_bytes_per_expert, dst_w13 + dst_idx * w13_bytes_per_expert, w13_bytes_per_expert,W13_N,W13_K);
+ copy_chunk_vnni(src_w2 + e * w2_bytes_per_expert, dst_w2 + dst_idx * w2_bytes_per_expert, w2_bytes_per_expert,W2_N,W2_K);
+ copy_chunk(src_w13_scales_bytes + e * w13_scale_bytes_per_expert, dst_w13_scales_bytes + dst_idx * w13_scale_bytes_per_expert, w13_scale_bytes_per_expert,W13_scale_N,W13_scale_K);
+ copy_chunk(src_w2_scales_bytes + e * w2_scale_bytes_per_expert, dst_w2_scales_bytes + dst_idx * w2_scale_bytes_per_expert, w2_scale_bytes_per_expert,W2_scale_N,W2_scale_K);
+ }
+}
+
+
+void Moe::update_expert_cache(
+ torch::Tensor w13_cache,
+ torch::Tensor w2_cache,
+ torch::Tensor w13_scale_cache,
+ torch::Tensor w2_scale_cache,
+ torch::Tensor map,
+ int64_t num_experts)
+{
+ // Sanity Checks
+ TORCH_CHECK(w13_cache.is_cuda() && w2_cache.is_cuda() && map.is_cuda());
+ TORCH_CHECK(w13_scale_cache.is_cuda() && w2_scale_cache.is_cuda());
+
+ // Compute bytes per expert
+ const int64_t w13_bytes_per_expert = w13_cache.size(1) * w13_cache.size(2) * w13_cache.element_size();
+ const int64_t w2_bytes_per_expert = w2_cache.size(1) * w2_cache.size(2) * w2_cache.element_size();
+ const int64_t w13_scale_bytes_per_expert = w13_scale_cache.size(1) * w13_scale_cache.size(2) * w13_scale_cache.element_size();
+ const int64_t w2_scale_bytes_per_expert = w2_scale_cache.size(1) * w2_scale_cache.size(2) * w2_scale_cache.element_size();
+ const int64_t actual_num_experts = std::min(map.size(0), num_experts);
+
+
+ // 2. 检查缓存大小(第一维是专家数量)
+ const int64_t cache_expert_num = w13_cache.size(0);
+ // Dynamic Launch Configuration
+ constexpr int64_t BLOCK_SIZE = 256;
+ const int64_t GRID_SIZE = 4;
+
+ const int W13_N = w13_cache.size(1);
+ const int W13_K = w13_cache.size(2);
+ const int W2_N = w2_cache.size(1);
+ const int W2_K = w2_cache.size(2);
+ const int W13_scale_N = w13_scale_cache.size(1);
+ const int W13_scale_K = w13_scale_cache.size(2);
+ const int W2_scale_N = w2_scale_cache.size(1);
+ const int W2_scale_K = w2_scale_cache.size(2);
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ update_expert_cache_kernel<<>>(
+ reinterpret_cast(m_w13_weights),
+ reinterpret_cast(m_w2_weights),
+ reinterpret_cast(m_w13_scale),
+ reinterpret_cast(m_w2_scale),
+ reinterpret_cast(w13_cache.data_ptr()),
+ reinterpret_cast(w2_cache.data_ptr()),
+ w13_scale_cache.data_ptr(),
+ w2_scale_cache.data_ptr(),
+ map.data_ptr(),
+ w13_bytes_per_expert, w2_bytes_per_expert,
+ w13_scale_bytes_per_expert, w2_scale_bytes_per_expert,
+ num_experts,
+ W13_N,
+ W13_K,
+ W2_N,
+ W2_K,
+ W13_scale_N,
+ W13_scale_K,
+ W2_scale_N,
+ W2_scale_K,
+ static_cast(map.size(0)), // map_size
+ static_cast(cache_expert_num), // cache_expert_num
+ config_.expert_num // source_expert_num
+ );
+
+ cudaError_t launch_err = cudaGetLastError();
+ if (launch_err != cudaSuccess) {
+ fprintf(stderr, "[ERROR] Kernel launch failed: %s\n", cudaGetErrorString(launch_err));
+ TORCH_CHECK(false, "Kernel launch failed: ", cudaGetErrorString(launch_err));
+ }
+
+}
+
+void MoeOffloadEngine::expert_cache_policy(
+ torch::Tensor cache_map,
+ torch::Tensor miss_map,
+ torch::Tensor policy_sort,
+ torch::Tensor topk_ids,
+ torch::Tensor cpu_topk,
+ torch::Tensor copy_map
+){
+ TORCH_CHECK(cache_map.device().is_cuda(), "cache_map mast be CUDA");
+ TORCH_CHECK(miss_map.device().is_cuda(), "miss_map mast be CUDA");
+ TORCH_CHECK(policy_sort.device().is_cuda(), "policy_sort mast be CUDA");
+ TORCH_CHECK(topk_ids.device().is_cuda(), "topk_ids mast be CUDA");
+ TORCH_CHECK(cpu_topk.device().is_cuda(), "cpu_topk mast be CUDA");
+ TORCH_CHECK(copy_map.device().is_cuda(), "copy_map mast be CUDA");
+
+ const int threads = 256;
+ const int blocks = 1;
+ auto num_tokens = topk_ids.size(0);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ cache_policy_kernel<<>>(
+ cache_map.data_ptr(),
+ miss_map.data_ptr(),
+ copy_map.data_ptr(),
+ policy_sort.data_ptr(),
+ topk_ids.data_ptr(),
+ cpu_topk.data_ptr(),
+ config_.cache_expert_num,
+ config_.expert_num,
+ num_tokens * config_.num_experts_per_tok,
+ config_.cache_topk,
+ config_.update_expert_num
+ );
+}
+
+// ==================== MoeOffloadEngine CUDA 方法 ====================
+
+void MoeOffloadEngine::initialize_async_state() {
+ if (async_state_initialized_) {
+ throw std::runtime_error("Async state already initialized");
+ }
+
+ int device;
+ cudaError_t err = cudaGetDevice(&device);
+ if (err != cudaSuccess) {
+ throw std::runtime_error("Failed to get CUDA device: " +
+ std::string(cudaGetErrorString(err)));
+ }
+
+ cudaDeviceProp prop;
+ err = cudaGetDeviceProperties(&prop, device);
+ if (err != cudaSuccess) {
+ throw std::runtime_error("Failed to get device properties: " +
+ std::string(cudaGetErrorString(err)));
+ }
+
+ if (!prop.canMapHostMemory) {
+ throw std::runtime_error("GPU does not support mapped host memory");
+ }
+
+ // 分配零拷贝内存
+ err = cudaHostAlloc((void**)&cpu_state_, sizeof(AsyncState), cudaHostAllocMapped);
+ if (err != cudaSuccess) {
+ throw std::runtime_error("cudaHostAlloc failed: " +
+ std::string(cudaGetErrorString(err)));
+ }
+
+ // 获取设备指针
+ err = cudaHostGetDevicePointer((void**)&gpu_state_, (void*)cpu_state_, 0);
+ if (err != cudaSuccess) {
+ cudaFreeHost(cpu_state_);
+ cpu_state_ = nullptr;
+ throw std::runtime_error("cudaHostGetDevicePointer failed: " +
+ std::string(cudaGetErrorString(err)));
+ }
+
+ // 初始化状态
+ *cpu_state_ = AsyncState(); // 使用构造函数初始化
+
+ // 启动轮询线程
+ try {
+ polling_thread_ = std::thread(&MoeOffloadEngine::cpu_polling_loop, this);
+ } catch (const std::system_error& e) {
+ cudaFreeHost(cpu_state_);
+ cpu_state_ = nullptr;
+ gpu_state_ = nullptr;
+ throw std::runtime_error("Failed to create thread: " + std::string(e.what()));
+ }
+
+ async_state_initialized_ = true;
+ fprintf(stderr, "[MoeOffloadEngine] Async state initialized: gpu_state=%p cpu_state=%p\n",
+ gpu_state_, cpu_state_);
+}
+
+void MoeOffloadEngine::cleanup_async_state() {
+ if (cpu_state_) {
+ cudaFreeHost(cpu_state_);
+ cpu_state_ = nullptr;
+ gpu_state_ = nullptr;
+ }
+}
+
+cudaError_t MoeOffloadEngine::submit(int layer_idx, int batch_idx, int num_tokens) {
+ nvtxRangePushA("MoeOffloadEngine::submit");
+ if (!async_state_initialized_) {
+ fprintf(stderr, "[ERROR] submit_kernel engine not initialized\n");
+ return cudaErrorNotInitialized;
+ }
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ submit_kernel<<<1, 1, 0, stream>>>(gpu_state_, layer_idx, batch_idx, num_tokens);
+ cudaError_t launch_err = cudaGetLastError();
+ if (launch_err != cudaSuccess) {
+ fprintf(stderr, "[ERROR] submit_kernel launch failed: %s\n",
+ cudaGetErrorString(launch_err));
+ }
+
+ nvtxRangePop();
+ return launch_err;
+}
+
+cudaError_t MoeOffloadEngine::sync() {
+ nvtxRangePushA("MoeOffloadEngine::sync");
+ if (!async_state_initialized_) {
+ return cudaErrorNotInitialized;
+ }
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ sync_kernel<<<1, 1, 0, stream>>>(gpu_state_);
+ cudaError_t launch_err = cudaGetLastError();
+ if (launch_err != cudaSuccess) {
+ fprintf(stderr, "[ERROR] wait_callback_completion_kernel launch failed: %s\n",
+ cudaGetErrorString(launch_err));
+ nvtxRangePop();
+ return launch_err;
+ }
+ nvtxRangePop();
+ return launch_err;
+}
+
+void MoeOffloadEngine::get_output(torch::Tensor gpu_output) {
+ TORCH_CHECK(gpu_output.device().is_cuda(), "Output must be on CUDA device");
+ int64_t n = gpu_output.size(0);
+ TORCH_CHECK(output_.size(0) >= n, "CPU output buffer too small");
+ //gpu_output.copy_(output_.slice(0, 0, n), true);
+
+ size_t copy_bytes = std::min(gpu_output.nbytes(), output_.nbytes());
+ const cudaStream_t copyStream = at::cuda::getCurrentCUDAStream(gpu_output.device().index());
+ AT_CUDA_CHECK(cudaMemcpyAsync(gpu_output.data_ptr(),
+ output_.data_ptr(),
+ copy_bytes,
+ cudaMemcpyHostToDevice,
+ copyStream));
+
+}
+
+void MoeOffloadEngine::set_input(torch::Tensor gpu_hidden_states,
+ torch::Tensor gpu_topk_ids,
+ torch::Tensor gpu_topk_weights) {
+ int64_t n = gpu_hidden_states.size(0);
+ TORCH_CHECK(gpu_hidden_states.device().is_cuda(), "Input must be on CUDA");
+ TORCH_CHECK(gpu_topk_ids.device().is_cuda(), "topk_ids must be on CUDA");
+ TORCH_CHECK(gpu_topk_weights.device().is_cuda(), "topk_weights must be on CUDA");
+ TORCH_CHECK(n <= config_.max_batch_token, "max_batch_token");
+ //hidden_states_.slice(0, 0, n).copy_(gpu_hidden_states.slice(0, 0, n), true);
+ //topk_ids_.slice(0, 0, n).copy_(gpu_topk_ids.slice(0, 0, n), true);
+ //topk_weights_.slice(0, 0, n).copy_(gpu_topk_weights.slice(0, 0, n), true);
+
+
+ size_t input_copy_bytes = std::min(gpu_hidden_states.nbytes(), hidden_states_.nbytes());
+ size_t ids_copy_bytes = std::min(gpu_topk_ids.nbytes(), topk_ids_.nbytes());
+ size_t weights_copy_bytes = std::min(gpu_topk_weights.nbytes(), topk_weights_.nbytes());
+
+ // 3. 取目标 stream(优先使用传入的 stream)
+ const cudaStream_t copyStream = at::cuda::getCurrentCUDAStream(gpu_hidden_states.device().index());
+
+ // 4. 异步拷贝(仅拷贝适配部分)
+ cudaMemcpyAsync(hidden_states_.data_ptr(), gpu_hidden_states.data_ptr(), input_copy_bytes, cudaMemcpyDeviceToHost, copyStream);
+ cudaMemcpyAsync(topk_ids_.data_ptr(), gpu_topk_ids.data_ptr(), ids_copy_bytes, cudaMemcpyDeviceToHost, copyStream);
+ cudaMemcpyAsync(topk_weights_.data_ptr(), gpu_topk_weights.data_ptr(), weights_copy_bytes, cudaMemcpyDeviceToHost, copyStream);
+
+}
\ No newline at end of file
diff --git a/csrc/offload/primitives.cpp b/csrc/offload/primitives.cpp
new file mode 100644
index 000000000000..2fbd98cef9e7
--- /dev/null
+++ b/csrc/offload/primitives.cpp
@@ -0,0 +1,308 @@
+#include "primitives.h"
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+// ==================== 数据类型实现 ====================
+
+bfloat16_t bfloat16_t::from_float(float f) {
+ uint32_t fp32 = *reinterpret_cast(&f);
+ uint32_t round_bias = (fp32 >> 16) & 0x1;
+ return bfloat16_t(static_cast((fp32 + round_bias) >> 16));
+}
+
+float bfloat16_t::to_float() const {
+ uint32_t fp32 = static_cast(bits) << 16;
+ return *reinterpret_cast(&fp32);
+}
+
+float8_e4m3_t float8_e4m3_t::from_float(float f) {
+ if (f == 0.0f) return float8_e4m3_t(0);
+ if (std::isnan(f) || std::isinf(f)) return float8_e4m3_t(0x7F);
+
+ const bool sign = f < 0;
+ const float abs_f = sign ? -f : f;
+
+ if (abs_f > 448.0f) return float8_e4m3_t((sign << 7) | 0x7F);
+ if (abs_f < 0.0078125f) return float8_e4m3_t(sign << 7);
+
+ int exponent;
+ float normalized = std::frexp(abs_f, &exponent);
+ exponent--;
+
+ uint8_t e4m3_exp, e4m3_mant;
+ if (exponent < -6) {
+ e4m3_exp = 0;
+ const float subnormal_val = abs_f * 512.0f;
+ e4m3_mant = static_cast(subnormal_val + 0.5f);
+ if (e4m3_mant > 7) e4m3_mant = 7;
+ } else {
+ e4m3_exp = static_cast(exponent + 7);
+ const float mantissa_val = (normalized - 0.5f) * 16.0f;
+ e4m3_mant = static_cast(mantissa_val + 0.5f);
+
+ const float fraction = mantissa_val - e4m3_mant;
+ const bool round_up = (fraction > 0.5f) ||
+ (fraction == 0.5f && (e4m3_mant & 1));
+ if (round_up) {
+ e4m3_mant++;
+ if (e4m3_mant == 8) {
+ e4m3_mant = 0;
+ e4m3_exp++;
+ }
+ }
+ }
+ return float8_e4m3_t((sign << 7) | (e4m3_exp << 3) | e4m3_mant);
+}
+
+float float8_e4m3_t::to_float() const {
+ const uint8_t sign = bits >> 7;
+ const uint8_t exponent_bits = (bits >> 3) & 0x0F;
+ const uint8_t mantissa_bits = bits & 0x07;
+
+ if (exponent_bits == 0 && mantissa_bits == 0) {
+ union { uint32_t u; float f; } u = { sign ? 0x80000000u : 0x00000000u };
+ return u.f;
+ }
+ if (exponent_bits == 0x0F) {
+ union { uint32_t u; float f; } u = { 0x7FC00000u };
+ return u.f;
+ }
+
+ const float sign_f = sign ? -1.0f : 1.0f;
+ if (exponent_bits == 0) {
+ const float mantissa_f = static_cast(mantissa_bits) / 8.0f;
+ return sign_f * mantissa_f * 0.015625f;
+ }
+
+ const float mantissa_f = 1.0f + static_cast(mantissa_bits) / 8.0f;
+ return sign_f * mantissa_f * std::exp2f(static_cast(exponent_bits) - 7.0f);
+}
+
+void fp32_to_bf16(const float* __restrict f32_in,
+ bfloat16_t* __restrict bf16_out, int len)
+{
+ const int step = 16; // 每次处理 16 个 float
+ for (int i = 0; i < len; i += step)
+ {
+ __m512 v = _mm512_loadu_ps(f32_in + i); // 16×float
+ _mm256_storeu_si256((__m256i*)(bf16_out + i), (__m256i)_mm512_cvtneps_pbh(v));
+ }
+}
+
+
+// ==================== AMX配置实现 ====================
+
+bool set_tiledata_use()
+{
+ if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA))
+ {
+ printf("\n Failed to enable XFEATURE_XTILEDATA \n\n");
+ return false;
+ }
+ else
+ {
+ //printf("\n TILE DATA USE SET - OK \n\n");
+ return true;
+ }
+ return true;
+}
+
+
+
+
+// 初始化tile配置
+void init_tile_config(__tile_config *cfg, int rows, int colsb) {
+ memset(cfg, 0, sizeof(__tile_config));
+ cfg->palette_id = 1;
+ cfg->start_row = 0;
+ for (int i = 0; i < 8; i++) {
+ cfg->colsb[i] = colsb; // 每行字节数:BF16为64(32个元素*2字节)
+ cfg->rows[i] = rows; // 行数:最大16
+ }
+}
+
+
+
+// ==================== AMX核心计算实现 ====================
+
+inline __m512i fp8x32_to_bf16(__m256i in8) {
+ __m512i fp8_ext = _mm512_cvtepu8_epi16(in8);
+ fp8_ext = _mm512_add_epi16(fp8_ext, mm780);
+ fp8_ext = _mm512_slli_epi16(fp8_ext, 4);
+ fp8_ext = _mm512_and_si512(fp8_ext, mm87f0);
+ fp8_ext = _mm512_add_epi16(fp8_ext, mm3c00);
+ return fp8_ext;
+}
+
+
+void amx_gemm_block_32_K_32(
+ const bfloat16_t *A, const float8_e4m3_t *B, float *scale,
+ float *C, int K, int ldc
+) {
+ // 根据layout计算4个子块地址
+ __tile_config cfg;
+ init_tile_config(&cfg, 16, 64);
+ _tile_loadconfig(&cfg);
+
+ float *C00, *C01, *C10, *C11;
+
+ C00 = C;
+ C01 = C + 16;
+ C10 = C + ldc * 16;
+ C11 = C + ldc * 16 + 16;
+
+ alignas(64) bfloat16_t B_bf16[32 * 32];
+
+ for (int kk = 0; kk < K; kk += 128) {
+ __m512 scale_vec = _mm512_set1_ps(scale[kk / 128]);
+ alignas(64) float scale_buf[16 * 64];
+ _tile_zero(0); _tile_zero(1); _tile_zero(2); _tile_zero(3);
+
+ for (int k = kk; k < kk + 128 && k < K; k += 32) {
+ const float8_e4m3_t *B_fp8 = B + k * tile_block_size * 2 / 32;
+ const bfloat16_t *A_ptr = A + k / 32 * tile_block_size * 2;
+
+ #pragma GCC unroll 32
+ for (int i = 0; i < 32; i++) {
+ __m256i fp8_row = _mm256_loadu_si256((const __m256i *)(B_fp8 + i * 32));
+ __m512i bf16_row = fp8x32_to_bf16(fp8_row);
+ _mm512_store_epi32(B_bf16 + i * 32, bf16_row);
+ }
+ asm volatile("" ::: "memory");
+ _tile_loadd(4, A_ptr, 64);
+ _tile_loadd(5, A_ptr + tile_block_size, 64);
+ _tile_loadd(6, B_bf16, 128);
+ _tile_loadd(7, B_bf16 + 32, 128);
+
+ _tile_dpbf16ps(0, 4, 6);
+ _tile_dpbf16ps(1, 4, 7);
+ _tile_dpbf16ps(2, 5, 6);
+ _tile_dpbf16ps(3, 5, 7);
+ }
+ _tile_stored(0, scale_buf, 16 * sizeof(float));
+ _tile_stored(1, scale_buf + 16 * 16, 16 * sizeof(float));
+ _tile_stored(2, scale_buf + 32 * 16, 16 * sizeof(float));
+ _tile_stored(3, scale_buf + 48 * 16, 16 * sizeof(float));
+
+ #pragma GCC unroll 16
+ for (int i = 0; i < 16; i++) {
+ __m512 vec00 = _mm512_loadu_ps(scale_buf + i * 16);
+ __m512 vec01 = _mm512_loadu_ps(scale_buf + (i + 16) * 16);
+ __m512 vec10 = _mm512_loadu_ps(scale_buf + (i + 32) * 16);
+ __m512 vec11 = _mm512_loadu_ps(scale_buf + (i + 48) * 16);
+
+ __m512 c00_vec = _mm512_loadu_ps(C00 + i * ldc);
+ __m512 c01_vec = _mm512_loadu_ps(C01 + i * ldc);
+ __m512 c10_vec = _mm512_loadu_ps(C10 + i * ldc);
+ __m512 c11_vec = _mm512_loadu_ps(C11 + i * ldc);
+
+ c00_vec = _mm512_fmadd_ps(vec00, scale_vec, c00_vec);
+ c01_vec = _mm512_fmadd_ps(vec01, scale_vec, c01_vec);
+ c10_vec = _mm512_fmadd_ps(vec10, scale_vec, c10_vec);
+ c11_vec = _mm512_fmadd_ps(vec11, scale_vec, c11_vec);
+
+ _mm512_storeu_ps(C00 + i * ldc, c00_vec);
+ _mm512_storeu_ps(C01 + i * ldc, c01_vec);
+ _mm512_storeu_ps(C10 + i * ldc, c10_vec);
+ _mm512_storeu_ps(C11 + i * ldc, c11_vec);
+ }
+ }
+}
+
+const static int BLOCK_K = 32;
+const static int BLOCK_M = 32;
+const static int BLOCK_SIZE = BLOCK_K * BLOCK_M;
+void gemv_anni_grouped(const bfloat16_t* B, const uint8_t* A, const float* AS,
+ float* C, int M, int K, int block_size) {
+ const int m_blocks = M / BLOCK_M; // M // 32
+ const int k_blocks = K / BLOCK_K; // K // 16
+ const int AS_col_stride = K / 128; // K // 128
+
+ for(int m = 0; m < m_blocks; m+=4) {
+ __m512 Cv[8] = {};
+ const float* AS_row = AS + m * AS_col_stride;
+
+ for(int kg = 0; kg < k_blocks; kg += 4) {
+ __m512 sum[8] = {};
+ const uint8_t* A_base = A + (kg + m * k_blocks) * BLOCK_SIZE;
+
+ #pragma GCC unroll 4
+ for(int kk=0; kk<4; kk++)
+ {
+ __m512i b_block = _mm512_loadu_si512((const __m512i*)(B + (kg + kk) * BLOCK_K));
+
+ #pragma GCC unroll 16
+ for(int ch = 0; ch < 16; ++ch){
+ __m512i b_vec = _mm512_permutexvar_epi32(_mm512_set1_epi32(ch), b_block);
+ _mm_prefetch((const char*)(A_base + (ch + kk * 16) * 64 + BLOCK_SIZE), _MM_HINT_T0);
+ _mm_prefetch((const char*)(A_base + (ch + kk * 16) * 64 + (1 * k_blocks) * BLOCK_SIZE + BLOCK_SIZE/2), _MM_HINT_T0);
+ _mm_prefetch((const char*)(A_base + (ch + kk * 16) * 64 + (2 * k_blocks) * BLOCK_SIZE + BLOCK_SIZE/2), _MM_HINT_T0);
+ _mm_prefetch((const char*)(A_base + (ch + kk * 16) * 64 + (3 * k_blocks) * BLOCK_SIZE + BLOCK_SIZE/2), _MM_HINT_T0);
+
+ __m512i block12 = _mm512_loadu_si512((const __m512i*)(A_base + (ch + kk * 16) * 64));
+ __m512i block34 = _mm512_loadu_si512((const __m512i*)(A_base + (ch + kk * 16) * 64 + k_blocks * BLOCK_SIZE));
+ __m512i block56 = _mm512_loadu_si512((const __m512i*)(A_base + (ch + kk * 16) * 64 + 2 * k_blocks * BLOCK_SIZE));
+ __m512i block78 = _mm512_loadu_si512((const __m512i*)(A_base + (ch + kk * 16) * 64 + 3 * k_blocks * BLOCK_SIZE));
+
+ __m256i v1 = _mm512_extracti64x4_epi64(block12, 0);
+ __m256i v2 = _mm512_extracti64x4_epi64(block12, 1);
+ __m256i v3 = _mm512_extracti64x4_epi64(block34, 0);
+ __m256i v4 = _mm512_extracti64x4_epi64(block34, 1);
+ __m256i v5 = _mm512_extracti64x4_epi64(block56, 0);
+ __m256i v6 = _mm512_extracti64x4_epi64(block56, 1);
+ __m256i v7 = _mm512_extracti64x4_epi64(block78, 0);
+ __m256i v8 = _mm512_extracti64x4_epi64(block78, 1);
+
+ __m512bh e1 = (__m512bh)fp8x32_to_bf16(v1);
+ __m512bh e2 = (__m512bh)fp8x32_to_bf16(v2);
+ __m512bh e3 = (__m512bh)fp8x32_to_bf16(v3);
+ __m512bh e4 = (__m512bh)fp8x32_to_bf16(v4);
+ __m512bh e5 = (__m512bh)fp8x32_to_bf16(v5);
+ __m512bh e6 = (__m512bh)fp8x32_to_bf16(v6);
+ __m512bh e7 = (__m512bh)fp8x32_to_bf16(v7);
+ __m512bh e8 = (__m512bh)fp8x32_to_bf16(v8);
+
+ sum[0] = _mm512_dpbf16_ps(sum[0], e1, (__m512bh)b_vec);
+ sum[1] = _mm512_dpbf16_ps(sum[1], e2, (__m512bh)b_vec);
+ sum[2] = _mm512_dpbf16_ps(sum[2], e3, (__m512bh)b_vec);
+ sum[3] = _mm512_dpbf16_ps(sum[3], e4, (__m512bh)b_vec);
+ sum[4] = _mm512_dpbf16_ps(sum[4], e5, (__m512bh)b_vec);
+ sum[5] = _mm512_dpbf16_ps(sum[5], e6, (__m512bh)b_vec);
+ sum[6] = _mm512_dpbf16_ps(sum[6], e7, (__m512bh)b_vec);
+ sum[7] = _mm512_dpbf16_ps(sum[7], e8, (__m512bh)b_vec);
+ }
+ }
+
+ float wscale = *AS_row;
+ __m512 scale_vec = _mm512_set1_ps(wscale);
+ Cv[0] = _mm512_fmadd_ps(sum[0], scale_vec, Cv[0]);
+ Cv[1] = _mm512_fmadd_ps(sum[1], scale_vec, Cv[1]);
+ Cv[2] = _mm512_fmadd_ps(sum[2], scale_vec, Cv[2]);
+ Cv[3] = _mm512_fmadd_ps(sum[3], scale_vec, Cv[3]);
+ Cv[4] = _mm512_fmadd_ps(sum[4], scale_vec, Cv[4]);
+ Cv[5] = _mm512_fmadd_ps(sum[5], scale_vec, Cv[5]);
+ Cv[6] = _mm512_fmadd_ps(sum[6], scale_vec, Cv[6]);
+ Cv[7] = _mm512_fmadd_ps(sum[7], scale_vec, Cv[7]);
+ AS_row += 1;
+ }
+
+ for(int i=0; i<8; ++i)
+ {
+ _mm512_stream_ps(C + m * BLOCK_M + BLOCK_M/2 * i, Cv[i]);
+ }
+ }
+ _mm_sfence();
+
+}
\ No newline at end of file
diff --git a/csrc/offload/primitives.h b/csrc/offload/primitives.h
new file mode 100644
index 000000000000..80d82fc68c42
--- /dev/null
+++ b/csrc/offload/primitives.h
@@ -0,0 +1,107 @@
+#pragma once
+
+#ifndef __NVCC__
+#include
+#endif
+#include
+#include
+#include
+#include
+
+// ==================== 数据类型定义 ====================
+
+struct bfloat16_t {
+ uint16_t bits;
+
+ bfloat16_t() : bits(0) {}
+ explicit bfloat16_t(uint16_t b) : bits(b) {}
+
+ static bfloat16_t from_float(float f);
+ float to_float() const;
+};
+
+struct float8_e4m3_t {
+ uint8_t bits;
+
+ float8_e4m3_t() : bits(0) {}
+ explicit float8_e4m3_t(uint8_t b) : bits(b) {}
+
+ static float8_e4m3_t from_float(float f);
+ float to_float() const;
+};
+
+// ==================== AMX常量与配置 ====================
+
+#if !defined(__CUDACC__) && !defined(__CUDA_ARCH__)
+static const __m512i mm780 = _mm512_set1_epi16(0x780);
+static const __m512i mm87f0 = _mm512_set1_epi16(0x87f0);
+static const __m512i mm3c00 = _mm512_set1_epi16(0x3c00);
+#endif
+
+#define ARCH_GET_XCOMP_PERM 0x1022
+#define ARCH_REQ_XCOMP_PERM 0x1023
+#define XFEATURE_XTILECFG 17
+#define XFEATURE_XTILEDATA 18
+
+bool set_tiledata_use();
+
+typedef struct {
+ uint8_t palette_id;
+ uint8_t start_row;
+ uint8_t reserved[14];
+ uint16_t colsb[16];
+ uint8_t rows[16];
+} __tile_config __attribute__((aligned(64)));
+
+void init_tile_config(__tile_config *cfg, int rows, int colsb);
+
+// ==================== 调试辅助 ====================
+
+template
+struct has_to_float : std::false_type {};
+
+template
+struct has_to_float().to_float())>> : std::true_type {};
+
+template
+void dump_martix(const T *A, int M, int N, const char *prefix, int ldc) {
+ std::cout << prefix << ":" << std::endl;
+ int stride = ldc == 0 ? N : ldc;
+ for (int i = 0; i < M; i++) {
+ std::cout << i << ": ";
+ for (int j = 0; j < N; j++) {
+ if constexpr (has_to_float::value)
+ std::cout << A[i * stride + j].to_float() << ", ";
+ else
+ std::cout << A[i * stride + j] << ", ";
+ }
+ std::cout << std::endl;
+ }
+ std::cout << "=====================================" << std::endl;
+}
+
+// ==================== AMX核心计算 ====================
+
+enum class CLayout { RowMajor, Contiguous };
+
+static const int tile_block_size = 32 * 16;
+static const int tile_c_block_size = 16 * 16;
+
+
+
+void amx_gemm_block_32_K_32(
+ const bfloat16_t *A,
+ const float8_e4m3_t *B,
+ float *scale,
+ float *C, // C矩阵起始地址
+ int K,
+ int ldc
+);
+void gemv_anni_grouped(const bfloat16_t* B,const uint8_t* A, const float* AS,
+ float* C, int M, int K, int block_size);
+// FP8->BF16批量转换 (32个元素)
+#if !defined(__CUDACC__) && !defined(__CUDA_ARCH__) && !defined(__NVCC__)
+inline __m512i fp8x32_to_bf16(__m256i in8);
+#endif
+void fp32_to_bf16(const float* __restrict f32_in,
+ bfloat16_t* __restrict bf16_out, int len);
\ No newline at end of file
diff --git a/csrc/offload/py_bindding.cpp b/csrc/offload/py_bindding.cpp
new file mode 100644
index 000000000000..5cee92bede12
--- /dev/null
+++ b/csrc/offload/py_bindding.cpp
@@ -0,0 +1,168 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "moe.h"
+#include "primitives.h"
+
+
+torch::Tensor
+cpu_moe_sync(torch::Tensor output, int64_t moe_engine_ptr) {
+ auto* moe_engine = reinterpret_cast(moe_engine_ptr);
+ moe_engine->sync();
+ moe_engine->get_output(output);
+ return output;
+}
+
+torch::Tensor
+cpu_moe_sync_meta(torch::Tensor output, int64_t moe_engine_ptr){
+ return output;
+}
+
+void cpu_moe_submit(torch::Tensor hidden_states, torch::Tensor topk_ids, torch::Tensor topk_weights,
+ int64_t moe_engine_ptr, int64_t layer_id, int64_t batch_idx) {
+ auto* moe_engine = reinterpret_cast(moe_engine_ptr);
+ moe_engine->set_input(hidden_states, topk_ids, topk_weights);
+ moe_engine->submit(layer_id, batch_idx, hidden_states.size(0));
+}
+
+void cpu_moe_submit_meta(torch::Tensor hidden_states, torch::Tensor topk_ids, torch::Tensor topk_weights,
+ int64_t moe_engine_ptr, int64_t layer_id, int64_t batch_idx) {
+
+}
+
+
+std::tuple
+expert_cache_policy(torch::Tensor& topk_ids, torch::Tensor& cache_map,
+ torch::Tensor& miss_map, torch::Tensor& policy_sort,
+ int64_t moe_engine_ptr){
+ auto* moe_engine = reinterpret_cast(moe_engine_ptr);
+ auto copy_map = torch::zeros_like(cache_map);
+ auto cpu_topk = torch::zeros_like(topk_ids);
+ moe_engine->expert_cache_policy(cache_map, miss_map, policy_sort, topk_ids,
+ cpu_topk, copy_map);
+ return std::make_tuple(cpu_topk, copy_map);
+}
+
+std::tuple
+expert_cache_policy_meta(torch::Tensor& topk_ids, torch::Tensor& cache_map,
+ torch::Tensor& miss_map, torch::Tensor& policy_sort,
+ int64_t moe_engine_ptr){
+ auto copy_map = torch::zeros_like(cache_map);
+ auto cpu_topk = torch::zeros_like(topk_ids);
+ return std::make_tuple(cpu_topk, copy_map);
+}
+
+void update_expert_cache(
+ torch::Tensor w13_cache, torch::Tensor w2_cache,
+ torch::Tensor w13_scale_cache, torch::Tensor w2_scale_cache,
+ torch::Tensor map, int64_t num_experts, int64_t layer_id, int64_t moe_engine_ptr) {
+
+ auto* moe_engine = reinterpret_cast(moe_engine_ptr);
+ moe_engine->update_expert_cache(w13_cache, w2_cache, w13_scale_cache,
+ w2_scale_cache, map, layer_id, num_experts);
+}
+
+void update_expert_cache_meta(
+ torch::Tensor w13_cache, torch::Tensor w2_cache,
+ torch::Tensor w13_scale_cache, torch::Tensor w2_scale_cache,
+ torch::Tensor map, int64_t num_experts, int64_t layer_id, int64_t moe_engine_ptr) {
+}
+
+namespace py = pybind11;
+
+// ========== 算子注册 ==========
+TORCH_LIBRARY(moe_offload_ops, m) {
+ m.def("expert_cache_policy(Tensor topk_ids, Tensor cache_map, Tensor miss_map, "
+ "Tensor policy_sort, int moe_engine_ptr) -> (Tensor, Tensor)");
+
+ m.def("update_expert_cache(Tensor w13_cache, Tensor w2_cache, "
+ "Tensor w13_scale_cache, Tensor w2_scale_cache, Tensor map, "
+ "int num_experts, int layer_id, int moe_engine_ptr) -> ()");
+
+ m.def("cpu_moe_submit(Tensor hidden_states, Tensor topk_ids, Tensor topk_weights, "
+ "int moe_engine_ptr, int layer_id, int batch_idx) -> ()");
+
+ m.def("cpu_moe_sync(Tensor(a!) output, int moe_engine_ptr) -> Tensor");
+}
+
+// ========== Python绑定 ==========
+PYBIND11_MODULE(_offload_C, m) {
+ m.doc() = "MoE Offload Engine (Minimal C++ Interface)";
+
+ py::class_(m, "MOEConfig")
+ .def(py::init(),
+ py::arg("tp_rank"), py::arg("tp_size"), py::arg("expert_num"),
+ py::arg("num_experts_per_tok"), py::arg("hidden_size"),
+ py::arg("intermediate_size"), py::arg("max_batch_token"),
+ py::arg("cache_expert_num"), py::arg("block_size"),
+ py::arg("cache_topk"), py::arg("update_expert_num"),
+ py::arg("forward_context_num_threads") = 14);
+
+ py::class_(m, "Moe")
+ .def(py::init([](uint64_t w13_weights_ptr, uint64_t w2_weights_ptr,
+ uint64_t w13_scales_ptr, uint64_t w2_scales_ptr,
+ int layer_id,
+ MOEConfig config) {
+ return new Moe(
+ reinterpret_cast(w13_weights_ptr),
+ reinterpret_cast(w2_weights_ptr),
+ reinterpret_cast(w13_scales_ptr),
+ reinterpret_cast(w2_scales_ptr),
+ layer_id,
+ config);
+ }), py::arg("w13_weights_ptr"), py::arg("w2_weights_ptr"),
+ py::arg("w13_scales_ptr"), py::arg("w2_scales_ptr"),
+ py::arg("layer_id"),
+ py::arg("config"))
+
+ .def("forward", [](Moe& self, uint64_t input_ptr, uint64_t topk_ids_ptr,
+ uint64_t topk_weights_ptr, uint64_t output_ptr,
+ int num_tokens) {
+
+ self.forward(
+ reinterpret_cast(input_ptr),
+ reinterpret_cast(topk_ids_ptr),
+ reinterpret_cast(topk_weights_ptr),
+ reinterpret_cast(output_ptr),
+ num_tokens);
+
+ }, py::arg("input_ptr"), py::arg("topk_ids_ptr"),
+ py::arg("topk_weights_ptr"), py::arg("output_ptr"),
+ py::arg("num_tokens"));
+
+
+ py::class_(m, "MoeOffloadEngine")
+ .def(py::init())
+ .def("create_layer", &MoeOffloadEngine::create_cpu_moe_layer)
+ .def("ptr", &MoeOffloadEngine::ptr);
+
+ m.def("set_tiledata_use", &set_tiledata_use, "Enable AMX-Tile feature");
+}
+
+TORCH_LIBRARY_IMPL(moe_offload_ops, CUDA, m) {
+ m.impl("update_expert_cache", &update_expert_cache);
+ m.impl("expert_cache_policy", &expert_cache_policy);
+ m.impl("cpu_moe_submit", &cpu_moe_submit);
+ m.impl("cpu_moe_sync", &cpu_moe_sync);
+}
+
+// ========== Meta / FakeTensor / Dynamo ==========
+TORCH_LIBRARY_IMPL(moe_offload_ops, Meta, m) {
+ m.impl("update_expert_cache", &update_expert_cache_meta);
+ m.impl("expert_cache_policy", &expert_cache_policy_meta);
+ m.impl("cpu_moe_submit", &cpu_moe_submit_meta);
+ m.impl("cpu_moe_sync", &cpu_moe_sync_meta);
+}
+
+// CompositeExplicitAutogradNonFunctional
+TORCH_LIBRARY_IMPL(moe_offload_ops, CompositeExplicitAutogradNonFunctional, m) {
+ m.impl("update_expert_cache", &update_expert_cache_meta);
+ m.impl("expert_cache_policy", &expert_cache_policy_meta);
+ m.impl("cpu_moe_submit", &cpu_moe_submit_meta);
+ m.impl("cpu_moe_sync", &cpu_moe_sync_meta);
+}
diff --git a/docs/assets/logos/digital-china-logo.png b/docs/assets/logos/digital-china-logo.png
new file mode 100644
index 000000000000..830cb9136b79
Binary files /dev/null and b/docs/assets/logos/digital-china-logo.png differ
diff --git a/requirements/common.txt b/requirements/common.txt
index 43f4a8676d79..196d6a05f892 100644
--- a/requirements/common.txt
+++ b/requirements/common.txt
@@ -52,3 +52,4 @@ openai-harmony >= 0.0.3 # Required for gpt-oss
anthropic == 0.71.0
model-hosting-container-standards >= 0.1.10, < 1.0.0
mcp
+py-libnuma # Required for NUMA support
diff --git a/setup.py b/setup.py
index 595397264283..88a3a586f079 100644
--- a/setup.py
+++ b/setup.py
@@ -81,7 +81,9 @@ def is_freethreaded():
class CMakeExtension(Extension):
def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None:
- super().__init__(name, sources=[], py_limited_api=not is_freethreaded(), **kwa)
+ # Extract py_limited_api from kwargs if provided, otherwise use default
+ py_limited_api = kwa.pop("py_limited_api", not is_freethreaded())
+ super().__init__(name, sources=[], py_limited_api=py_limited_api, **kwa)
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
@@ -842,6 +844,9 @@ def _read_requirements(filename: str) -> list[str]:
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if _is_cuda():
+ # _offload_C doesn't use py_limited_api because it needs full symbol visibility
+ # for pybind11 type casters to work with PyTorch
+ ext_modules.append(CMakeExtension(name="vllm._offload_C", py_limited_api=False))
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
if envs.VLLM_USE_PRECOMPILED or (
CUDA_HOME and get_nvcc_cuda_version() >= Version("12.3")
diff --git a/vllm/config/model.py b/vllm/config/model.py
index 8e28e34bffaf..ebc1cc592ac0 100644
--- a/vllm/config/model.py
+++ b/vllm/config/model.py
@@ -5,7 +5,7 @@
from collections.abc import Callable
from dataclasses import InitVar, field
from functools import cached_property
-from typing import TYPE_CHECKING, Any, Literal, cast, get_args
+from typing import TYPE_CHECKING, Any, Literal, cast, get_args,Optional
import torch
from pydantic import ConfigDict, Field, field_validator, model_validator
@@ -304,6 +304,12 @@ class ModelConfig:
skip_mm_profiling: InitVar[bool | None] = None
video_pruning_rate: InitVar[float | None] = None
+ moe_offload: Optional[bool] = False
+ moe_offload_cache_expert_num: Optional[int] = 32
+ moe_offload_cache_topk: Optional[int] = 2
+ moe_offload_update_expert_num: Optional[int] = 2
+ moe_offload_context_num_threads: Optional[int] = 14
+
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py
index 1cf741006a6a..aec363c11974 100644
--- a/vllm/config/vllm.py
+++ b/vllm/config/vllm.py
@@ -917,7 +917,7 @@ def has_blocked_weights():
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
)
- if self.parallel_config.use_ubatching:
+ if self.parallel_config.use_ubatching and not self.model_config.moe_offload:
a2a_backend = self.parallel_config.all2all_backend
assert a2a_backend in [
"deepep_low_latency",
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 98f1cfbd5922..8f122e5a13b8 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -22,6 +22,7 @@
cast,
get_args,
get_origin,
+ Optional,
)
import huggingface_hub
@@ -373,6 +374,11 @@ class EngineArgs:
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
seed: int = ModelConfig.seed
max_model_len: int | None = ModelConfig.max_model_len
+ moe_offload: Optional[bool] = ModelConfig.moe_offload
+ moe_offload_cache_expert_num: Optional[int] = ModelConfig.moe_offload_cache_expert_num
+ moe_offload_cache_topk: Optional[int] = ModelConfig.moe_offload_cache_topk
+ moe_offload_update_expert_num: Optional[int] = ModelConfig.moe_offload_update_expert_num
+ moe_offload_context_num_threads: Optional[int] = ModelConfig.moe_offload_context_num_threads
cudagraph_capture_sizes: list[int] | None = (
CompilationConfig.cudagraph_capture_sizes
)
@@ -646,6 +652,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--tokenizer-revision", **model_kwargs["tokenizer_revision"]
)
model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
+ model_group.add_argument("--moe-offload", **model_kwargs["moe_offload"])
+ model_group.add_argument("--moe-offload-cache-expert-num", **model_kwargs["moe_offload_cache_expert_num"])
+ model_group.add_argument("--moe-offload-cache-topk", **model_kwargs["moe_offload_cache_topk"])
+ model_group.add_argument("--moe-offload-update-expert-num", **model_kwargs["moe_offload_update_expert_num"])
+ model_group.add_argument("--moe-offload-context-num-threads", **model_kwargs["moe_offload_context_num_threads"])
model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
@@ -1222,6 +1233,11 @@ def create_model_config(self) -> ModelConfig:
hf_overrides=self.hf_overrides,
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
+ moe_offload=self.moe_offload,
+ moe_offload_cache_expert_num=self.moe_offload_cache_expert_num,
+ moe_offload_cache_topk=self.moe_offload_cache_topk,
+ moe_offload_update_expert_num=self.moe_offload_update_expert_num,
+ moe_offload_context_num_threads=self.moe_offload_context_num_threads,
quantization=self.quantization,
enforce_eager=self.enforce_eager,
max_logprobs=self.max_logprobs,
diff --git a/vllm/forward_context.py b/vllm/forward_context.py
index d1223ad83fbc..f37db5d6150c 100644
--- a/vllm/forward_context.py
+++ b/vllm/forward_context.py
@@ -102,7 +102,7 @@ def make(
num_tokens_across_dp_cpu: torch.Tensor,
) -> "DPMetadata":
assert num_tokens_across_dp_cpu is not None
- assert parallel_config.data_parallel_size > 1
+ #assert parallel_config.data_parallel_size > 1
assert parallel_config.is_moe_model is not False
dp_rank = parallel_config.data_parallel_rank
batchsize = num_tokens
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index fb441963a97d..0ede2e82b430 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -4,7 +4,7 @@
from collections.abc import Callable, Iterable
from contextlib import nullcontext
from enum import Enum
-from typing import Literal, cast, get_args, overload
+from typing import Literal, Optional, cast, get_args, overload
import torch
import torch.nn.functional as F
@@ -51,6 +51,9 @@
)
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
+
+from vllm.model_executor.layers.fused_moe.moe_offload import CpuOffloadInfer
+
if current_platform.is_cuda_alike():
from .fused_moe import eplb_map_to_physical_and_record
else:
@@ -310,6 +313,7 @@ class FusedMoE(CustomOp):
enable_eplb: Whether to enable expert parallelism load balancer.
router_logits_dtype: Data type for router logits buffers.
"""
+ cpu_offload_eng: Optional[CpuOffloadInfer] = None
def __init__(
self,
@@ -417,6 +421,7 @@ def __init__(
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
+ self.layer_idx = int(prefix.split(sep='.')[-3])
self.enable_eplb = enable_eplb
self.expert_load_view: torch.Tensor | None = None
@@ -637,6 +642,25 @@ def _get_quant_method() -> FusedMoEMethodBase:
):
moe_quant_params["intermediate_size_full"] = intermediate_size
+ """
+ offload
+ """
+ self.moe_offload = vllm_config.model_config.moe_offload
+ self.cache_expert_num = vllm_config.model_config.moe_offload_cache_expert_num
+ if FusedMoE.cpu_offload_eng is None and self.moe_offload:
+ FusedMoE.cpu_offload_eng = CpuOffloadInfer(
+ total_expert_num=self.global_num_experts,
+ cache_expert_num=self.cache_expert_num,
+ top_k=top_k,
+ hidden_size=hidden_size,
+ intermediate_size=self.intermediate_size_per_partition,
+ max_batch_tokens=16384,
+ tp_rank=self.tp_rank,
+ tp_size=self.tp_size,
+ )
+
+ """ end of offload"""
+
self.quant_method.create_weights(layer=self, **moe_quant_params)
# Chunked all2all staging tensor
diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py
index 79168948f04a..ed356fb12ab9 100644
--- a/vllm/model_executor/layers/fused_moe/modular_kernel.py
+++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py
@@ -682,12 +682,13 @@ def __init__(
fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None,
+ moe_offload: bool | None = False,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
-
+ self.moe_offload = moe_offload
# prefer an explicit FusedMoEParallelConfig when available (from
# FusedMoE layers / tests).
# if not provided, assume this kernel is
@@ -920,11 +921,12 @@ def _prepare(
The _prepare method is a wrapper around self.prepare_finalize.prepare
that handles DBO and async.
"""
- if not self.prepare_finalize.supports_async():
+ if not self.prepare_finalize.supports_async() :
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
# TODO(lucas): enable in follow-up
- assert not dbo_enabled()
+ if not self.moe_offload:
+ assert not dbo_enabled()
(
a1q,
@@ -1096,7 +1098,8 @@ def _finalize(
shared_output: torch.Tensor | None = None
if not self.prepare_finalize.supports_async():
- assert not dbo_enabled()
+ if not self.moe_offload:
+ assert not dbo_enabled()
self.prepare_finalize.finalize(
output,
diff --git a/vllm/model_executor/layers/fused_moe/moe_offload.py b/vllm/model_executor/layers/fused_moe/moe_offload.py
new file mode 100644
index 000000000000..9e926c1d021e
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/moe_offload.py
@@ -0,0 +1,552 @@
+import torch
+from torch import nn, Tensor
+from typing import Dict, List, Optional, Any, Tuple
+import warnings
+from contextlib import contextmanager
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
+from vllm import _offload_C as moe_offload
+from vllm.forward_context import get_forward_context, set_forward_context
+from vllm.config import get_current_vllm_config
+import gc
+
+from vllm.v1.worker.ubatching import (
+ dbo_current_ubatch_id, dbo_enabled, dbo_switch_to_comm,
+ dbo_switch_to_compute, dbo_switch_to_comm_sync,
+ dbo_yield_and_switch_from_comm_to_compute,
+ dbo_yield_and_switch_from_compute_to_comm)
+
+
+
+class CpuOffloadInfer:
+ """
+ CPU Offload 管理器:支持双层 Miss Expert Buffer
+ 支持 DB0 模式和 Prefetch 模式
+ """
+ def __init__(self, total_expert_num, cache_expert_num, top_k, hidden_size, intermediate_size, max_batch_tokens, tp_rank, tp_size):
+
+ vllm_config = get_current_vllm_config()
+ self.total_expert_num = total_expert_num
+ self.cache_expert_num = cache_expert_num
+ self.top_k = top_k
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ #self.max_batch_tokens = 1024
+ self.weight_block_size = [128,128]
+ self.block_quant = True
+ self.tp_rank = tp_rank
+ self.tp_size = tp_size
+ self.moe_offload_cache_topk = vllm_config.model_config.moe_offload_cache_topk
+ self.moe_offload_update_expert_num = vllm_config.model_config.moe_offload_update_expert_num
+ self.moe_offload_context_num_threads = vllm_config.model_config.moe_offload_context_num_threads
+
+ config = moe_offload.MOEConfig(
+ tp_rank, # tp_rank
+ tp_size, # tp_size
+ total_expert_num, # expert_num
+ top_k, # num_experts_per_tok
+ hidden_size, # hidden_size
+ intermediate_size, # intermediate_size
+ max_batch_tokens, # max_batch_token
+ cache_expert_num, # cache_expert_num
+ 128, # block_size
+ self.moe_offload_cache_topk, # cache_topk
+ self.moe_offload_update_expert_num, # update_expert_num
+ self.moe_offload_context_num_threads, # context_num_threads
+ )
+ self.engine = moe_offload.MoeOffloadEngine(config)
+
+ # ==================== 核心数据结构 ====================
+ # 存储各层原始 CPU 权重
+ self.cpu_weights: Dict[int, Dict[str, Tensor]] = {}
+
+ # 存储各层配置信息
+ self.layer_configs: Dict[int, Dict[str, Any]] = {}
+
+ # 已处理层的记录
+ self._processed_layers: set = set()
+
+ # Miss Expert Buffer,索引 0 和 1 对应两个 layer
+ self.temp_layer: List[Optional[nn.Module]] = [None, None]
+
+ # layer_id 到 buffer 索引的映射
+ self._layer_to_buffer_idx: Dict[int, int] = {}
+ # 存储 Fp8MoEMethod 引用,用于获取 quant_config 等参数
+ self._moe_method_refs: Dict[int, Any] = {}
+
+ # ==================== DB0 相关 ====================
+
+
+
+ self.sync_token = None
+
+ # 是否启用 DB0 模式
+ self.dbo_enabled: bool = False
+
+ # Cache maps(用于 DB0)
+ self.cache_maps: Dict[int, torch.Tensor] = {}
+ self.miss_maps: Dict[int, torch.Tensor] = {}
+ self.policy_sorts: Dict[int, torch.Tensor] = {}
+
+ # GPU 缓存引用(在 layer 对象中)
+ self.w13_caches: Dict[int, Tensor] = {}
+ self.w2_caches: Dict[int, Tensor] = {}
+
+ self.update_streams: Optional[torch.cuda.Stream] = None
+ self.compute_event: List[Optional[torch.cuda.Event]] = [None, None]
+ self.copy_event: List[Optional[torch.cuda.Event]] = [None, None]
+
+ def init_expert_hit_stats(self):
+ self.expert_miss = 0
+ self.copy_count = 0
+ self.expert_count = 0
+
+ def update_expert_hit_stats(self, topk_ids:torch.Tensor, copy_map:torch.Tensor):
+ self.expert_miss += topk_ids.sum().cpu().item()
+
+
+ def setup_cache_maps(self, layer_id: int, total_experts: int, cache_num: int) -> None:
+ """
+ 初始化指定层的 cache 映射结构
+ Args:
+ layer_id: 层ID
+ total_experts: 总专家数量
+ cache_num: cache中的专家数量
+
+ 说明:
+ - cache_maps: 专家ID → cache位置,初始前cache_num个专家映射到0~cache_num-1,其余为-1
+ - miss_maps: 专家ID → temp位置,初始前cache_num个专家设为-1(在cache中),其余映射到0~miss_num-1
+ - policy_sorts: cache位置 → 优先级,长度为cache_num,初始值为0~cache_num-1
+ """
+ device = torch.device('cuda')
+ miss_num = total_experts - cache_num
+
+ # cache_maps: [total_experts],前cache_num个在cache中
+ cache_map = torch.full((total_experts,), -1, dtype=torch.int32, device=device)
+ if cache_num > 0:
+ cache_map[:cache_num] = torch.arange(cache_num, dtype=torch.int32, device=device)
+ self.cache_maps[layer_id] = cache_map
+
+ # ✅ miss_maps: [total_experts],不在cache中的都在temp中
+ miss_map = torch.full((total_experts,), -1, dtype=torch.int32, device=device)
+ if miss_num > 0:
+ # 对于不在cache中的专家(从cache_num开始),映射到0~miss_num-1
+ miss_map[cache_num:] = torch.arange(miss_num, dtype=torch.int32, device=device)
+ self.miss_maps[layer_id] = miss_map
+
+ # policy_sorts: [cache_num],初始优先级顺序
+ if cache_num > 0:
+ policy_sort = torch.arange(cache_num, dtype=torch.int32, device=device)
+ else:
+ policy_sort = torch.empty(0, dtype=torch.int32, device=device)
+ self.policy_sorts[layer_id] = policy_sort
+
+ # ==================== 初始化方法 ====================
+ def setup_layer_cache(self, layer: nn.Module) -> None:
+ """
+ 初始化 Layer 的 Offload 配置
+
+ Args:
+ layer: MoE 层实例
+ layer_id: 层唯一标识
+ num_cache_experts: GPU 缓存的专家数量
+ total_experts: 总专家数
+ moe_method: Fp8MoEMethod 实例
+ buffer_idx: 手动指定 miss buffer 索引(0 或 1)
+ """
+ layer_id = layer.layer_idx
+ buffer_idx = layer_id % 2
+
+ if layer_id in self._processed_layers:
+ warnings.warn(f"Layer {layer_id} 已处理,跳过")
+ return
+
+ if not getattr(layer, 'moe_offload', False):
+ warnings.warn(f"Layer {layer_id} 未开启 offloading,跳过")
+ return
+
+ # 1. 保存 moe_method 引用
+ #self._moe_method_refs[layer_id] = moe_method
+
+ # 2. 保存 CPU 权重
+ weight_names = self._get_weight_names(layer)
+ cpu_state_dict = {}
+ for name in weight_names:
+ param = getattr(layer, name)
+ if name == 'w13_weight' or name == 'w2_weight':
+ M = param.data.shape[0]
+ N = param.data.shape[1]
+ K = param.data.shape[2]
+ param.data = param.data.reshape(M,N//32,32,K//32,16,2).permute(0,1,3,4,2,5).contiguous().pin_memory()
+ cpu_state_dict[name] = param.data
+ else:
+ cpu_state_dict[name] = param.data
+
+ self.cpu_weights[layer_id] = cpu_state_dict
+
+ # 3. 分配 Miss Buffer 索引
+ if buffer_idx is None:
+ used_indices = set(self._layer_to_buffer_idx.values())
+ available = {0, 1} - used_indices
+ if not available:
+ raise RuntimeError("Miss Expert Buffer 已满(最多支持2个layer)")
+ buffer_idx = min(available)
+
+ self._layer_to_buffer_idx[layer_id] = buffer_idx
+
+ # 4. 创建 Miss Buffer
+ self._create_temp_layer(layer_id, weight_names)
+
+ # 5. 创建 GPU 缓存
+ self._create_gpu_cache(layer, layer_id, self.cache_expert_num, weight_names)
+
+ # 6. 记录配置
+ self.layer_configs[layer_id] = {
+ 'num_cache_experts': self.cache_expert_num,
+ 'total_experts': self.total_expert_num,
+ 'buffer_idx': buffer_idx,
+ 'weight_names': weight_names,
+ 'intermediate_size': layer.intermediate_size_per_partition,
+ 'hidden_size': layer.hidden_size,
+ 'dtype': layer.w13_weight.dtype,
+ '_layer_ref': layer
+ }
+
+ # 7. 在 layer 上保存引用
+ layer.cpu_offload_layer_id = layer_id
+ layer.cpu_offload_manager = self
+ '''
+ if layer.expert_map is None:
+ layer.expert_map = torch.arange(0, self.total_expert_num, dtype=torch.int32, device='cuda')
+ layer.expert_map[self.cache_expert_num:] = -1
+ '''
+ if layer._expert_map is None:
+ layer._expert_map = torch.arange(0, self.total_expert_num, dtype=torch.int32, device='cuda')
+ layer._expert_map[self.cache_expert_num:] = -1
+
+ self._processed_layers.add(layer_id)
+
+
+ self.setup_cache_maps(layer_id=layer_id, total_experts=self.total_expert_num, cache_num=self.cache_expert_num)
+
+ self.engine.create_layer(
+ self.cpu_weights[layer_id]['w13_weight'], # intptr_t gateUpWeights
+ self.cpu_weights[layer_id]['w2_weight'], # intptr_t downWeights
+ self.cpu_weights[layer_id]['w13_weight_scale_inv'], # intptr_t gateUpScales
+ self.cpu_weights[layer_id]['w2_weight_scale_inv'], # intptr_t downScales
+ layer_id,
+ )
+
+ self._preload_experts(layer_id, self.cache_expert_num)
+
+ if self.tp_rank == 0:
+ print(f"✓ Layer {layer_id}: Miss Buffer={buffer_idx}, "
+ f"GPU Cache={self.cache_expert_num}/{self.total_expert_num} experts")
+
+ def _create_temp_layer(
+ self,
+ layer_id: int,
+ weight_names: List[str],
+ ) -> None:
+ """
+ 创建 Temp Layer(全局仅2个)
+ ✅ 所有参数(权重+缩放系数)都预分配空tensor
+ ✅ 后续由GPU kernel填充实际数据
+ """
+ buffer_idx = layer_id % 2
+ if self.temp_layer[buffer_idx] is not None:
+ return
+
+ miss_num_experts = self.total_expert_num - self.cache_expert_num
+ if miss_num_experts <= 0:
+ self.temp_layer[buffer_idx] = None
+ return
+
+ temp_mod = nn.Module()
+ device = torch.device('cuda')
+
+ for name in weight_names:
+ ref_shape = self.cpu_weights[layer_id][name].shape
+
+ new_shape = (miss_num_experts,) + ref_shape[1:]
+ if name == 'w13_weight' or name == 'w2_weight':
+ N = new_shape[1] * 32
+ K = new_shape[2] * 32
+ new_shape = (miss_num_experts,N,K)
+
+ # 注册到 temp layer
+ temp_mod.register_parameter(
+ name,
+ nn.Parameter(
+ torch.zeros(new_shape, dtype=self.cpu_weights[layer_id][name].dtype, device=device),
+ requires_grad=False
+ )
+ )
+
+ # ✅ quant_config 引用 temp_layer 中的空tensor
+ temp_mod.quant_config = fp8_w8a8_moe_quant_config(
+ w1_scale=(temp_mod.w13_weight_scale_inv if self.block_quant else temp_mod.w13_weight_scale),
+ w2_scale=(temp_mod.w2_weight_scale_inv if self.block_quant else temp_mod.w2_weight_scale),
+ a1_scale=getattr(temp_mod, 'w13_input_scale', None), # ✅ 安全获取
+ a2_scale=getattr(temp_mod, 'w2_input_scale', None),
+ block_shape=self.weight_block_size,
+ )
+
+ self.temp_layer[buffer_idx] = temp_mod
+ if self.tp_rank == 0:
+ print(f"✓ Temp Layer {buffer_idx} 预分配成功,{miss_num_experts} experts")
+
+ def _get_temp_layer(self, buffer_idx: int) -> Optional[nn.Module]:
+ """
+ 获取已创建的 temp_layer
+ ✅ forward时仅获取,不创建
+ """
+ if buffer_idx not in [0, 1]:
+ raise ValueError(f"buffer_idx must be 0 or 1, got {buffer_idx}")
+ return self.temp_layer[buffer_idx]
+ def _create_gpu_cache(
+ self,
+ layer: nn.Module,
+ layer_id: int,
+ num_cache_experts: int,
+ weight_names: List[str],
+ ) -> None:
+ """创建 GPU 缓存并替换 layer 参数"""
+ for name in weight_names:
+ param = getattr(layer, name)
+ orig_shape = param.shape
+ if name == 'w13_weight' or name == 'w2_weight':
+ N = orig_shape[1] * 32
+ K = orig_shape[2] * 32
+ new_shape = (num_cache_experts, N, K)
+ else:
+ new_shape = (num_cache_experts,) + orig_shape[1:]
+ gpu_tensor = torch.empty(new_shape, dtype=param.dtype, device='cuda')
+ param.data = gpu_tensor
+
+ def _preload_experts(self, layer_id: int, num_experts: int) -> None:
+ """初始加载前 N 个 expert 到 GPU 缓存"""
+ cpu_state = self.cpu_weights[layer_id]
+ config = self.layer_configs[layer_id]
+
+ # ✅ 获取 layer 引用
+ layer = config.get('_layer_ref')
+ copy_map = self.cache_maps[layer_id]
+ if not layer:
+ return
+
+ # ✅ 统一处理所有参数:w13_weight, w2_weight, scales, input_scales
+ weight_names = config.get('weight_names', [])
+
+
+
+ for name in weight_names:
+ if name not in cpu_state:
+ continue
+
+ src = cpu_state[name]
+ param = getattr(layer, name)
+ # 计算要复制的数量
+ num_to_copy = min(num_experts, src.shape[0])
+
+ if name == 'w13_weight' or name == 'w2_weight':
+ M = src.shape[0]
+ N = src.shape[1] * 32
+ K = src.shape[2] * 32
+ tmp = src[:num_to_copy].permute(0,1,4,2,3,5).reshape(num_to_copy,N,K).contiguous().pin_memory()
+ else:
+ tmp = src
+
+ if num_to_copy > 0:
+ param.data[:num_to_copy].copy_(tmp[:num_to_copy])
+ if name == 'w13_weight' or name == 'w2_weight':
+ del tmp
+ gc.collect()
+
+
+ def _get_weight_names(self, layer: nn.Module) -> List[str]:
+ """自动检测所有相关权重名称"""
+ names = []
+ for name in ['w13_weight', 'w2_weight',
+ 'w13_weight_scale', 'w2_weight_scale',
+ 'w13_weight_scale_inv', 'w2_weight_scale_inv',
+ 'w13_input_scale', 'w2_input_scale']:
+ if hasattr(layer, name) and getattr(layer, name) is not None:
+ names.append(name)
+ return names
+
+ def get_miss_buffer(self, buffer_idx: int) -> Dict[str, Tensor]:
+ """通过 0 或 1 索引获取 Miss Expert Buffer"""
+ if buffer_idx not in self.miss_buffers:
+ raise KeyError(f"Miss buffer 索引 {buffer_idx} 未初始化")
+ return self.miss_buffers[buffer_idx]
+
+ def forward_dbo(
+ self,
+ hidden_states: Tensor,
+ topk_weights: Tensor,
+ topk_ids: Tensor,
+ layer_id: int,
+ ) -> [Tensor, Tensor]:
+ """
+ 深度绑定优化 (DBO) 前向计算
+ 使用自定义 CPU 算子进行异步计算
+ """
+
+ # 1. 执行缓存策略
+ cache_map = self.cache_maps.get(layer_id, [])
+ config = self.layer_configs[layer_id]
+
+ layer = config.get('_layer_ref')
+
+ if not layer:
+ raise RuntimeError(f"Layer {layer_id} 引用丢失")
+
+ cpu_topk_ids, copy_map = torch.ops.moe_offload_ops.expert_cache_policy(
+ topk_ids,
+ self.cache_maps[layer_id],
+ self.miss_maps[layer_id],
+ self.policy_sorts[layer_id],
+ self.engine.ptr()
+ )
+
+ if dbo_enabled():
+ dbo_switch_to_comm_sync()
+
+ # 3. 提交 CPU 任务
+ torch.ops.moe_offload_ops.cpu_moe_submit(hidden_states, cpu_topk_ids, topk_weights, self.engine.ptr(), layer_id, 0)
+
+ # 4. 更新专家缓存
+ torch.ops.moe_offload_ops.update_expert_cache(
+ layer.w13_weight,
+ layer.w2_weight,
+ layer.w13_weight_scale_inv,
+ layer.w2_weight_scale_inv,
+ copy_map,
+ self.total_expert_num,
+ layer_id,
+ self.engine.ptr()
+ )
+
+ cpu_output = torch.zeros_like(hidden_states)
+ cpu_output = torch.ops.moe_offload_ops.cpu_moe_sync(cpu_output, self.engine.ptr())
+
+ # 5. 切换到另一个线程的计算,计算结束当前线程等待传输完成
+ if dbo_enabled():
+ dbo_yield_and_switch_from_comm_to_compute()
+
+ return cpu_output, cache_map
+
+
+ def forward_prefetch(
+ self,
+ hidden_states: Tensor,
+ topk_weights: Tensor,
+ topk_ids: Tensor,
+ layer_id: int,
+ ) -> [Tensor, Tensor]:
+ """
+ 预取 (Prefetch) 模式前向计算
+ 使用 Miss Expert Buffer 进行 miss expert 计算
+ """
+ n_tok = hidden_states.shape[0]
+ if n_tok > 8192:
+ n_copy = 256
+ elif n_tok >= 1024:
+ n_copy = (n_tok // 1000) * 16 + 112
+ else:
+ n_copy = 80
+
+ # ✅ 使用 layer_id % 2 获取 buffer 索引
+ buffer_idx = layer_id % 2
+ temp_layer = self._get_temp_layer(buffer_idx)
+ cache_map = self.cache_maps.get(layer_id, [])
+ config = self.layer_configs[layer_id]
+ miss_map = self.miss_maps.get(layer_id, [])
+
+ # 如果 buffer 为空,返回零张量
+ if not temp_layer:
+ return torch.zeros_like(hidden_states)
+
+ # 为每个buffer创建独立的stream和事件
+ if self.update_streams is None:
+ self.update_streams = torch.cuda.Stream()
+
+ if self.compute_event[buffer_idx] is None:
+ self.compute_event[buffer_idx] = torch.cuda.Event()
+ self.copy_event[buffer_idx] = torch.cuda.Event()
+
+ current_stream = torch.cuda.current_stream()
+
+ # 3. 当前层的update需要等待同buffer的前一个fused_experts完成
+ self.update_streams.wait_event(self.compute_event[buffer_idx])
+
+ # 1. 在专用stream中执行update_expert_cache
+ with torch.cuda.stream(self.update_streams):
+ torch.ops.moe_offload_ops.update_expert_cache(
+ temp_layer.w13_weight,
+ temp_layer.w2_weight,
+ temp_layer.w13_weight_scale_inv,
+ temp_layer.w2_weight_scale_inv,
+ miss_map,
+ n_copy,
+ layer_id,
+ self.engine.ptr()
+ )
+ # 记录update完成事件
+ self.copy_event[buffer_idx].record(self.update_streams)
+
+ # 2. fused_experts依赖update_expert_cache的执行结束
+ current_stream.wait_event(self.copy_event[buffer_idx])
+
+ # 4. 执行 fused_experts(使用 temp_layer 中的权重)
+ # 注意:这里直接调用 fused_experts 函数,而不是 moe_method.apply()
+ if n_copy < 256:
+ map_modified = miss_map.clone()
+ map_modified[n_copy:] = -1
+ mask = (topk_ids > n_copy) & (miss_map[topk_ids] >= 0)
+ cpu_topk_ids = torch.where(mask, topk_ids, -1)
+ torch.ops.moe_offload_ops.cpu_moe_submit(hidden_states, cpu_topk_ids, topk_weights, self.engine.ptr(), layer_id, 0)
+
+ from vllm.model_executor.layers.fused_moe import fused_experts
+
+ miss_output = fused_experts(
+ hidden_states=hidden_states,
+ w1=temp_layer.w13_weight, # 来自 miss buffer
+ w2=temp_layer.w2_weight, # 来自 miss buffer
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ inplace=False,
+ activation="silu", # 可从 layer 配置获取
+ global_num_experts=config['total_experts'],
+ apply_router_weight_on_input=False,
+ expert_map=miss_map, # 使用 layer 的 miss_map
+ quant_config=temp_layer.quant_config,
+ allow_deep_gemm=False,
+ )
+
+ cpu_output = torch.zeros_like(hidden_states)
+ if n_copy < 256:
+ cpu_output = torch.ops.moe_offload_ops.cpu_moe_sync(cpu_output, self.engine.ptr())
+ # 记录fused_experts完成事件,供下一个同buffer层使用
+ self.compute_event[buffer_idx].record(current_stream)
+
+ return miss_output + cpu_output, cache_map
+
+
+ def forward_offload(
+ self,
+ hidden_states: Tensor,
+ topk_weights: Tensor,
+ topk_ids: Tensor,
+ layer_id: int,
+ ) -> Tensor:
+ """
+ 统一入口:根据配置选择 DBO 或 Prefetch 模式
+ """
+ if self.tp_rank == 1 and layer_id == 4:
+ print(f"dbo_enable():{dbo_enabled()} and ntok = {hidden_states.shape[0]}")
+ if dbo_enabled() or hidden_states.shape[0] == 1:
+ return self.forward_dbo(hidden_states, topk_weights, topk_ids, layer_id)
+ else:
+ return self.forward_prefetch(hidden_states, topk_weights, topk_ids, layer_id)
diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py
index f5c3b9af611f..69451f3477cd 100644
--- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py
+++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py
@@ -354,5 +354,6 @@ def make_fp8_moe_kernel(
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(quant_config=moe_quant_config),
+ moe_offload=layer.moe_offload,
)
return kernel, use_inplace
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 2879315a6886..96d11e67b950 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -678,6 +678,20 @@ def create_weights(
layer.orig_dtype = params_dtype
layer.weight_block_size = None
+ if getattr(layer, 'moe_offload', False):
+ device = torch.device('cpu')
+ pin_memory = True
+ else:
+ device = torch.get_current_device()
+ pin_memory = False
+
+ def create_tensor(shape, dtype):
+ if device == torch.device('cpu'):
+ t = torch.empty(shape, dtype=dtype, device=device, pin_memory=True)
+ else:
+ t = torch.empty(shape, dtype=dtype, device=device)
+ return t
+
assert self.quant_config.is_checkpoint_fp8_serialized
params_dtype = torch.float8_e4m3fn
@@ -709,10 +723,10 @@ def create_weights(
# WEIGHTS
w13_weight = torch.nn.Parameter(
- torch.empty(
+ create_tensor(shape=(
num_experts,
2 * intermediate_size_per_partition,
- hidden_size,
+ hidden_size),
dtype=params_dtype,
),
requires_grad=False,
@@ -721,10 +735,11 @@ def create_weights(
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
- torch.empty(
- num_experts,
- hidden_size,
- intermediate_size_per_partition,
+ create_tensor(
+ shape=(
+ num_experts,
+ hidden_size,
+ intermediate_size_per_partition),
dtype=params_dtype,
),
requires_grad=False,
@@ -735,20 +750,22 @@ def create_weights(
# WEIGHT_SCALES
if not self.block_quant:
# For per-tensor quant, the scales are per expert and weight.
- w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
- w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
+ w13_scale_data = create_tensor(shape=(num_experts, 2), dtype=torch.float32)
+ w2_scale_data = create_tensor(shape=(num_experts,), dtype=torch.float32)
else:
# For block quant, the scales are per block (typically 128x128).
- w13_scale_data = torch.ones(
- num_experts,
- 2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
- (hidden_size + block_k - 1) // block_k,
+ w13_scale_data = create_tensor(
+ shape=(
+ num_experts,
+ 2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
+ (hidden_size + block_k - 1) // block_k),
dtype=torch.float32,
)
- w2_scale_data = torch.ones(
- num_experts,
- (hidden_size + block_n - 1) // block_n,
- (intermediate_size_per_partition + block_k - 1) // block_k,
+ w2_scale_data = create_tensor(
+ shape=(
+ num_experts,
+ (hidden_size + block_n - 1) // block_n,
+ (intermediate_size_per_partition + block_k - 1) // block_k),
dtype=torch.float32,
)
w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
@@ -771,13 +788,13 @@ def create_weights(
if self.quant_config.activation_scheme == "static":
assert not self.block_quant
w13_input_scale = torch.nn.Parameter(
- torch.ones(num_experts, dtype=torch.float32), requires_grad=False
+ create_tensor(shape=(num_experts,), dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
- torch.ones(num_experts, dtype=torch.float32), requires_grad=False
+ create_tensor(shape=(num_experts,), dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
@@ -836,6 +853,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
w13_input_scale = layer.w13_input_scale
w2_input_scale = layer.w2_input_scale
+ if layer.moe_offload:
+ layer.cpu_offload_eng.setup_layer_cache(layer)
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
if current_platform.is_fp8_fnuz():
@@ -1055,6 +1074,11 @@ def apply(
hidden_states=x,
router_logits=router_logits,
)
+ if layer.moe_offload:
+ miss_output, expert_map = layer.cpu_offload_eng.forward_offload(x, topk_weights, topk_ids, layer.layer_idx)
+ else:
+ miss_output = torch.zeros_like(x)
+ expert_map = None
assert self.kernel is not None
result = self.kernel(
@@ -1066,10 +1090,12 @@ def apply(
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
- expert_map=layer.expert_map,
+ expert_map=expert_map if layer.moe_offload else layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
+ result = result + miss_output
+
return result
diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py
index 08d7a851ac9a..65efc8ce98cf 100644
--- a/vllm/model_executor/model_loader/utils.py
+++ b/vllm/model_executor/model_loader/utils.py
@@ -104,8 +104,11 @@ def process_weights_after_loading(
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
- with device_loading_context(module, target_device):
+ if model_config.moe_offload:
quant_method.process_weights_after_loading(module)
+ else:
+ with device_loading_context(module, target_device):
+ quant_method.process_weights_after_loading(module)
# Initialize post-load attention weights for both Attention and MLA.
# NOTE: Happens after other modules so we can easily decompress weights.
diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py
index 2ee2740a51ba..f744fc408411 100755
--- a/vllm/v1/attention/backends/mla/common.py
+++ b/vllm/v1/attention/backends/mla/common.py
@@ -229,10 +229,12 @@
get_per_layer_parameters,
infer_global_hyperparameters,
split_decodes_and_prefills,
+ AttentionCGSupport,
)
from vllm.v1.kv_cache_interface import AttentionSpec
+
class QueryLenSupport(Enum):
"""Defines the level of query length support for an attention backend's
decode pipeline.
@@ -519,6 +521,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# Use `query_len_support` (above) to set this automatically
# when speculative decoding is enabled.
reorder_batch_threshold: int = 1
+ cudagraph_support: ClassVar[AttentionCGSupport] = \
+ AttentionCGSupport.UNIFORM_BATCH
@staticmethod
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
@@ -770,8 +774,7 @@ def build_for_cudagraph_capture(
"""
m = common_attn_metadata
assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), (
- "MLA only supports decode-only full CUDAGraph capture. "
- "Make sure all cudagraph capture sizes <= max_num_seq."
+ f"m.num_reqs: {m.num_reqs}, m.num_actual_tokens: {m.num_actual_tokens}, self.reorder_batch_threshold: {self.reorder_batch_threshold}"
)
assert m.max_query_len <= self.reorder_batch_threshold # decode only
diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py
index 7b427b4a6cde..21c16624e246 100644
--- a/vllm/v1/executor/multiproc_executor.py
+++ b/vllm/v1/executor/multiproc_executor.py
@@ -599,6 +599,35 @@ def make_worker_process(
proc.start()
writer.close()
+
+ # CPU binding for moe_offload
+ if vllm_config.model_config.moe_offload:
+ from numa import schedule, memory
+ import torch.distributed as dist
+ import psutil
+ import os
+ rank_local = rank
+
+ parallel_config = vllm_config.parallel_config
+ dp_rank = parallel_config.data_parallel_rank
+ tp_size = parallel_config.tensor_parallel_size
+ rank_local += dp_rank * tp_size
+ proc_handle = psutil.Process(proc.pid)
+
+ # the following cpu setting only for DCG's shenzhen AI servers
+ core_per_worker = 64 // tp_size
+ node_split = tp_size // 2
+ if rank_local < node_split:
+ cpu_node = 0
+ else:
+ cpu_node = 1
+
+ cpu_cores = list(range(rank_local * core_per_worker, rank_local * core_per_worker + core_per_worker))
+
+ schedule.run_on_nodes(cpu_node)
+ memory.set_membind_nodes(cpu_node)
+ proc_handle.cpu_affinity(cpu_cores)
+
# Keep death_writer open in parent - when parent exits,
# death_reader in child will get EOFError
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer)
diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py
index 82de0cba9194..22a5111a640c 100644
--- a/vllm/v1/worker/dp_utils.py
+++ b/vllm/v1/worker/dp_utils.py
@@ -6,7 +6,7 @@
import torch
import torch.distributed as dist
-from vllm.config import ParallelConfig
+from vllm.config import ParallelConfig,VllmConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
from vllm.v1.worker.ubatch_utils import (
@@ -179,6 +179,7 @@ def coordinate_batch_across_dp(
uniform_decode: bool | None = None,
num_scheduled_tokens_per_request: np.ndarray | None = None,
cudagraph_mode: int = 0,
+ moe_offload : bool = False,
) -> tuple[bool, torch.Tensor | None, int]:
"""
Coordinates amongst all DP ranks to determine if and how the full batch
@@ -208,7 +209,7 @@ def coordinate_batch_across_dp(
]
"""
- if parallel_config.data_parallel_size == 1:
+ if parallel_config.data_parallel_size == 1 and not moe_offload:
# Early exit.
return False, None, cudagraph_mode
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 07d5c282c036..35881e760b93 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -1587,6 +1587,8 @@ def _build_attention_metadata(
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
"""
# Attention metadata is not needed for attention free models
+ if num_tokens <= 1:
+ print(f"num_tokens: {num_tokens}, num_reqs: {num_reqs}, max_query_len: {max_query_len}, num_tokens_padded: {num_tokens_padded}, num_reqs_padded: {num_reqs_padded}, ubatch_slices: {ubatch_slices}")
if len(self.kv_cache_config.kv_cache_groups) == 0:
return {}, None
@@ -3018,7 +3020,7 @@ def _determine_batch_execution_and_padding(
# Extra coordination when running data-parallel since we need to coordinate
# across ranks
should_ubatch, num_tokens_across_dp = False, None
- if self.vllm_config.parallel_config.data_parallel_size > 1:
+ if self.vllm_config.parallel_config.data_parallel_size > 1 or self.vllm_config.model_config.moe_offload:
# Disable DP padding when running eager to avoid excessive padding when
# running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
# in a P/D setup and still use CUDA graphs (enabled by this padding) on the
@@ -3037,6 +3039,7 @@ def _determine_batch_execution_and_padding(
uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
cudagraph_mode=cudagraph_mode.value,
+ moe_offload=self.vllm_config.model_config.moe_offload,
)
)
@@ -4248,6 +4251,7 @@ def _dummy_run(
remove_lora: If False, dummy LoRAs are not destroyed after the run
activate_lora: If False, dummy_run is performed without LoRAs.
"""
+ print(f"entering dummy_run : num_tokens:{num_tokens},cudagraph_runtime_mode:{cudagraph_runtime_mode},uniform_decode:{uniform_decode},create_mixed_batch:{create_mixed_batch}")
if supports_mm_encoder_only(self.model):
# The current dummy run only covers LM execution, so we can skip it.
# mm encoder dummy run may need to add in the future.
@@ -4343,6 +4347,8 @@ def _dummy_run(
num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
)
+ print(f"num_reqs:{num_reqs}, num_reqs_padded:{num_reqs_padded},batch_desc.num_reqs:{batch_desc.num_reqs}")
+
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
should_ubatch,
num_scheduled_tokens,
@@ -4902,6 +4908,7 @@ def _capture_cudagraphs(
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
+ print(f"warmup_size : {self.compilation_config.cudagraph_num_of_warmups}")
self._dummy_run(
num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
@@ -5057,6 +5064,13 @@ def _check_and_update_cudagraph_mode(
# Flexible resolve the cudagraph mode
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
+
+ # Flexible resolve the cudagraph mode for moe_offload
+ if self.vllm_config.model_config.moe_offload:
+ print("moe_offload gpu_model_runner initialize_cudagraph_capture")
+ self.compilation_config.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
+ cudagraph_mode = self.compilation_config.cudagraph_mode
+
# check cudagraph for mixed batch is supported
if (
cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL
@@ -5091,7 +5105,7 @@ def _check_and_update_cudagraph_mode(
# check that if we are doing decode full-cudagraphs it is supported
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
- and min_cg_support == AttentionCGSupport.NEVER
+ and min_cg_support == AttentionCGSupport.NEVER and not self.vllm_config.model_config.moe_offload
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
@@ -5124,7 +5138,7 @@ def _check_and_update_cudagraph_mode(
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.uniform_decode_query_len > 1
- and min_cg_support.value < AttentionCGSupport.UNIFORM_BATCH.value
+ and min_cg_support.value < AttentionCGSupport.UNIFORM_BATCH.value and not self.vllm_config.model_config.moe_offload
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
@@ -5147,7 +5161,7 @@ def _check_and_update_cudagraph_mode(
# even after automatic downgrades
if (
cudagraph_mode.has_full_cudagraphs()
- and min_cg_support == AttentionCGSupport.NEVER
+ and min_cg_support == AttentionCGSupport.NEVER and not self.vllm_config.model_config.moe_offload
):
raise ValueError(
f"CUDAGraphMode.{cudagraph_mode.name} is not "
@@ -5166,7 +5180,7 @@ def _check_and_update_cudagraph_mode(
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and cudagraph_mode.separate_routine()
- and self.uniform_decode_query_len > 1
+ and self.uniform_decode_query_len > 1
):
self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py
index af09129e67b1..80f93ae20ab9 100644
--- a/vllm/v1/worker/gpu_ubatch_wrapper.py
+++ b/vllm/v1/worker/gpu_ubatch_wrapper.py
@@ -418,7 +418,8 @@ def __call__(self, *args, **kwargs):
dp_metadata = forward_context.dp_metadata
# We shouldn't be here unless we are running with multiple DP ranks
- assert dp_metadata is not None
+ if not self.vllm_config.model_config.moe_offload:
+ assert dp_metadata is not None
ubatch_dp_metadata = []
for ubatch_slice in ubatch_slices:
dp_size = self.vllm_config.parallel_config.data_parallel_size