diff --git a/python/sglang/srt/disaggregation/mooncake/memory_pool.py b/python/sglang/srt/disaggregation/mooncake/memory_pool.py new file mode 100644 index 00000000000..6e8edaf927d --- /dev/null +++ b/python/sglang/srt/disaggregation/mooncake/memory_pool.py @@ -0,0 +1,47 @@ +import os +import threading +from importlib import resources +from typing import Dict, Final, Optional + +import torch +from torch.cuda.memory import CUDAPluggableAllocator + + +# TODO(shangming): move this class into mooncake's package for more general use cases +class MooncakeNVLinkAllocator: + _instances: Dict[torch.device, CUDAPluggableAllocator] = {} + _lock: Final = threading.Lock() + + @classmethod + def _get_so_path(cls) -> str: + """Dynamically locate hook.so in the mooncake package installation""" + try: + # Attempt to locate package resource + with resources.path("mooncake", "hook.so") as so_path: + if so_path.exists(): + return str(so_path) + except (ImportError, FileNotFoundError, TypeError): + pass + + # Fallback strategy: check in package location via import metadata + try: + import mooncake + + base_path = os.path.dirname(os.path.abspath(mooncake.__file__)) + so_path = os.path.join(base_path, "hook.so") + if os.path.exists(so_path): + return so_path + except (ImportError, FileNotFoundError, TypeError): + raise ImportError( + "SGLANG_MOONCAKE_CUSTOM_MEM_POOL require mooncake-transfer-engine >= 0.3.3.post2." + ) + + @classmethod + def get_allocator(cls, device: torch.device) -> CUDAPluggableAllocator: + with cls._lock: + if device not in cls._instances: + so_path = cls._get_so_path() + cls._instances[device] = CUDAPluggableAllocator( + so_path, "mc_nvlink_malloc", "mc_nvlink_free" + ) + return cls._instances[device] diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 7dd6a6bec45..cd46846ff66 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -6,6 +6,7 @@ import threading import warnings from collections import deque +from contextlib import nullcontext from enum import Enum from typing import TYPE_CHECKING, List, Optional @@ -84,24 +85,37 @@ def free(self, free_index: int): class MetadataBuffers: - def __init__(self, size: int, max_top_logprobs_num: int = 128): - # TODO: abort top_logprobs_num > 128 in PD - - # We transfer the metadata of first output token to decode - # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes - self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu") - self.output_token_logprobs_val = torch.zeros( - (size, 16), dtype=torch.float32, device="cpu" - ) - self.output_token_logprobs_idx = torch.zeros( - (size, 16), dtype=torch.int32, device="cpu" - ) - self.output_top_logprobs_val = torch.zeros( - (size, max_top_logprobs_num), dtype=torch.float32, device="cpu" - ) - self.output_top_logprobs_idx = torch.zeros( - (size, max_top_logprobs_num), dtype=torch.int32, device="cpu" - ) + def __init__( + self, + size: int, + max_top_logprobs_num: int = 128, + custom_mem_pool: torch.cuda.MemPool = None, + ): + self.custom_mem_pool = custom_mem_pool + device = "cuda" if self.custom_mem_pool else "cpu" + + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.custom_mem_pool + else nullcontext() + ): + # TODO: abort top_logprobs_num > 128 in PD + + # We transfer the metadata of first output token to decode + # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes + self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device) + self.output_token_logprobs_val = torch.zeros( + (size, 16), dtype=torch.float32, device=device + ) + self.output_token_logprobs_idx = torch.zeros( + (size, 16), dtype=torch.int32, device=device + ) + self.output_top_logprobs_val = torch.zeros( + (size, max_top_logprobs_num), dtype=torch.float32, device=device + ) + self.output_top_logprobs_idx = torch.zeros( + (size, max_top_logprobs_num), dtype=torch.int32, device=device + ) def get_buf_infos(self): ptrs = [ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6b5a03b8222..7ab872d7081 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -622,7 +622,10 @@ def init_disaggregation(self): self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( buffer_size ) - self.disagg_metadata_buffers = MetadataBuffers(buffer_size) + self.disagg_metadata_buffers = MetadataBuffers( + buffer_size, + custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), + ) # The decode requests polling kv cache self.disagg_decode_transfer_queue = DecodeTransferQueue( @@ -669,7 +672,10 @@ def init_disaggregation(self): self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( buffer_size ) - self.disagg_metadata_buffers = MetadataBuffers(buffer_size) + self.disagg_metadata_buffers = MetadataBuffers( + buffer_size, + custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), + ) self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue( token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(), diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 1e823be10a8..d426093df2c 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -26,6 +26,8 @@ import abc import logging +import os +from contextlib import nullcontext from typing import List, Optional, Tuple, Union import numpy as np @@ -34,7 +36,7 @@ import triton.language as tl from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2 +from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2 logger = logging.getLogger(__name__) @@ -260,6 +262,22 @@ def __init__( self.head_num = head_num self.head_dim = head_dim + + # for disagg with nvlink + self.enable_custom_mem_pool = get_bool_env_var( + "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" + ) + if self.enable_custom_mem_pool: + from sglang.srt.disaggregation.mooncake.memory_pool import ( + MooncakeNVLinkAllocator, + ) + + # TODO(shangming): abstract custom allocator class for more backends + allocator = MooncakeNVLinkAllocator.get_allocator(self.device) + self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) + else: + self.custom_mem_pool = None + self._create_buffers() # used for chunked cpu-offloading @@ -275,24 +293,29 @@ def __init__( def _create_buffers(self): with self.memory_saver_adapter.region(): - # [size, head_num, head_dim] for each layer - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.k_buffer = [ - torch.zeros( - (self.size + self.page_size, self.head_num, self.head_dim), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] - self.v_buffer = [ - torch.zeros( - (self.size + self.page_size, self.head_num, self.head_dim), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.enable_custom_mem_pool + else nullcontext() + ): + # [size, head_num, head_dim] for each layer + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.k_buffer = [ + torch.zeros( + (self.size + self.page_size, self.head_num, self.head_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + self.v_buffer = [ + torch.zeros( + (self.size + self.page_size, self.head_num, self.head_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] self.data_ptrs = torch.tensor( [x.data_ptr() for x in self.k_buffer + self.v_buffer], @@ -349,6 +372,9 @@ def get_contiguous_buf_infos(self): ] return kv_data_ptrs, kv_data_lens, kv_item_lens + def maybe_get_custom_mem_pool(self): + return self.custom_mem_pool + def get_cpu_copy(self, indices): torch.cuda.synchronize() kv_cache_cpu = [] @@ -569,16 +595,36 @@ def __init__( self.kv_lora_rank = kv_lora_rank self.qk_rope_head_dim = qk_rope_head_dim + # for disagg with nvlink + self.enable_custom_mem_pool = get_bool_env_var( + "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" + ) + if self.enable_custom_mem_pool: + from sglang.srt.disaggregation.mooncake.memory_pool import ( + MooncakeNVLinkAllocator, + ) + + # TODO(shangming): abstract custom allocator class for more backends + allocator = MooncakeNVLinkAllocator.get_allocator(self.device) + self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) + else: + self.custom_mem_pool = None + with self.memory_saver_adapter.region(): - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.kv_buffer = [ - torch.zeros( - (size + page_size, 1, kv_lora_rank + qk_rope_head_dim), - dtype=self.store_dtype, - device=device, - ) - for _ in range(layer_num) - ] + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.custom_mem_pool + else nullcontext() + ): + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.kv_buffer = [ + torch.zeros( + (size + page_size, 1, kv_lora_rank + qk_rope_head_dim), + dtype=self.store_dtype, + device=device, + ) + for _ in range(layer_num) + ] self.layer_transfer_counter = None @@ -604,6 +650,9 @@ def get_contiguous_buf_infos(self): ] return kv_data_ptrs, kv_data_lens, kv_item_lens + def maybe_get_custom_mem_pool(self): + return self.custom_mem_pool + def get_key_buffer(self, layer_id: int): if self.layer_transfer_counter is not None: self.layer_transfer_counter.wait_until(layer_id - self.start_layer)