Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions python/sglang/srt/disaggregation/mooncake/memory_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import threading
from importlib import resources
from typing import Dict, Final, Optional

import torch
from torch.cuda.memory import CUDAPluggableAllocator


class CustomAllocator:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: wondering whether this should be done on mooncake side or sglang side. Suppose one day another user of mooncake wants to do sth similar on their kv cache, then seems they will directly reuse this code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. Will fix this in v0.3.4.

_instances: Dict[torch.device, CUDAPluggableAllocator] = {}
_lock: Final = threading.Lock()

@classmethod
def _get_so_path(cls) -> str:
"""Dynamically locate hook.so in the mooncake package installation"""
try:
# Attempt to locate package resource
with resources.path("mooncake", "hook.so") as so_path:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: wondering whether the name needs to be more distinguishable (but anyway this is unrelated to SGLang side)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name issue will be optimized in the future release, since v0.3.3.post2 has been released. Will fix this before v0.3.4, and make a PR to fix this at that time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: then maybe add a TODO comment to MooncakeNVLinkAllocator class

if so_path.exists():
return str(so_path)
except (ImportError, FileNotFoundError, TypeError):
pass

# Fallback strategy: check in package location via import metadata
try:
import mooncake

base_path = os.path.dirname(os.path.abspath(mooncake.__file__))
so_path = os.path.join(base_path, "hook.so")
if os.path.exists(so_path):
return so_path
except (ImportError, FileNotFoundError, TypeError):
raise ImportError(
"hook.so not found in mooncake package. "
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL require mooncake-transfer-engine >= 0.3.3.post2."
)

@classmethod
def get_allocator(cls, device: torch.device) -> CUDAPluggableAllocator:
with cls._lock:
if device not in cls._instances:
so_path = cls._get_so_path()
cls._instances[device] = CUDAPluggableAllocator(
so_path, "mc_nvlink_malloc", "mc_nvlink_free"
)
return cls._instances[device]
71 changes: 53 additions & 18 deletions python/sglang/srt/disaggregation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,24 +84,59 @@ def free(self, free_index: int):


class MetadataBuffers:
def __init__(self, size: int, max_top_logprobs_num: int = 128):
# TODO: abort top_logprobs_num > 128 in PD

# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device="cpu"
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device="cpu"
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
)
def __init__(
self,
size: int,
max_top_logprobs_num: int = 128,
custom_mem_pool: torch.cuda.MemPool = None,
):
self.custom_mem_pool = custom_mem_pool

if self.custom_mem_pool is not None:
self.output_ids = None
self.output_token_logprobs_val = None
self.output_token_logprobs_idx = None
self.output_top_logprobs_val = None
self.output_top_logprobs_idx = None

with torch.cuda.use_mem_pool(self.custom_mem_pool):
# TODO: abort top_logprobs_num > 128 in PD

# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros(
(size, 16), dtype=torch.int32, device="cuda"
)
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device="cuda"
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device="cuda"
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device="cuda"
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device="cuda"
)
else:
# TODO: abort top_logprobs_num > 128 in PD

# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device="cpu"
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device="cpu"
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
)

def get_buf_infos(self):
ptrs = [
Expand Down
10 changes: 8 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,10 @@ def init_disaggregation(self):
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)

# The decode requests polling kv cache
self.disagg_decode_transfer_queue = DecodeTransferQueue(
Expand Down Expand Up @@ -669,7 +672,10 @@ def init_disaggregation(self):
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)

self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Expand Down
113 changes: 89 additions & 24 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import abc
import logging
import os
from typing import List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -260,6 +261,19 @@ def __init__(

self.head_num = head_num
self.head_dim = head_dim

# for disagg
self.enable_custom_mem_pool = os.environ.get(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", False
)
if self.enable_custom_mem_pool:
from sglang.srt.disaggregation.mooncake.memory_pool import CustomAllocator

allocator = CustomAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None

self._create_buffers()

# used for chunked cpu-offloading
Expand All @@ -277,22 +291,42 @@ def _create_buffers(self):
with self.memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
if self.enable_custom_mem_pool:
self.k_buffer = []
self.v_buffer = []

assert self.custom_mem_pool is not None
with torch.cuda.use_mem_pool(self.custom_mem_pool):
for _ in range(self.layer_num):
k = torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
v = torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
self.k_buffer.append(k)
self.v_buffer.append(v)
else:
self.k_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]

self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
Expand Down Expand Up @@ -349,6 +383,9 @@ def get_contiguous_buf_infos(self):
]
return kv_data_ptrs, kv_data_lens, kv_item_lens

def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool

def get_cpu_copy(self, indices):
torch.cuda.synchronize()
kv_cache_cpu = []
Expand Down Expand Up @@ -569,16 +606,41 @@ def __init__(
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim

# for disagg with nvlink
self.enable_custom_mem_pool = os.environ.get(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", False
)
if self.enable_custom_mem_pool:
from sglang.srt.disaggregation.mooncake.memory_pool import CustomAllocator

allocator = CustomAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we put such common logic to base class

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I think @ByronHsu will need this too. Maybe it would be better to implement this in another PR after I talk with him, so that it will be less painful for him to sync the code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok looks reasonable

else:
self.custom_mem_pool = None

with self.memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
if self.enable_custom_mem_pool:
self.kv_buffer = []

assert self.custom_mem_pool is not None
with torch.cuda.use_mem_pool(self.custom_mem_pool):
for _ in range(layer_num):
kv = torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
self.kv_buffer.append(kv)
else:
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]

self.layer_transfer_counter = None

Expand All @@ -604,6 +666,9 @@ def get_contiguous_buf_infos(self):
]
return kv_data_ptrs, kv_data_lens, kv_item_lens

def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool

def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
Expand Down
Loading