diff --git a/python/sglang/srt/disaggregation/common/staging_handler.py b/python/sglang/srt/disaggregation/common/staging_handler.py index c05e8151177f..4cc0dc906541 100644 --- a/python/sglang/srt/disaggregation/common/staging_handler.py +++ b/python/sglang/srt/disaggregation/common/staging_handler.py @@ -45,7 +45,11 @@ class PrefillStagingContext: watermark_cv: threading.Condition = dataclasses.field( default_factory=threading.Condition ) + # (room, chunk_idx, session_id) keys for chunks already requested. prefetch_requested: set = dataclasses.field(default_factory=set) + # Rooms that have already had their full prefetch fan-out triggered. Used + # to short-circuit per-room prefetch entry on every chunk after the first. + prefetched_rooms: set = dataclasses.field(default_factory=set) prefetch_sockets: dict = dataclasses.field(default_factory=dict) @@ -184,6 +188,38 @@ def is_staging_room(self, room: int) -> bool: """Check if a room is registered for staging scatter.""" return room in self._room_to_decode_req + def handle_chunk_arrived( + self, + room: int, + chunk_idx: int, + page_start: int, + num_pages: int, + writer_id: str, + chunk_writer_counts: dict, + ) -> bool: + """Process a staging chunk arrival from any transport (NIXL RDMA notif or ZMQ CHUNK_READY). + + Accumulates writer arrivals in *chunk_writer_counts* and submits scatter + once all writers for this chunk have reported in. Returns True if scatter + was submitted. + """ + chunk_writer_counts[room][chunk_idx].append((page_start, num_pages, writer_id)) + decode_req = self._room_to_decode_req.get(room) + if decode_req is None: + logger.warning( + "Staging chunk arrived for unregistered room=%s chunk=%d, skipping", + room, + chunk_idx, + ) + return False + writers_arrived = len(chunk_writer_counts[room][chunk_idx]) + num_writers = self.num_writers_for(decode_req) + if writers_arrived >= num_writers: + self.submit_chunk_scatter(room, chunk_idx, page_start, num_pages) + del chunk_writer_counts[room][chunk_idx] + return True + return False + def submit_last_scatter_async(self, room: int) -> bool: """Submit scatter for the last chunk when all ranks report Success. @@ -367,8 +403,46 @@ def is_watermark_ready( return prev_round < wm_round or (prev_round == wm_round and alloc_end <= wm_tail) +def handle_watermark_msg(staging_ctx, msg_parts) -> None: + """Process a WATERMARK message and update remote watermark tracking.""" + wm_round = int(msg_parts[1].decode("ascii")) + wm_tail = int(msg_parts[2].decode("ascii")) + wm_session = msg_parts[3].decode("ascii") if len(msg_parts) > 3 else "" + with staging_ctx.watermark_cv: + prev = staging_ctx.remote_watermarks.get(wm_session, (0, 0)) + if (wm_round, wm_tail) > prev: + staging_ctx.remote_watermarks[wm_session] = ( + wm_round, + wm_tail, + ) + staging_ctx.watermark_cv.notify_all() + + +def handle_staging_rsp(msg_parts, transfer_infos: dict) -> None: + """Process a STAGING_RSP message and update transfer info with allocation.""" + stg_room = int(msg_parts[1].decode("ascii")) + stg_chunk_idx = int(msg_parts[2].decode("ascii")) + stg_offset = int(msg_parts[3].decode("ascii")) + stg_round = int(msg_parts[4].decode("ascii")) + stg_end = int(msg_parts[5].decode("ascii")) + stg_session = msg_parts[6].decode("ascii") + room_infos = transfer_infos.get(stg_room, {}) + tinfo = room_infos.get(stg_session) + if tinfo is not None: + if tinfo.staging is None: + tinfo.staging = StagingTransferInfo() + tinfo.staging.set_chunk(stg_chunk_idx, stg_offset, stg_round, stg_end) + else: + logger.warning( + "STAGING_RSP RECV but tinfo=None room=%s chunk=%d session=%s", + stg_room, + stg_chunk_idx, + stg_session, + ) + + # ====================================================================== -# Mooncake-specific staging protocol and utilities +# Staging data structures and protocol utilities # ====================================================================== @@ -434,9 +508,17 @@ def check_ready( req, kv_chunk_index_start: int, num_chunk_pages: int, + session_id: Optional[str] = None, ) -> Tuple[bool, int, int, int, int]: """Check if staging offset and watermark are ready for this chunk. + Args: + req: transfer request with a ``.staging`` attribute. + kv_chunk_index_start: page-level start index for this chunk. + num_chunk_pages: number of pages in this chunk. + session_id: identifier used for watermark lookup. Falls back to + ``req.mooncake_session_id`` when *None* (mooncake compat). + Returns (ready, chunk_idx, offset, round, end). offset == ALLOC_OVERSIZED means permanent failure (fall back to slice). offset == -1 means allocation pending (re-enqueue). @@ -462,9 +544,9 @@ def check_ready( c_round = stg.rounds[chunk_idx] c_end = stg.ends[chunk_idx] - if not self.kv_manager._is_watermark_ready( - req.mooncake_session_id, c_round, c_end - ): + if session_id is None: + session_id = req.mooncake_session_id + if not self.kv_manager._is_watermark_ready(session_id, c_round, c_end): return (False, chunk_idx, c_offset, c_round, c_end) return (True, chunk_idx, c_offset, c_round, c_end) @@ -499,21 +581,15 @@ def transfer( ) from e -def init_staging_buffers(engine, kv_args, count: int) -> list: - """Create prefill-side staging buffers and register them with the engine. +def _get_custom_mem_pool(device: str): + """Get custom memory pool for staging buffer allocation (backend-agnostic). - Returns list of StagingBuffer instances. + Returns (custom_mem_pool, pool_type) tuple. custom_mem_pool may be None + if no custom pool is configured. """ - from sglang.srt.disaggregation.common.staging_buffer import StagingBuffer from sglang.srt.disaggregation.mooncake.utils import ( init_mooncake_custom_mem_pool, ) - from sglang.srt.environ import envs - - size_mb = envs.SGLANG_DISAGG_STAGING_BUFFER_SIZE_MB.get() - size_bytes = size_mb * 1024 * 1024 - gpu_id = kv_args.gpu_id - device = f"cuda:{gpu_id}" _, custom_mem_pool, pool_type = init_mooncake_custom_mem_pool(device) if custom_mem_pool is None: @@ -522,24 +598,49 @@ def init_staging_buffers(engine, kv_args, count: int) -> list: "This works for all GPU architectures. " "For NVLink/MNNVL transport, set SGLANG_MOONCAKE_CUSTOM_MEM_POOL." ) + return custom_mem_pool, pool_type + + +def init_staging_buffers(register_fn, kv_args, count: int) -> list: + """Create prefill-side staging buffers and register them with the transport. + + Args: + register_fn: callable(ptr: int, size: int) that registers a memory + region with the transport backend. + kv_args: KVArgs with gpu_id. + count: number of staging buffers to create. + + Returns list of StagingBuffer instances. + """ + from sglang.srt.disaggregation.common.staging_buffer import StagingBuffer + from sglang.srt.environ import envs + + size_mb = envs.SGLANG_DISAGG_STAGING_BUFFER_SIZE_MB.get() + size_bytes = size_mb * 1024 * 1024 + gpu_id = kv_args.gpu_id + device = f"cuda:{gpu_id}" + + custom_mem_pool, _ = _get_custom_mem_pool(device) buffers = [] for _ in range(count): buf = StagingBuffer(size_bytes, device, gpu_id, custom_mem_pool=custom_mem_pool) - engine.batch_register([buf.get_ptr()], [buf.get_size()]) + register_fn(buf.get_ptr(), buf.get_size()) buffers.append(buf) return buffers -def init_staging_allocator(engine, kv_args): - """Create decode-side staging ring-buffer allocator and register with engine. +def init_staging_allocator(register_fn, kv_args): + """Create decode-side staging ring-buffer allocator and register with transport. + + Args: + register_fn: callable(ptr: int, size: int) that registers a memory + region with the transport backend. + kv_args: KVArgs with gpu_id. Returns a StagingAllocator instance. """ from sglang.srt.disaggregation.common.staging_buffer import StagingAllocator - from sglang.srt.disaggregation.mooncake.utils import ( - init_mooncake_custom_mem_pool, - ) from sglang.srt.environ import envs pool_size_mb = envs.SGLANG_DISAGG_STAGING_POOL_SIZE_MB.get() @@ -547,9 +648,9 @@ def init_staging_allocator(engine, kv_args): gpu_id = kv_args.gpu_id device = f"cuda:{gpu_id}" - _, custom_mem_pool, _ = init_mooncake_custom_mem_pool(device) + custom_mem_pool, _ = _get_custom_mem_pool(device) allocator = StagingAllocator(pool_size_bytes, device, gpu_id, custom_mem_pool) - engine.batch_register([allocator.get_base_ptr()], [allocator.get_total_size()]) + register_fn(allocator.get_base_ptr(), allocator.get_total_size()) return allocator @@ -696,7 +797,13 @@ def prefetch_staging_reqs( full_chunk_pages = max(1, cps // page_size) for session_id, tinfo in transfer_infos[room].items(): - if tinfo.is_dummy: + # mooncake exposes is_dummy as a dataclass bool field, NIXL exposes it + # as a method (it consults decode_prefix_len). Normalize via callable() + # so this shared helper works for either backend; treating a bound + # method as truthy (the previous behavior) silently dropped every + # STAGING_REQ on NIXL and deadlocked the prefill transfer worker. + is_dummy_attr = tinfo.is_dummy + if is_dummy_attr() if callable(is_dummy_attr) else is_dummy_attr: continue total_pages = len(tinfo.dst_kv_indices) if total_pages == 0: diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 7b26ec0aa66b..634f2eae5b79 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -294,7 +294,9 @@ def _init_staging_buffers(self, count: int): ) self._staging_ctx.buffers = init_staging_buffers( - self.engine, self.kv_args, count + lambda ptr, size: self.engine.batch_register([ptr], [size]), + self.kv_args, + count, ) self.kv_buffer_tensors = None @@ -303,7 +305,10 @@ def _init_staging_allocator(self): init_staging_allocator, ) - self._staging_ctx.allocator = init_staging_allocator(self.engine, self.kv_args) + self._staging_ctx.allocator = init_staging_allocator( + lambda ptr, size: self.engine.batch_register([ptr], [size]), + self.kv_args, + ) self.kv_buffer_tensors = None def _handle_staging_req(self, msg): @@ -1399,47 +1404,19 @@ def bootstrap_thread(): room = waiting_req_bytes[0].decode("ascii") # Staging: decode reports consumption watermark back to prefill if room == "WATERMARK": - wm_round = int(waiting_req_bytes[1].decode("ascii")) - wm_tail = int(waiting_req_bytes[2].decode("ascii")) - wm_session = ( - waiting_req_bytes[3].decode("ascii") - if len(waiting_req_bytes) > 3 - else "" + from sglang.srt.disaggregation.common.staging_handler import ( + handle_watermark_msg, ) - with self._staging_ctx.watermark_cv: - prev = self._staging_ctx.remote_watermarks.get( - wm_session, (0, 0) - ) - if (wm_round, wm_tail) > prev: - self._staging_ctx.remote_watermarks[wm_session] = ( - wm_round, - wm_tail, - ) - self._staging_ctx.watermark_cv.notify_all() + + handle_watermark_msg(self._staging_ctx, waiting_req_bytes) continue # Staging: decode replies with allocated staging offset if room == "STAGING_RSP": - stg_room = int(waiting_req_bytes[1].decode("ascii")) - stg_chunk_idx = int(waiting_req_bytes[2].decode("ascii")) - stg_offset = int(waiting_req_bytes[3].decode("ascii")) - stg_round = int(waiting_req_bytes[4].decode("ascii")) - stg_end = int(waiting_req_bytes[5].decode("ascii")) - stg_session = waiting_req_bytes[6].decode("ascii") - room_infos = self.transfer_infos.get(stg_room, {}) - tinfo = room_infos.get(stg_session) - if tinfo is not None: - if tinfo.staging is None: - tinfo.staging = StagingTransferInfo() - tinfo.staging.set_chunk( - stg_chunk_idx, stg_offset, stg_round, stg_end - ) - else: - logger.warning( - "STAGING_RSP RECV but tinfo=None room=%s chunk=%d session=%s", - stg_room, - stg_chunk_idx, - stg_session, - ) + from sglang.srt.disaggregation.common.staging_handler import ( + handle_staging_rsp, + ) + + handle_staging_rsp(waiting_req_bytes, self.transfer_infos) continue mooncake_session_id = waiting_req_bytes[3].decode("ascii") if room == "None": @@ -1493,28 +1470,18 @@ def decode_thread(): page_start = int(msg[3].decode("ascii")) num_pages = int(msg[4].decode("ascii")) session_id = msg[5].decode("ascii") - self._chunk_writer_counts[room][chunk_idx].append( - (page_start, num_pages, session_id) - ) handler = self._staging_handler assert ( handler is not None ), "CHUNK_READY received before staging handler initialized" - writers_arrived = len(self._chunk_writer_counts[room][chunk_idx]) - decode_req = handler._room_to_decode_req.get(room) - if decode_req is None: - logger.warning( - "CHUNK_READY received for unregistered room=%s chunk=%d, skipping", - room, - chunk_idx, - ) - continue - num_writers = handler.num_writers_for(decode_req) - if writers_arrived >= num_writers: - handler.submit_chunk_scatter( - room, chunk_idx, page_start, num_pages - ) - del self._chunk_writer_counts[room][chunk_idx] + handler.handle_chunk_arrived( + room, + chunk_idx, + page_start, + num_pages, + session_id, + self._chunk_writer_counts, + ) continue # Staging: prefill pre-requests staging allocation before forward diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 81a68ebb4fee..5fbbef59f8c3 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -8,11 +8,14 @@ import time import uuid from collections import defaultdict -from typing import Dict, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set import numpy as np import numpy.typing as npt +if TYPE_CHECKING: + from sglang.srt.disaggregation.common.staging_handler import StagingTransferInfo + from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll, StateType from sglang.srt.disaggregation.common.conn import ( CommonKVBootstrapServer, @@ -20,6 +23,7 @@ CommonKVReceiver, CommonKVSender, ) +from sglang.srt.disaggregation.common.staging_handler import StagingRegisterInfo from sglang.srt.disaggregation.common.utils import ( FastQueue, group_concurrent_contiguous, @@ -66,6 +70,9 @@ class TransferInfo: required_dst_info_num: int dst_state_indices: List[List[int]] decode_prefix_len: Optional[int] = None # for decode radix cache + # NOTE: optional staging field; populated via STAGING_RSP. Keep at the + # end so positional construction in from_zmq() continues to work. + staging: Optional["StagingTransferInfo"] = None def is_dummy(self): # A transfer is "dummy" only for CP non-authoritative ranks. @@ -126,6 +133,9 @@ class KVArgsRegisterInfo: dst_kv_item_len: int dst_state_item_lens: List[List[int]] = dataclasses.field(default_factory=list) dst_state_dim_per_tensor: List[List[int]] = dataclasses.field(default_factory=list) + # Keep last: optional, parsed from a variable-length tail of the ZMQ + # frame in from_zmq() below, so positional construction stays stable. + staging: Optional["StagingRegisterInfo"] = None @classmethod def from_zmq(cls, msg: List[bytes]): @@ -154,6 +164,7 @@ def from_zmq(cls, msg: List[bytes]): dst_kv_item_len=int(msg[11].decode("ascii")), dst_state_item_lens=dst_state_item_lens, dst_state_dim_per_tensor=dst_state_dim_per_tensor, + staging=StagingRegisterInfo.from_zmq_fields(msg, 14), ) @@ -256,27 +267,210 @@ def __init__( self.register_buffer_to_engine() + self.enable_staging = envs.SGLANG_DISAGG_STAGING_BUFFER.get() + self.kv_buffer_tensors = None + if self.disaggregation_mode == DisaggregationMode.PREFILL: transfer_queue_size = envs.SGLANG_DISAGGREGATION_QUEUE_SIZE.get() self.transfer_queues: List[FastQueue] = [ FastQueue() for _ in range(transfer_queue_size) ] self.exceptions: Dict[int, Exception] = {} - for queue in self.transfer_queues: + # Mirror mooncake: one staging buffer per worker queue, all + # built before workers spawn so each worker owns a private + # buffer (no cross-worker contention on the staging ring). + if self.enable_staging: + self._init_staging_prefill_ctx() + self._init_staging_buffers(len(self.transfer_queues)) + for i, queue in enumerate(self.transfer_queues): + staging_buffer = ( + self._staging_ctx.buffers[i] + if self.enable_staging and self._staging_ctx.buffers + else None + ) threading.Thread( - target=self.transfer_worker, args=(queue,), daemon=True + target=self.transfer_worker, + args=(queue, staging_buffer), + daemon=True, ).start() self._start_bootstrap_thread() elif self.disaggregation_mode == DisaggregationMode.DECODE: self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( TransferStatus ) + if self.enable_staging: + self._init_staging_decode_ctx() + self._staging_handler = None + self._chunk_writer_counts: dict = defaultdict(lambda: defaultdict(list)) + self._start_decode_staging_thread() self._start_heartbeat_checker_thread() else: raise ValueError( f"Unsupported DisaggregationMode: {self.disaggregation_mode}" ) + def _init_staging_prefill_ctx(self): + from sglang.srt.disaggregation.common.staging_handler import ( + PrefillStagingContext, + ) + + self._staging_ctx = PrefillStagingContext() + + def _init_staging_decode_ctx(self): + from sglang.srt.disaggregation.common.staging_handler import ( + DecodeStagingContext, + ) + + self._staging_ctx = DecodeStagingContext() + self._init_staging_allocator() + + def _init_staging_buffers(self, count: int): + from sglang.srt.disaggregation.common.staging_handler import ( + init_staging_buffers, + ) + + gpu_id = self.kv_args.gpu_id + self._staging_ctx.buffers = init_staging_buffers( + lambda ptr, size: self._register_staging_memory(ptr, size, gpu_id), + self.kv_args, + count, + ) + + def _init_staging_allocator(self): + from sglang.srt.disaggregation.common.staging_handler import ( + init_staging_allocator, + ) + + gpu_id = self.kv_args.gpu_id + self._staging_ctx.allocator = init_staging_allocator( + lambda ptr, size: self._register_staging_memory(ptr, size, gpu_id), + self.kv_args, + ) + + def _register_staging_memory(self, ptr: int, size: int, gpu_id: int): + """Register a staging buffer with the NIXL agent.""" + addrs = [(ptr, size, gpu_id, "")] + descs = self.agent.register_memory(addrs, "VRAM") + if not descs: + raise RuntimeError( + f"NIXL memory registration failed for staging buffer " + f"(ptr=0x{ptr:x}, size={size})" + ) + + def set_kv_buffer_tensors(self, k_buffers: list, v_buffers: list, page_size: int): + # NOTE: matches mooncake behavior -- staging buffers are now + # created in __init__ (per-worker), independent of the kv + # tensors. This setter only stashes the tensor metadata used by + # send_kvcache_staged(). + self.kv_buffer_tensors = { + "k_buffers": k_buffers, + "v_buffers": v_buffers, + "page_size": page_size, + } + + def register_staging_room_bootstrap(self, room, bootstrap_infos, receiver): + self._staging_ctx.room_bootstrap[room] = bootstrap_infos + self._staging_ctx.room_receivers[room] = receiver + + def _is_watermark_ready( + self, agent_name: str, alloc_round: int, alloc_end: int + ) -> bool: + from sglang.srt.disaggregation.common.staging_handler import ( + is_watermark_ready, + ) + + return is_watermark_ready(self._staging_ctx, agent_name, alloc_round, alloc_end) + + def _start_decode_staging_thread(self): + """Start a thread on the decode side to recv STAGING_REQ from prefill via ZMQ.""" + + def decode_staging_thread(): + while True: + msg = self.server_socket.recv_multipart() + if msg[0] == b"STAGING_REQ": + self._handle_staging_req(msg) + continue + logger.warning( + "decode_staging_thread: unexpected message tag %s", + msg[0][:20], + ) + + threading.Thread(target=decode_staging_thread, daemon=True).start() + + def _handle_staging_req(self, msg): + from sglang.srt.disaggregation.common.staging_handler import ( + handle_staging_req, + ) + + room = int(msg[1].decode("ascii")) + session_id = msg[4].decode("ascii") + handler = self._staging_handler + assert ( + handler is not None + ), "STAGING_REQ received before staging handler initialized" + decode_req = handler._room_to_decode_req.get(room) + if decode_req is None: + logger.warning( + "STAGING_REQ received for unregistered room=%s, skipping", + room, + ) + return + prefill_tp = decode_req.kv_receiver.prefill_info.attn_tp_size + handle_staging_req( + msg, + self._staging_ctx.allocator, + self.kv_args, + self.attn_tp_size, + prefill_tp, + getattr(self, "kv_buffer_tensors", None), + self._staging_ctx.room_receivers, + self._staging_ctx.room_bootstrap, + ) + + receiver = self._staging_ctx.room_receivers.get(room) + if receiver is not None: + handler.register_wm_subscriber(receiver, session_id) + + def _prefetch_staging_reqs(self, room: int): + """Send STAGING_REQ for all chunks before the prefill forward starts. + + Idempotent per room: the first call for a given room does the full + fan-out (one STAGING_REQ per chunk per peer); subsequent calls return + immediately. This lets the caller invoke this on every chunk without + depending on a chunk_id == 0 sentinel. + """ + if not self.enable_staging or self.kv_buffer_tensors is None: + return + if room in self._staging_ctx.prefetched_rooms: + return + + room_infos = self.transfer_infos.get(room, {}) + needs_staging = any( + not tinfo.is_dummy() + and tinfo.agent_name in self.decode_kv_args_table + and self.decode_kv_args_table[tinfo.agent_name].decode_tp_size + != self.attn_tp_size + for tinfo in room_infos.values() + ) + if not needs_staging: + # Mark anyway so we don't re-evaluate the predicate every chunk. + self._staging_ctx.prefetched_rooms.add(room) + return + + from sglang.srt.disaggregation.common.staging_handler import ( + prefetch_staging_reqs, + ) + + prefetch_staging_reqs( + room, + self.transfer_infos, + self.kv_buffer_tensors, + self.server_args.chunked_prefill_size, + self._staging_ctx.prefetch_requested, + self._staging_ctx.prefetch_sockets, + ) + self._staging_ctx.prefetched_rooms.add(room) + def _start_heartbeat_checker_thread(self): """ Start the heartbeat checker thread for Decode worker. @@ -366,7 +560,12 @@ def _handle_node_failure(self, failed_bootstrap_addr): def check_status(self, bootstrap_room: int): return self.request_status.get(bootstrap_room, KVPoll.WaitingForInput) - def transfer_worker(self, queue: FastQueue): + def transfer_worker(self, queue: FastQueue, staging_buffer=None): + # Per-worker staging strategy: lazy-created on first chunk so we + # see kv_buffer_tensors (set by ModelRunner after engine init). + # Never cache on self -- multiple workers would race the ring. + staging_strategy = None + while True: kv_chunk: TransferKVChunk = queue.get() room = kv_chunk.room @@ -376,20 +575,34 @@ def transfer_worker(self, queue: FastQueue): assert room in self.transfer_infos + # Lazily build a per-worker staging strategy bound to this + # worker's private staging buffer (matches mooncake). + if ( + self.enable_staging + and staging_strategy is None + and staging_buffer is not None + ): + staging_strategy = self._try_create_staging_strategy(staging_buffer) + self.update_status(room, KVPoll.Transferring) reqs_to_be_processed = list(self.transfer_infos[room].values()) handles: List = [] + # Set when staging allocation/watermark is not yet ready and + # the chunk has been re-enqueued. We then break out of the + # per-req loop and `continue` the worker main loop without + # touching room status -- the next pop will retry. + staging_deferred = False + for req in reqs_to_be_processed: assert room == req.room if req.is_dummy(): continue assert req.agent_name in self.decode_kv_args_table - decode_tp_size = self.decode_kv_args_table[ - req.agent_name - ].decode_tp_size + dst_info = self.decode_kv_args_table[req.agent_name] + decode_tp_size = dst_info.decode_tp_size # Skip KV RDMA transfer when there are no pages to send # (e.g., decode-side radix cache matched the entire prefix). @@ -409,34 +622,69 @@ def transfer_worker(self, queue: FastQueue): : len(chunked_dst_kv_indice) ] - notif = f"{req.room}_kv_{kv_chunk.chunk_id}_{int(kv_chunk.is_last)}_{self.kv_args.engine_rank}" + # Decide which kv send path to use: + # 1. Staging (heterogeneous TP, both sides have + # registered staging, watermark/alloc ready) + # 2. send_kvcache (MLA or homogeneous TP) + # 3. send_kvcache_slice (heterogeneous TP fallback, + # or staging hard-failed for this chunk) + use_staging = ( + self.enable_staging + and staging_strategy is not None + and not self.is_mla_backend + and decode_tp_size != self.attn_tp_size + and dst_info.staging is not None + ) - if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): - kv_xfer_handle = self.send_kvcache( - req.agent_name, - kv_chunk.prefill_kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, - chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, - notif, + kv_xfer_handle = None + if use_staging: + kv_xfer_handle, deferred = self._do_staging_transfer( + staging_strategy, + kv_chunk, + req, + dst_info, + queue, ) - else: - kv_xfer_handle = self.send_kvcache_slice( - req.agent_name, - kv_chunk.prefill_kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, - chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, - notif, - prefill_tp_size=self.attn_tp_size, - decode_tp_size=decode_tp_size, - decode_tp_rank=self.decode_kv_args_table[ - req.agent_name - ].decode_tp_rank, - dst_kv_item_len=self.decode_kv_args_table[ - req.agent_name - ].dst_kv_item_len, + if deferred: + # Chunk re-enqueued; stop processing remaining + # reqs for this chunk and let the worker loop + # pick it up again on the next pop. + staging_deferred = True + break + # kv_xfer_handle is None here means staging + # send_kvcache_staged() returned None (e.g. + # decode buffer too small) -- fall through to + # the slice path below. + + if kv_xfer_handle is None: + notif = ( + f"{req.room}_kv_{kv_chunk.chunk_id}" + f"_{int(kv_chunk.is_last)}_{self.kv_args.engine_rank}" ) + if self.is_mla_backend or ( + decode_tp_size == self.attn_tp_size + ): + kv_xfer_handle = self.send_kvcache( + req.agent_name, + kv_chunk.prefill_kv_indices, + dst_info.dst_kv_ptrs, + chunked_dst_kv_indice, + dst_info.gpu_id, + notif, + ) + else: + kv_xfer_handle = self.send_kvcache_slice( + req.agent_name, + kv_chunk.prefill_kv_indices, + dst_info.dst_kv_ptrs, + chunked_dst_kv_indice, + dst_info.gpu_id, + notif, + prefill_tp_size=self.attn_tp_size, + decode_tp_size=decode_tp_size, + decode_tp_rank=dst_info.decode_tp_rank, + dst_kv_item_len=dst_info.dst_kv_item_len, + ) handles.append(kv_xfer_handle) @@ -470,12 +718,16 @@ def transfer_worker(self, queue: FastQueue): aux_xfer_handle = self.send_aux( req.agent_name, kv_chunk.prefill_aux_index, - self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, + dst_info.dst_aux_ptrs, req.dst_aux_index, aux_notif, ) handles.append(aux_xfer_handle) + if staging_deferred: + # Chunk has been re-enqueued; do not advance status. + continue + while handles: states = [self.agent.check_xfer_state(h) for h in handles] if any(s == "ERR" for s in states): @@ -486,6 +738,17 @@ def transfer_worker(self, queue: FastQueue): if kv_chunk.is_last: self.update_status(room, KVPoll.Success) + # Drop per-room state on Success (parity with mooncake + # transfer_worker; staging prefetch sets are NIXL-only). + self.transfer_infos.pop(room, None) + self.req_to_decode_prefix_len.pop(room, None) + if self.enable_staging and self._staging_ctx is not None: + self._staging_ctx.prefetched_rooms.discard(room) + self._staging_ctx.prefetch_requested = { + k + for k in self._staging_ctx.prefetch_requested + if k[0] != room + } else: self.update_status(room, KVPoll.Transferring) except Exception as e: @@ -825,6 +1088,198 @@ def make_req_array(addr_chunks, size, gpu): return xfer_handle + def send_kvcache_staged( + self, + peer_name: str, + prefill_kv_indices: npt.NDArray[np.int32], + dst_staging_ptr: int, + dst_staging_size: int, + dst_gpu_id: int, + dst_tp_rank: int, + dst_attn_tp_size: int, + dst_kv_item_len: int, + notif: str, + staging_buffer=None, + ): + """Transfer KV cache via staging buffers (gather -> bulk RDMA -> scatter on decode).""" + from sglang.srt.disaggregation.common.staging_buffer import ( + compute_head_slice_params, + compute_staging_layout, + gather_all_layers_to_staging, + resolve_total_kv_heads, + ) + + if self.kv_buffer_tensors is None or staging_buffer is None: + return None + + k_buffers = self.kv_buffer_tensors["k_buffers"] + v_buffers = self.kv_buffer_tensors["v_buffers"] + page_size = self.kv_buffer_tensors["page_size"] + num_layers = len(k_buffers) + head_dim = k_buffers[0].shape[-1] + dtype_size = k_buffers[0].element_size() + + total_kv_heads = resolve_total_kv_heads(self.kv_args, self.attn_tp_size) + + local_tp_rank = self.kv_args.engine_rank % self.attn_tp_size + src_head_start, num_heads_to_send, _, _ = compute_head_slice_params( + self.attn_tp_size, + dst_attn_tp_size, + local_tp_rank, + dst_tp_rank, + total_kv_heads, + ) + + num_tokens = len(prefill_kv_indices) * page_size + per_layer_bytes = num_tokens * num_heads_to_send * head_dim * dtype_size + per_rank_bytes = per_layer_bytes * num_layers * 2 + + num_writers, writer_rank_bytes, total_staging_needed = compute_staging_layout( + self.attn_tp_size, + dst_attn_tp_size, + dst_tp_rank, + total_kv_heads, + num_tokens, + head_dim * dtype_size, + num_layers, + ) + writer_idx = local_tp_rank % num_writers if num_writers > 1 else 0 + rank_offset = sum(writer_rank_bytes[:writer_idx]) + + if not staging_buffer.fits(per_rank_bytes): + logger.warning( + f"Prefill staging too small for {per_rank_bytes} bytes, falling back" + ) + return None + if dst_staging_size < total_staging_needed: + logger.warning( + f"Decode staging too small: need {total_staging_needed} bytes, " + f"have {dst_staging_size}, falling back" + ) + return None + + # gather_all_layers_to_staging() runs the gather kernel on its own + # dedicated stream and synchronizes that stream before returning, so + # the staging buffer is fully populated and visible to the NIC by the + # time we post the RDMA WRITE below. No extra sync needed (matches + # mooncake's send_kvcache_staged behavior). + gather_all_layers_to_staging( + k_buffers, + v_buffers, + prefill_kv_indices, + staging_buffer, + src_head_start, + num_heads_to_send, + page_size, + self.kv_args.gpu_id, + ) + + dst_write_ptr = dst_staging_ptr + rank_offset + src_reqs = np.array( + [[staging_buffer.get_ptr(), per_rank_bytes, self.kv_args.gpu_id]], + dtype=np.int64, + ) + dst_reqs = np.array( + [[dst_write_ptr, per_rank_bytes, dst_gpu_id]], dtype=np.int64 + ) + + src_descs = self.agent.get_xfer_descs(src_reqs, "VRAM") + dst_descs = self.agent.get_xfer_descs(dst_reqs, "VRAM") + + xfer_handle = self.agent.initialize_xfer( + "WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii") + ) + if not xfer_handle: + raise RuntimeError( + f"[Staging] Failed to create NIXL bulk transfer " + f"(src=0x{staging_buffer.get_ptr():x}, dst=0x{dst_write_ptr:x}, " + f"size={per_rank_bytes})" + ) + state = self.agent.transfer(xfer_handle) + if state == "ERR": + raise RuntimeError("[Staging] NIXL bulk transfer failed to post") + return xfer_handle + + def _try_create_staging_strategy(self, staging_buffer): + """Create a per-worker PrefillStagingStrategy bound to ``staging_buffer``. + + Returns ``None`` if staging is disabled or kv tensors not yet set. + Caller is expected to keep the returned strategy as a worker-local + variable; never cache on ``self`` (multiple workers would race on + the underlying staging ring buffer). + """ + if not self.enable_staging or self.kv_buffer_tensors is None: + return None + from sglang.srt.disaggregation.common.staging_handler import ( + PrefillStagingStrategy, + ) + + return PrefillStagingStrategy(self, staging_buffer) + + def _do_staging_transfer( + self, + staging_strategy, + kv_chunk: "TransferKVChunk", + req: "TransferInfo", + dst_info: "KVArgsRegisterInfo", + queue: FastQueue, + ): + """Attempt staging transfer for one chunk. Returns (xfer_handle, deferred). + + Mirrors mooncake._do_staging_transfer semantics: + - staging not ready (watermark/alloc pending) -> ``queue.put(kv_chunk)`` + re-enqueue the chunk and return ``(None, True)``. Caller should + ``break`` out of the per-req loop and ``continue`` the worker + main loop without updating room status -- the chunk will be + retried on the next pop. + - oversized chunk (will never fit) -> raise RuntimeError. + - staging successfully posted -> return ``(handle, False)``. The + caller appends the handle to the per-chunk handle list and + busy-polls it to DONE alongside other handles. + - send_kvcache_staged returned None (decode buffer too small, + kv_buffer_tensors missing, etc.) -> return ``(None, False)``, + signalling the caller to fall back to send_kvcache_slice. + """ + page_start = kv_chunk.index_slice.start + num_pages = len(kv_chunk.prefill_kv_indices) + + ready, chunk_idx, c_offset, _, _ = staging_strategy.check_ready( + req, page_start, num_pages, session_id=req.agent_name + ) + if not ready: + from sglang.srt.disaggregation.common.staging_buffer import ( + StagingAllocator, + ) + + if c_offset == StagingAllocator.ALLOC_OVERSIZED: + raise RuntimeError( + f"[Staging] Chunk staging allocation permanently failed: " + f"chunk exceeds ring buffer total size " + f"(room={kv_chunk.room}). Increase " + f"SGLANG_DISAGG_STAGING_POOL_SIZE_MB." + ) + queue.put(kv_chunk) + return (None, True) + + notif_tag = ( + f"{req.room}_stg_{kv_chunk.chunk_id}_{int(kv_chunk.is_last)}" + f"_{self.kv_args.engine_rank}_{chunk_idx}" + f"_{page_start}_{num_pages}_{req.agent_name}" + ) + handle = self.send_kvcache_staged( + req.agent_name, + kv_chunk.prefill_kv_indices, + dst_info.staging.base_ptr + c_offset, + dst_info.staging.total_size - c_offset, + dst_info.gpu_id, + dst_info.decode_tp_rank, + dst_info.decode_tp_size, + dst_info.dst_kv_item_len, + notif_tag, + staging_buffer=staging_strategy.staging_buffer, + ) + return (handle, False) + def send_aux( self, peer_name: str, @@ -1124,6 +1579,19 @@ def add_transfer_request( assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last or (is_last and aux_index is not None) + # Prefetch STAGING_REQ to decode before enqueueing so decode has + # already allocated staging by the time the worker picks up the + # chunk. Internally a no-op when staging is disabled or no peer + # in this room needs heterogeneous-TP staging. + if self.enable_staging: + self._prefetch_staging_reqs(bootstrap_room) + + # Transfer is async: just enqueue the chunk; the per-queue worker + # (transfer_worker) does the actual gather + RDMA. Routing by + # ``room % N`` keeps every chunk of a given room on the same + # worker -- and therefore on the same private staging buffer -- + # which is required for the staging ring's offset/watermark + # state machine to advance correctly. shard_idx = bootstrap_room % len(self.transfer_queues) self.transfer_queues[shard_idx].put( TransferKVChunk( @@ -1142,45 +1610,132 @@ def update_transfer_status(self): # Process notifications from received transfers. notif_map = self.agent.get_new_notifs() for peer_name, messages in notif_map.items(): - # We could also check that self.bootstrap_info['agent_name'] matches - # the message sender. But the bootstrap room alone should be - # sufficient to map the status. for msg in messages: - components = msg.decode("ascii").split("_", 4) + # Notification tag layouts (underscore-separated): + # kv: {room}_kv_{chunk_id}_{is_last}_{pp_rank} -> 5 fields + # stg: {room}_stg_{chunk_id}_{is_last}_{pp_rank}_{chunk_idx} + # _{page_start}_{num_pages}_{agent_name} -> 9 fields + # aux: {room}_aux -> 2 fields + # state: {room}_state_{pp_rank} -> 3 fields + # maxsplit=8 keeps everything past the 8th underscore in the + # last component, so agent_name (which may itself contain + # underscores) lands intact in components[8] for the stg path. + components = msg.decode("ascii").split("_", 8) room = int(components[0]) - if components[1] == "kv": + tag = components[1] + if tag == "kv": chunk_id = int(components[2]) is_last = bool(int(components[3])) pp_rank = int(components[4]) if len(components) > 4 else 0 - # Track received chunks per pp_rank - self.transfer_statuses[room].received_kvs_per_pp[pp_rank].add( - chunk_id - ) - if is_last: - # Record expected chunk count for this pp_rank - self.transfer_statuses[room].expected_kvs_per_pp[pp_rank] = ( - chunk_id + 1 - ) - # Set num_pp_ranks_expected from table (or default to 1) - if self.transfer_statuses[room].num_pp_ranks_expected is None: - self.transfer_statuses[room].num_pp_ranks_expected = ( - self.required_prefill_response_num_table.get(room, 1) - ) - elif components[1] == "aux": - self.transfer_statuses[room].received_aux = True - # Handle "nokv" marker: no KV pages were sent for - # this pp_rank (decode-side radix cache hit). - if len(components) > 3 and components[2] == "nokv": - pp_rank = int(components[3]) - self.transfer_statuses[room].expected_kvs_per_pp[pp_rank] = 0 - if self.transfer_statuses[room].num_pp_ranks_expected is None: - self.transfer_statuses[room].num_pp_ranks_expected = ( - self.required_prefill_response_num_table.get(room, 1) - ) - elif components[1] == "state": + self._track_kv_arrival(room, chunk_id, is_last, pp_rank) + elif tag == "stg": + self._handle_stg_notification(components, room) + elif tag == "aux": + # main's "nokv" marker (decode-side radix cache hit): + # mark expected_kvs_per_pp[pp_rank] = 0 for this rank. + self._handle_aux_notification(room, components) + elif tag == "state": pp_rank = int(components[2]) if len(components) > 2 else 0 self.transfer_statuses[room].received_state_per_pp.add(pp_rank) + def _handle_stg_notification(self, components, room: int): + """Handle a staging RDMA notification tag. + + Format: {room}_stg_{chunk_id}_{is_last}_{pp_rank}_{chunk_idx}_{page_start}_{num_pages}_{agent_name} + """ + chunk_id = int(components[2]) + is_last = bool(int(components[3])) + pp_rank = int(components[4]) + chunk_idx = int(components[5]) + page_start = int(components[6]) + num_pages = int(components[7]) + agent_name = components[8] if len(components) > 8 else "" + self._track_kv_arrival(room, chunk_id, is_last, pp_rank) + self._handle_staging_chunk_arrived( + room, chunk_idx, page_start, num_pages, agent_name + ) + + def _handle_aux_notification(self, room: int, components: List[str]): + """Handle an aux notification and trigger last scatter if staging is complete. + + Notification tag layouts: + aux: {room}_aux -> 2 fields + aux (nokv): {room}_aux_nokv_{pp_rank} -> 4 fields + (decode-side radix cache hit; this pp_rank sent + no KV pages, so expected_kvs_per_pp[pp_rank] = 0) + """ + self.transfer_statuses[room].received_aux = True + # main's "nokv" marker (decode-side radix cache hit, see #19746). + if len(components) > 3 and components[2] == "nokv": + pp_rank = int(components[3]) + self.transfer_statuses[room].expected_kvs_per_pp[pp_rank] = 0 + if self.transfer_statuses[room].num_pp_ranks_expected is None: + self.transfer_statuses[room].num_pp_ranks_expected = ( + self.required_prefill_response_num_table.get(room, 1) + ) + if ( + self.enable_staging + and self._staging_handler is not None + and self._staging_handler.is_staging_room(room) + ): + self._maybe_submit_last_scatter(room) + + def _track_kv_arrival(self, room: int, chunk_id: int, is_last: bool, pp_rank: int): + """Update transfer status tracking for a kv chunk arrival.""" + self.transfer_statuses[room].received_kvs_per_pp[pp_rank].add(chunk_id) + if is_last: + self.transfer_statuses[room].expected_kvs_per_pp[pp_rank] = chunk_id + 1 + if self.transfer_statuses[room].num_pp_ranks_expected is None: + self.transfer_statuses[room].num_pp_ranks_expected = ( + self.required_prefill_response_num_table.get(room, 1) + ) + if ( + self.enable_staging + and self._staging_handler is not None + and self._staging_handler.is_staging_room(room) + ): + self._maybe_submit_last_scatter(room) + + def _handle_staging_chunk_arrived( + self, + room: int, + chunk_idx: int, + page_start: int, + num_pages: int, + agent_name: str, + ): + """Process a staging chunk arrival via RDMA notification.""" + handler = self._staging_handler + if handler is None: + return + handler.handle_chunk_arrived( + room, + chunk_idx, + page_start, + num_pages, + agent_name, + self._chunk_writer_counts, + ) + + def _maybe_submit_last_scatter(self, room: int): + """Check if all kv+aux transfers are done and submit last scatter if so.""" + status = self.transfer_statuses.get(room) + if status is None: + return + if not status.received_aux: + return + if status.num_pp_ranks_expected is None: + return + if len(status.expected_kvs_per_pp) < status.num_pp_ranks_expected: + return + for pp_rank, expected in status.expected_kvs_per_pp.items(): + if len(status.received_kvs_per_pp[pp_rank]) != expected: + return + handler = self._staging_handler + if handler is not None and handler.is_staging_room(room): + handler.submit_last_scatter_async(room) + self._chunk_writer_counts.pop(room, None) + def check_transfer_done(self, room: int): if room not in self.transfer_statuses: return False @@ -1194,6 +1749,27 @@ def bootstrap_thread(): logger.debug( f"Received multipart with total byte size {sum(len(x) for x in waiting_req_bytes)}" ) + + # Staging: decode reports consumption watermark back to prefill + if waiting_req_bytes[0] == b"WATERMARK": + if self.enable_staging: + from sglang.srt.disaggregation.common.staging_handler import ( + handle_watermark_msg, + ) + + handle_watermark_msg(self._staging_ctx, waiting_req_bytes) + continue + + # Staging: decode replies with allocated staging offset + if waiting_req_bytes[0] == b"STAGING_RSP": + if self.enable_staging: + from sglang.srt.disaggregation.common.staging_handler import ( + handle_staging_rsp, + ) + + handle_staging_rsp(waiting_req_bytes, self.transfer_infos) + continue + assert ( waiting_req_bytes[0] == GUARD ), f"First message should be {GUARD}. Foreign traffic?" @@ -1356,6 +1932,16 @@ def send_metadata( self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) return + # Register staging room bootstrap info for staging handler + if ( + self.kv_mgr.enable_staging + and self.kv_mgr._staging_ctx.allocator is not None + ): + self.chunk_staging_infos = [] + self.kv_mgr.register_staging_room_bootstrap( + self.bootstrap_room, self.bootstrap_infos, self + ) + for bootstrap_info in self.bootstrap_infos: logger.debug( f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" @@ -1453,6 +2039,18 @@ def _register_kv_args(self): getattr(self.kv_mgr.kv_args, "state_dim_per_tensor", []) or [], "I" ) + # Include staging allocator metadata if available + if ( + self.kv_mgr.enable_staging + and self.kv_mgr._staging_ctx.allocator is not None + ): + _alloc = self.kv_mgr._staging_ctx.allocator + packed_staging_base_ptr = struct.pack("Q", _alloc.get_base_ptr()) + staging_total_size_str = str(_alloc.get_total_size()).encode("ascii") + else: + packed_staging_base_ptr = b"" + staging_total_size_str = b"" + with lock: sock.send_multipart( [ @@ -1471,6 +2069,8 @@ def _register_kv_args(self): str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"), packed_state_item_lens, packed_state_dim_per_tensor, + packed_staging_base_ptr, + staging_total_size_str, ] ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 79d9a7584f2c..30a04dadf447 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -3989,11 +3989,11 @@ def _handle_pd_disaggregation(self): if self.disaggregation_mode in ("prefill", "decode"): if ( envs.SGLANG_DISAGG_STAGING_BUFFER.get() - and self.disaggregation_transfer_backend != "mooncake" + and self.disaggregation_transfer_backend not in ("mooncake", "nixl") ): raise ValueError( f"SGLANG_DISAGG_STAGING_BUFFER requires " - f"disaggregation_transfer_backend='mooncake', " + f"disaggregation_transfer_backend='mooncake' or 'nixl', " f"got '{self.disaggregation_transfer_backend}'." )