Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e1570fa
[Disagg][NIXL] Add staging buffer support for heterogeneous TP KV tra…
YAMY1234 Apr 6, 2026
04f3af9
Fix NIXL staging buffer: use cuMemCreate via custom_mem_pool
YAMY1234 Apr 7, 2026
897a402
Refactor NIXL staging buffer: reduce code intrusion, align with moonc…
YAMY1234 Apr 7, 2026
ebc4288
Decouple NIXL staging from mooncake: use register_fn callback pattern
YAMY1234 Apr 7, 2026
44d3c7c
Restore original type: ignore, getattr, and is_dummy() conventions
YAMY1234 Apr 7, 2026
7b25518
Unify staging handler: fix is_dummy bug, generalize PrefillStagingStr…
YAMY1234 Apr 7, 2026
615b513
Remove section divider comments added by refactor
YAMY1234 Apr 7, 2026
ea9fe3d
Rename decode_info -> dst_info for consistency
YAMY1234 Apr 7, 2026
4676ff7
Extract _send_kv_for_req to reduce add_transfer_request complexity
YAMY1234 Apr 7, 2026
8065a95
Extract staging notification handlers from update_transfer_status
YAMY1234 Apr 8, 2026
3f3fb05
Consolidate staging functions and extract handle_chunk_arrived to com…
YAMY1234 Apr 8, 2026
466ebad
Revert _send_kv_for_req extraction, keep inline dispatch in add_trans…
YAMY1234 Apr 8, 2026
dbb47c4
Revert unnecessary changes in staging_handler.py
YAMY1234 Apr 8, 2026
0a04f45
Restore original field access style in add_transfer_request
YAMY1234 Apr 8, 2026
19d7f2d
Move handle_watermark_msg and handle_staging_rsp to common staging_ha…
YAMY1234 Apr 8, 2026
d72e9e6
style: format staging buffer code with black
YAMY1234 Apr 10, 2026
8aedb1b
fix: limit notification tag split to handle underscores in agent_name
YAMY1234 Apr 10, 2026
c9c0d1c
Merge branch 'main' into feat/nixl-staging-buffer-independent
YAMY1234 Apr 16, 2026
2e4bc12
refactor(nixl staging): address review feedback
YAMY1234 Apr 17, 2026
73cb3bb
Merge upstream main into feat/nixl-staging-buffer-async
YAMY1234 May 10, 2026
0ebc318
nixl staging: per-worker buffer + dispatch in transfer_worker (moonca…
YAMY1234 May 11, 2026
c2da8d0
nixl staging: fix prefetch_staging_reqs is_dummy access for NIXL backend
YAMY1234 May 11, 2026
8c32315
nixl staging: address review feedback (imports / docs / per-room clea…
YAMY1234 May 11, 2026
bee6bfd
nixl staging: shorten review-feedback comments
YAMY1234 May 11, 2026
bcb6a69
Merge branch 'main' into feat/nixl-staging-buffer-independent
ShangmingCai May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 130 additions & 23 deletions python/sglang/srt/disaggregation/common/staging_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ class PrefillStagingContext:
watermark_cv: threading.Condition = dataclasses.field(
default_factory=threading.Condition
)
# (room, chunk_idx, session_id) keys for chunks already requested.
prefetch_requested: set = dataclasses.field(default_factory=set)
# Rooms that have already had their full prefetch fan-out triggered. Used
# to short-circuit per-room prefetch entry on every chunk after the first.
prefetched_rooms: set = dataclasses.field(default_factory=set)
prefetch_sockets: dict = dataclasses.field(default_factory=dict)


Expand Down Expand Up @@ -184,6 +188,38 @@ def is_staging_room(self, room: int) -> bool:
"""Check if a room is registered for staging scatter."""
return room in self._room_to_decode_req

def handle_chunk_arrived(
self,
room: int,
chunk_idx: int,
page_start: int,
num_pages: int,
writer_id: str,
chunk_writer_counts: dict,
) -> bool:
"""Process a staging chunk arrival from any transport (NIXL RDMA notif or ZMQ CHUNK_READY).

Accumulates writer arrivals in *chunk_writer_counts* and submits scatter
once all writers for this chunk have reported in. Returns True if scatter
was submitted.
"""
chunk_writer_counts[room][chunk_idx].append((page_start, num_pages, writer_id))
decode_req = self._room_to_decode_req.get(room)
if decode_req is None:
logger.warning(
"Staging chunk arrived for unregistered room=%s chunk=%d, skipping",
room,
chunk_idx,
)
return False
writers_arrived = len(chunk_writer_counts[room][chunk_idx])
num_writers = self.num_writers_for(decode_req)
if writers_arrived >= num_writers:
self.submit_chunk_scatter(room, chunk_idx, page_start, num_pages)
del chunk_writer_counts[room][chunk_idx]
return True
return False

def submit_last_scatter_async(self, room: int) -> bool:
"""Submit scatter for the last chunk when all ranks report Success.

Expand Down Expand Up @@ -367,8 +403,46 @@ def is_watermark_ready(
return prev_round < wm_round or (prev_round == wm_round and alloc_end <= wm_tail)


def handle_watermark_msg(staging_ctx, msg_parts) -> None:
"""Process a WATERMARK message and update remote watermark tracking."""
wm_round = int(msg_parts[1].decode("ascii"))
wm_tail = int(msg_parts[2].decode("ascii"))
wm_session = msg_parts[3].decode("ascii") if len(msg_parts) > 3 else ""
with staging_ctx.watermark_cv:
prev = staging_ctx.remote_watermarks.get(wm_session, (0, 0))
if (wm_round, wm_tail) > prev:
staging_ctx.remote_watermarks[wm_session] = (
wm_round,
wm_tail,
)
staging_ctx.watermark_cv.notify_all()


def handle_staging_rsp(msg_parts, transfer_infos: dict) -> None:
"""Process a STAGING_RSP message and update transfer info with allocation."""
stg_room = int(msg_parts[1].decode("ascii"))
stg_chunk_idx = int(msg_parts[2].decode("ascii"))
stg_offset = int(msg_parts[3].decode("ascii"))
stg_round = int(msg_parts[4].decode("ascii"))
stg_end = int(msg_parts[5].decode("ascii"))
stg_session = msg_parts[6].decode("ascii")
room_infos = transfer_infos.get(stg_room, {})
tinfo = room_infos.get(stg_session)
if tinfo is not None:
if tinfo.staging is None:
tinfo.staging = StagingTransferInfo()
tinfo.staging.set_chunk(stg_chunk_idx, stg_offset, stg_round, stg_end)
else:
logger.warning(
"STAGING_RSP RECV but tinfo=None room=%s chunk=%d session=%s",
stg_room,
stg_chunk_idx,
stg_session,
)


# ======================================================================
# Mooncake-specific staging protocol and utilities
# Staging data structures and protocol utilities
# ======================================================================


Expand Down Expand Up @@ -434,9 +508,17 @@ def check_ready(
req,
kv_chunk_index_start: int,
num_chunk_pages: int,
session_id: Optional[str] = None,
) -> Tuple[bool, int, int, int, int]:
"""Check if staging offset and watermark are ready for this chunk.

Args:
req: transfer request with a ``.staging`` attribute.
kv_chunk_index_start: page-level start index for this chunk.
num_chunk_pages: number of pages in this chunk.
session_id: identifier used for watermark lookup. Falls back to
``req.mooncake_session_id`` when *None* (mooncake compat).

Returns (ready, chunk_idx, offset, round, end).
offset == ALLOC_OVERSIZED means permanent failure (fall back to slice).
offset == -1 means allocation pending (re-enqueue).
Expand All @@ -462,9 +544,9 @@ def check_ready(
c_round = stg.rounds[chunk_idx]
c_end = stg.ends[chunk_idx]

if not self.kv_manager._is_watermark_ready(
req.mooncake_session_id, c_round, c_end
):
if session_id is None:
session_id = req.mooncake_session_id
if not self.kv_manager._is_watermark_ready(session_id, c_round, c_end):
return (False, chunk_idx, c_offset, c_round, c_end)

return (True, chunk_idx, c_offset, c_round, c_end)
Expand Down Expand Up @@ -499,21 +581,15 @@ def transfer(
) from e


def init_staging_buffers(engine, kv_args, count: int) -> list:
"""Create prefill-side staging buffers and register them with the engine.
def _get_custom_mem_pool(device: str):
"""Get custom memory pool for staging buffer allocation (backend-agnostic).

Returns list of StagingBuffer instances.
Returns (custom_mem_pool, pool_type) tuple. custom_mem_pool may be None
if no custom pool is configured.
"""
from sglang.srt.disaggregation.common.staging_buffer import StagingBuffer
from sglang.srt.disaggregation.mooncake.utils import (
init_mooncake_custom_mem_pool,
)
from sglang.srt.environ import envs

size_mb = envs.SGLANG_DISAGG_STAGING_BUFFER_SIZE_MB.get()
size_bytes = size_mb * 1024 * 1024
gpu_id = kv_args.gpu_id
device = f"cuda:{gpu_id}"

_, custom_mem_pool, pool_type = init_mooncake_custom_mem_pool(device)
if custom_mem_pool is None:
Expand All @@ -522,34 +598,59 @@ def init_staging_buffers(engine, kv_args, count: int) -> list:
"This works for all GPU architectures. "
"For NVLink/MNNVL transport, set SGLANG_MOONCAKE_CUSTOM_MEM_POOL."
)
return custom_mem_pool, pool_type


def init_staging_buffers(register_fn, kv_args, count: int) -> list:
"""Create prefill-side staging buffers and register them with the transport.

Args:
register_fn: callable(ptr: int, size: int) that registers a memory
region with the transport backend.
kv_args: KVArgs with gpu_id.
count: number of staging buffers to create.

Returns list of StagingBuffer instances.
"""
from sglang.srt.disaggregation.common.staging_buffer import StagingBuffer
from sglang.srt.environ import envs

size_mb = envs.SGLANG_DISAGG_STAGING_BUFFER_SIZE_MB.get()
size_bytes = size_mb * 1024 * 1024
gpu_id = kv_args.gpu_id
device = f"cuda:{gpu_id}"

custom_mem_pool, _ = _get_custom_mem_pool(device)

buffers = []
for _ in range(count):
buf = StagingBuffer(size_bytes, device, gpu_id, custom_mem_pool=custom_mem_pool)
engine.batch_register([buf.get_ptr()], [buf.get_size()])
register_fn(buf.get_ptr(), buf.get_size())
buffers.append(buf)
return buffers


def init_staging_allocator(engine, kv_args):
"""Create decode-side staging ring-buffer allocator and register with engine.
def init_staging_allocator(register_fn, kv_args):
"""Create decode-side staging ring-buffer allocator and register with transport.

Args:
register_fn: callable(ptr: int, size: int) that registers a memory
region with the transport backend.
kv_args: KVArgs with gpu_id.

Returns a StagingAllocator instance.
"""
from sglang.srt.disaggregation.common.staging_buffer import StagingAllocator
from sglang.srt.disaggregation.mooncake.utils import (
init_mooncake_custom_mem_pool,
)
from sglang.srt.environ import envs

pool_size_mb = envs.SGLANG_DISAGG_STAGING_POOL_SIZE_MB.get()
pool_size_bytes = pool_size_mb * 1024 * 1024
gpu_id = kv_args.gpu_id
device = f"cuda:{gpu_id}"

_, custom_mem_pool, _ = init_mooncake_custom_mem_pool(device)
custom_mem_pool, _ = _get_custom_mem_pool(device)
allocator = StagingAllocator(pool_size_bytes, device, gpu_id, custom_mem_pool)
engine.batch_register([allocator.get_base_ptr()], [allocator.get_total_size()])
register_fn(allocator.get_base_ptr(), allocator.get_total_size())
return allocator


Expand Down Expand Up @@ -696,7 +797,13 @@ def prefetch_staging_reqs(
full_chunk_pages = max(1, cps // page_size)

for session_id, tinfo in transfer_infos[room].items():
if tinfo.is_dummy:
# mooncake exposes is_dummy as a dataclass bool field, NIXL exposes it
# as a method (it consults decode_prefix_len). Normalize via callable()
# so this shared helper works for either backend; treating a bound
# method as truthy (the previous behavior) silently dropped every
# STAGING_REQ on NIXL and deadlocked the prefill transfer worker.
is_dummy_attr = tinfo.is_dummy
if is_dummy_attr() if callable(is_dummy_attr) else is_dummy_attr:
continue
total_pages = len(tinfo.dst_kv_indices)
if total_pages == 0:
Expand Down
81 changes: 24 additions & 57 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Mooncake part LGTM

Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,9 @@ def _init_staging_buffers(self, count: int):
)

self._staging_ctx.buffers = init_staging_buffers(
self.engine, self.kv_args, count
lambda ptr, size: self.engine.batch_register([ptr], [size]),
self.kv_args,
count,
)
self.kv_buffer_tensors = None

Expand All @@ -303,7 +305,10 @@ def _init_staging_allocator(self):
init_staging_allocator,
)

self._staging_ctx.allocator = init_staging_allocator(self.engine, self.kv_args)
self._staging_ctx.allocator = init_staging_allocator(
lambda ptr, size: self.engine.batch_register([ptr], [size]),
self.kv_args,
)
self.kv_buffer_tensors = None

def _handle_staging_req(self, msg):
Expand Down Expand Up @@ -1399,47 +1404,19 @@ def bootstrap_thread():
room = waiting_req_bytes[0].decode("ascii")
# Staging: decode reports consumption watermark back to prefill
if room == "WATERMARK":
wm_round = int(waiting_req_bytes[1].decode("ascii"))
wm_tail = int(waiting_req_bytes[2].decode("ascii"))
wm_session = (
waiting_req_bytes[3].decode("ascii")
if len(waiting_req_bytes) > 3
else ""
from sglang.srt.disaggregation.common.staging_handler import (
handle_watermark_msg,
)
with self._staging_ctx.watermark_cv:
prev = self._staging_ctx.remote_watermarks.get(
wm_session, (0, 0)
)
if (wm_round, wm_tail) > prev:
self._staging_ctx.remote_watermarks[wm_session] = (
wm_round,
wm_tail,
)
self._staging_ctx.watermark_cv.notify_all()

handle_watermark_msg(self._staging_ctx, waiting_req_bytes)
continue
# Staging: decode replies with allocated staging offset
if room == "STAGING_RSP":
stg_room = int(waiting_req_bytes[1].decode("ascii"))
stg_chunk_idx = int(waiting_req_bytes[2].decode("ascii"))
stg_offset = int(waiting_req_bytes[3].decode("ascii"))
stg_round = int(waiting_req_bytes[4].decode("ascii"))
stg_end = int(waiting_req_bytes[5].decode("ascii"))
stg_session = waiting_req_bytes[6].decode("ascii")
room_infos = self.transfer_infos.get(stg_room, {})
tinfo = room_infos.get(stg_session)
if tinfo is not None:
if tinfo.staging is None:
tinfo.staging = StagingTransferInfo()
tinfo.staging.set_chunk(
stg_chunk_idx, stg_offset, stg_round, stg_end
)
else:
logger.warning(
"STAGING_RSP RECV but tinfo=None room=%s chunk=%d session=%s",
stg_room,
stg_chunk_idx,
stg_session,
)
from sglang.srt.disaggregation.common.staging_handler import (
handle_staging_rsp,
)

handle_staging_rsp(waiting_req_bytes, self.transfer_infos)
continue
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
if room == "None":
Expand Down Expand Up @@ -1493,28 +1470,18 @@ def decode_thread():
page_start = int(msg[3].decode("ascii"))
num_pages = int(msg[4].decode("ascii"))
session_id = msg[5].decode("ascii")
self._chunk_writer_counts[room][chunk_idx].append(
(page_start, num_pages, session_id)
)
handler = self._staging_handler
assert (
handler is not None
), "CHUNK_READY received before staging handler initialized"
writers_arrived = len(self._chunk_writer_counts[room][chunk_idx])
decode_req = handler._room_to_decode_req.get(room)
if decode_req is None:
logger.warning(
"CHUNK_READY received for unregistered room=%s chunk=%d, skipping",
room,
chunk_idx,
)
continue
num_writers = handler.num_writers_for(decode_req)
if writers_arrived >= num_writers:
handler.submit_chunk_scatter(
room, chunk_idx, page_start, num_pages
)
del self._chunk_writer_counts[room][chunk_idx]
handler.handle_chunk_arrived(
room,
chunk_idx,
page_start,
num_pages,
session_id,
self._chunk_writer_counts,
)
continue

# Staging: prefill pre-requests staging allocation before forward
Expand Down
Loading
Loading