-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[PD] Add custom memory pool option to support Mooncake PD with NVLink #7264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
59c6daf
d9dcc1a
e5524d5
b1590ca
d37597b
f9b2184
a998ae5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| import os | ||
| import threading | ||
| from importlib import resources | ||
| from typing import Dict, Final, Optional | ||
|
|
||
| import torch | ||
| from torch.cuda.memory import CUDAPluggableAllocator | ||
|
|
||
|
|
||
| class MooncakeNVLinkAllocator: | ||
| _instances: Dict[torch.device, CUDAPluggableAllocator] = {} | ||
| _lock: Final = threading.Lock() | ||
|
|
||
| @classmethod | ||
| def _get_so_path(cls) -> str: | ||
| """Dynamically locate hook.so in the mooncake package installation""" | ||
| try: | ||
| # Attempt to locate package resource | ||
| with resources.path("mooncake", "hook.so") as so_path: | ||
| if so_path.exists(): | ||
| return str(so_path) | ||
| except (ImportError, FileNotFoundError, TypeError): | ||
| pass | ||
|
|
||
| # Fallback strategy: check in package location via import metadata | ||
| try: | ||
| import mooncake | ||
|
|
||
| base_path = os.path.dirname(os.path.abspath(mooncake.__file__)) | ||
| so_path = os.path.join(base_path, "hook.so") | ||
| if os.path.exists(so_path): | ||
| return so_path | ||
| except (ImportError, FileNotFoundError, TypeError): | ||
| raise ImportError( | ||
| "SGLANG_MOONCAKE_CUSTOM_MEM_POOL require mooncake-transfer-engine >= 0.3.3.post2." | ||
| ) | ||
|
|
||
| @classmethod | ||
| def get_allocator(cls, device: torch.device) -> CUDAPluggableAllocator: | ||
| with cls._lock: | ||
| if device not in cls._instances: | ||
| so_path = cls._get_so_path() | ||
| cls._instances[device] = CUDAPluggableAllocator( | ||
| so_path, "mc_nvlink_malloc", "mc_nvlink_free" | ||
| ) | ||
| return cls._instances[device] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,8 @@ | |
|
|
||
| import abc | ||
| import logging | ||
| import os | ||
| from contextlib import nullcontext | ||
| from typing import List, Optional, Tuple, Union | ||
|
|
||
| import numpy as np | ||
|
|
@@ -260,6 +262,23 @@ def __init__( | |
|
|
||
| self.head_num = head_num | ||
| self.head_dim = head_dim | ||
|
|
||
| # for disagg with nvlink | ||
| is_custom_mem_pool_enabled = os.environ.get("SGLANG_MOONCAKE_CUSTOM_MEM_POOL") | ||
| self.enable_custom_mem_pool = ( | ||
| is_custom_mem_pool_enabled is not None | ||
| and is_custom_mem_pool_enabled.lower() == "true" | ||
| ) | ||
ShangmingCai marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if self.enable_custom_mem_pool: | ||
| from sglang.srt.disaggregation.mooncake.memory_pool import ( | ||
| MooncakeNVLinkAllocator, | ||
| ) | ||
|
|
||
| allocator = MooncakeNVLinkAllocator.get_allocator(self.device) | ||
| self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) | ||
| else: | ||
| self.custom_mem_pool = None | ||
|
|
||
| self._create_buffers() | ||
|
|
||
| # used for chunked cpu-offloading | ||
|
|
@@ -277,22 +296,27 @@ def _create_buffers(self): | |
| with self.memory_saver_adapter.region(): | ||
| # [size, head_num, head_dim] for each layer | ||
| # The padded slot 0 is used for writing dummy outputs from padded tokens. | ||
| self.k_buffer = [ | ||
| torch.zeros( | ||
| (self.size + self.page_size, self.head_num, self.head_dim), | ||
| dtype=self.store_dtype, | ||
| device=self.device, | ||
| ) | ||
| for _ in range(self.layer_num) | ||
| ] | ||
| self.v_buffer = [ | ||
| torch.zeros( | ||
| (self.size + self.page_size, self.head_num, self.head_dim), | ||
| dtype=self.store_dtype, | ||
| device=self.device, | ||
| ) | ||
| for _ in range(self.layer_num) | ||
| ] | ||
| self.k_buffer = [] | ||
| self.v_buffer = [] | ||
|
|
||
| with ( | ||
| torch.cuda.use_mem_pool(self.custom_mem_pool) | ||
| if self.enable_custom_mem_pool | ||
| else nullcontext() | ||
| ): | ||
| for _ in range(self.layer_num): | ||
| k = torch.zeros( | ||
| (self.size + self.page_size, self.head_num, self.head_dim), | ||
| dtype=self.store_dtype, | ||
| device=self.device, | ||
| ) | ||
| v = torch.zeros( | ||
| (self.size + self.page_size, self.head_num, self.head_dim), | ||
| dtype=self.store_dtype, | ||
| device=self.device, | ||
| ) | ||
| self.k_buffer.append(k) | ||
| self.v_buffer.append(v) | ||
ShangmingCai marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| self.data_ptrs = torch.tensor( | ||
| [x.data_ptr() for x in self.k_buffer + self.v_buffer], | ||
|
|
@@ -349,6 +373,9 @@ def get_contiguous_buf_infos(self): | |
| ] | ||
| return kv_data_ptrs, kv_data_lens, kv_item_lens | ||
|
|
||
| def maybe_get_custom_mem_pool(self): | ||
| return self.custom_mem_pool | ||
|
|
||
| def get_cpu_copy(self, indices): | ||
| torch.cuda.synchronize() | ||
| kv_cache_cpu = [] | ||
|
|
@@ -569,16 +596,38 @@ def __init__( | |
| self.kv_lora_rank = kv_lora_rank | ||
| self.qk_rope_head_dim = qk_rope_head_dim | ||
|
|
||
| # for disagg with nvlink | ||
| is_custom_mem_pool_enabled = os.environ.get("SGLANG_MOONCAKE_CUSTOM_MEM_POOL") | ||
ShangmingCai marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.enable_custom_mem_pool = ( | ||
| is_custom_mem_pool_enabled is not None | ||
| and is_custom_mem_pool_enabled.lower() == "true" | ||
| ) | ||
| if self.enable_custom_mem_pool: | ||
| from sglang.srt.disaggregation.mooncake.memory_pool import ( | ||
| MooncakeNVLinkAllocator, | ||
| ) | ||
|
|
||
| allocator = MooncakeNVLinkAllocator.get_allocator(self.device) | ||
| self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we put such common logic to base class
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I think @ByronHsu will need this too. Maybe it would be better to implement this in another PR after I talk with him, so that it will be less painful for him to sync the code.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok looks reasonable |
||
| else: | ||
| self.custom_mem_pool = None | ||
|
|
||
| with self.memory_saver_adapter.region(): | ||
| # The padded slot 0 is used for writing dummy outputs from padded tokens. | ||
| self.kv_buffer = [ | ||
| torch.zeros( | ||
| (size + page_size, 1, kv_lora_rank + qk_rope_head_dim), | ||
| dtype=self.store_dtype, | ||
| device=device, | ||
| ) | ||
| for _ in range(layer_num) | ||
| ] | ||
| self.kv_buffer = [] | ||
ShangmingCai marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| with ( | ||
| torch.cuda.use_mem_pool(self.custom_mem_pool) | ||
| if self.custom_mem_pool | ||
| else nullcontext() | ||
| ): | ||
| for _ in range(layer_num): | ||
| kv = torch.zeros( | ||
| (size + page_size, 1, kv_lora_rank + qk_rope_head_dim), | ||
| dtype=self.store_dtype, | ||
| device=device, | ||
| ) | ||
| self.kv_buffer.append(kv) | ||
|
|
||
| self.layer_transfer_counter = None | ||
|
|
||
|
|
@@ -604,6 +653,9 @@ def get_contiguous_buf_infos(self): | |
| ] | ||
| return kv_data_ptrs, kv_data_lens, kv_item_lens | ||
|
|
||
| def maybe_get_custom_mem_pool(self): | ||
| return self.custom_mem_pool | ||
|
|
||
| def get_key_buffer(self, layer_id: int): | ||
| if self.layer_transfer_counter is not None: | ||
| self.layer_transfer_counter.wait_until(layer_id - self.start_layer) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: wondering whether the name needs to be more distinguishable (but anyway this is unrelated to SGLang side)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name issue will be optimized in the future release, since v0.3.3.post2 has been released. Will fix this before v0.3.4, and make a PR to fix this at that time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: then maybe add a TODO comment to MooncakeNVLinkAllocator class