Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions python/sglang/srt/disaggregation/mooncake/memory_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import threading
from importlib import resources
from typing import Dict, Final, Optional

import torch
from torch.cuda.memory import CUDAPluggableAllocator


class MooncakeNVLinkAllocator:
_instances: Dict[torch.device, CUDAPluggableAllocator] = {}
_lock: Final = threading.Lock()

@classmethod
def _get_so_path(cls) -> str:
"""Dynamically locate hook.so in the mooncake package installation"""
try:
# Attempt to locate package resource
with resources.path("mooncake", "hook.so") as so_path:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: wondering whether the name needs to be more distinguishable (but anyway this is unrelated to SGLang side)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name issue will be optimized in the future release, since v0.3.3.post2 has been released. Will fix this before v0.3.4, and make a PR to fix this at that time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: then maybe add a TODO comment to MooncakeNVLinkAllocator class

if so_path.exists():
return str(so_path)
except (ImportError, FileNotFoundError, TypeError):
pass

# Fallback strategy: check in package location via import metadata
try:
import mooncake

base_path = os.path.dirname(os.path.abspath(mooncake.__file__))
so_path = os.path.join(base_path, "hook.so")
if os.path.exists(so_path):
return so_path
except (ImportError, FileNotFoundError, TypeError):
raise ImportError(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL require mooncake-transfer-engine >= 0.3.3.post2."
)

@classmethod
def get_allocator(cls, device: torch.device) -> CUDAPluggableAllocator:
with cls._lock:
if device not in cls._instances:
so_path = cls._get_so_path()
cls._instances[device] = CUDAPluggableAllocator(
so_path, "mc_nvlink_malloc", "mc_nvlink_free"
)
return cls._instances[device]
67 changes: 49 additions & 18 deletions python/sglang/srt/disaggregation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import threading
import warnings
from collections import deque
from contextlib import nullcontext
from enum import Enum
from typing import TYPE_CHECKING, List, Optional

Expand Down Expand Up @@ -84,24 +85,54 @@ def free(self, free_index: int):


class MetadataBuffers:
def __init__(self, size: int, max_top_logprobs_num: int = 128):
# TODO: abort top_logprobs_num > 128 in PD

# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device="cpu"
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device="cpu"
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
)
def __init__(
self,
size: int,
max_top_logprobs_num: int = 128,
custom_mem_pool: torch.cuda.MemPool = None,
):
self.custom_mem_pool = custom_mem_pool

self.output_ids = None
self.output_token_logprobs_val = None
self.output_token_logprobs_idx = None
self.output_top_logprobs_val = None
self.output_top_logprobs_idx = None

with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
else nullcontext()
):
# TODO: abort top_logprobs_num > 128 in PD

# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros(
(size, 16),
dtype=torch.int32,
device="cuda" if self.custom_mem_pool else "cpu",
)
self.output_token_logprobs_val = torch.zeros(
(size, 16),
dtype=torch.float32,
device="cuda" if self.custom_mem_pool else "cpu",
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16),
dtype=torch.int32,
device="cuda" if self.custom_mem_pool else "cpu",
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num),
dtype=torch.float32,
device="cuda" if self.custom_mem_pool else "cpu",
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num),
dtype=torch.int32,
device="cuda" if self.custom_mem_pool else "cpu",
)

def get_buf_infos(self):
ptrs = [
Expand Down
10 changes: 8 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,10 @@ def init_disaggregation(self):
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)

# The decode requests polling kv cache
self.disagg_decode_transfer_queue = DecodeTransferQueue(
Expand Down Expand Up @@ -669,7 +672,10 @@ def init_disaggregation(self):
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)

self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Expand Down
100 changes: 76 additions & 24 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

import abc
import logging
import os
from contextlib import nullcontext
from typing import List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -260,6 +262,23 @@ def __init__(

self.head_num = head_num
self.head_dim = head_dim

# for disagg with nvlink
is_custom_mem_pool_enabled = os.environ.get("SGLANG_MOONCAKE_CUSTOM_MEM_POOL")
self.enable_custom_mem_pool = (
is_custom_mem_pool_enabled is not None
and is_custom_mem_pool_enabled.lower() == "true"
)
if self.enable_custom_mem_pool:
from sglang.srt.disaggregation.mooncake.memory_pool import (
MooncakeNVLinkAllocator,
)

allocator = MooncakeNVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None

self._create_buffers()

# used for chunked cpu-offloading
Expand All @@ -277,22 +296,27 @@ def _create_buffers(self):
with self.memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.k_buffer = []
self.v_buffer = []

with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
else nullcontext()
):
for _ in range(self.layer_num):
k = torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
v = torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
self.k_buffer.append(k)
self.v_buffer.append(v)

self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
Expand Down Expand Up @@ -349,6 +373,9 @@ def get_contiguous_buf_infos(self):
]
return kv_data_ptrs, kv_data_lens, kv_item_lens

def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool

def get_cpu_copy(self, indices):
torch.cuda.synchronize()
kv_cache_cpu = []
Expand Down Expand Up @@ -569,16 +596,38 @@ def __init__(
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim

# for disagg with nvlink
is_custom_mem_pool_enabled = os.environ.get("SGLANG_MOONCAKE_CUSTOM_MEM_POOL")
self.enable_custom_mem_pool = (
is_custom_mem_pool_enabled is not None
and is_custom_mem_pool_enabled.lower() == "true"
)
if self.enable_custom_mem_pool:
from sglang.srt.disaggregation.mooncake.memory_pool import (
MooncakeNVLinkAllocator,
)

allocator = MooncakeNVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we put such common logic to base class

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I think @ByronHsu will need this too. Maybe it would be better to implement this in another PR after I talk with him, so that it will be less painful for him to sync the code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok looks reasonable

else:
self.custom_mem_pool = None

with self.memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
self.kv_buffer = []

with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
else nullcontext()
):
for _ in range(layer_num):
kv = torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
self.kv_buffer.append(kv)

self.layer_transfer_counter = None

Expand All @@ -604,6 +653,9 @@ def get_contiguous_buf_infos(self):
]
return kv_data_ptrs, kv_data_lens, kv_item_lens

def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool

def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
Expand Down
Loading