[Disagg][NIXL] Add staging buffer support for heterogeneous TP KV transfer#22536
Conversation
…nsfer Implement GPU staging buffer for NIXL backend to enable bulk RDMA transfers under heterogeneous TP, reducing RDMA work requests by ~1000x compared to per-head scatter transfers. Key changes: - server_args.py: Allow SGLANG_DISAGG_STAGING_BUFFER with NIXL backend - Staging buffer lifecycle: _init_staging_prefill_ctx/decode_ctx, _init_staging_buffers, _init_staging_allocator, _register_staging_memory - Prefill side: send_kvcache_staged() gathers KV heads into staging buffer then posts a single bulk RDMA write; _prefetch_staging_reqs() pre-sends STAGING_REQ to decode before forward starts - Decode side: _start_decode_staging_thread() receives STAGING_REQ, allocates staging offsets, replies with STAGING_RSP; watermark-based flow control for ring buffer reuse - Notification handling: stg notifications trigger chunk scatter; _maybe_submit_last_scatter() for final scatter after all chunks arrive - KVArgsRegisterInfo: add staging_base_ptr, staging_total_size fields at msg[12]/msg[13] - KVReceiver: send staging allocator metadata during registration
NIXL staging buffers were allocated with cudaMalloc instead of cuMemCreate. Pass custom_mem_pool from init_mooncake_custom_mem_pool() to StagingBuffer/StagingAllocator constructors, matching Mooncake behavior.
…ake pattern - Extract staging transfer logic into helper methods - Delegate common operations to staging_handler.py - Remove unnecessary getattr/hasattr defensive checks - Simplify NixlKVReceiver staging registration
- Extract _get_custom_mem_pool() in staging_handler.py to centralize mooncake custom memory pool initialization - Change init_staging_buffers/init_staging_allocator to accept a register_fn callback instead of mooncake engine, making them transport-agnostic - NIXL now delegates to staging_handler instead of directly importing mooncake.utils.init_mooncake_custom_mem_pool - Replace hardcoded 8192 with DEFAULT_CHUNKED_PREFILL_SIZE constant Made-with: Cursor
Refactor commit accidentally removed defensive getattr() calls, # type: ignore comments, and changed is_dummy() method to @Property. These were pre-existing patterns not introduced by staging buffer code and should not be altered. Made-with: Cursor
…ategy, use StagingRegisterInfo in NIXL - Fix prefetch_staging_reqs() is_dummy compatibility: handle both bool field (mooncake) and method (NIXL) via callable() check - Remove misleading "Mooncake-specific" section header in staging_handler.py — most code is backend-agnostic - Generalize PrefillStagingStrategy.check_ready() with session_id param so NIXL can pass req.agent_name instead of mooncake_session_id - NIXL: use StagingRegisterInfo.from_zmq_fields() instead of manual staging_base_ptr/staging_total_size parsing - NIXL: delegate readiness check to PrefillStagingStrategy.check_ready() instead of inlining chunk_idx/offset/watermark logic Made-with: Cursor
These # --- section headers were not in the original codebase and add unnecessary noise. Made-with: Cursor
Made-with: Cursor
Consolidate the 3-way KV dispatch (same-TP, staging, slice-fallback) into a single _send_kv_for_req method, eliminating _send_kv_slice_fallback. Made-with: Cursor
Move inline stg tag parsing and aux staging checks into dedicated _handle_stg_notification and _handle_aux_notification methods, keeping the main notification dispatch loop concise. Made-with: Cursor
…mon handler - Group scattered staging sub-functions in nixl/conn.py: _get_staging_strategy + _do_staging_transfer now adjacent to send_kvcache_staged, _handle_watermark_msg + _handle_staging_rsp now adjacent to _maybe_submit_last_scatter - Extract DecodeStagingHandler.handle_chunk_arrived() in staging_handler.py, unifying the chunk writer tracking + scatter submission logic used by both NIXL (_handle_staging_chunk_arrived) and mooncake (CHUNK_READY handler) Made-with: Cursor
…fer_request Restore the original inline if/elif/else call style for send_kvcache and send_kvcache_slice, adding staging as a new first branch without changing the existing call patterns. Made-with: Cursor
- Remove DEFAULT_CHUNKED_PREFILL_SIZE constant, restore inline `or 8192` - Revert getattr(req, "staging", None) back to req.staging Made-with: Cursor
Use self.decode_kv_args_table[req.agent_name].xxx consistently instead of extracting dst_info, matching the upstream code style. Made-with: Cursor
…ndler Both NIXL and mooncake had identical WATERMARK and STAGING_RSP message handling logic. Extract into staging_handler.py as shared functions, reducing duplication across backends. Made-with: Cursor
Made-with: Cursor
There was a problem hiding this comment.
Code Review
This pull request refactors the staging buffer logic to be transport-agnostic and introduces staging support for the NIXL disaggregation backend. Key changes include moving staging-related message handling and buffer initialization to a common handler, updating the Mooncake backend to utilize these shared utilities, and implementing RDMA-based staging transfers for NIXL. Feedback suggests implementing a retry mechanism in the staging transfer logic to avoid premature fallbacks to less efficient slice-based transfers and improving the robustness of notification tag parsing to handle potential underscores in agent names.
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
- Drop redundant `torch.cuda.current_stream().synchronize()` in
`send_kvcache_staged`. `gather_all_layers_to_staging` already syncs
its dedicated `_gather_stream` before returning, so the staging
buffer is fully populated and visible to the NIC by the time the
RDMA WRITE is posted (matches mooncake's behavior). Drops the
resulting unused `import torch`.
- Replace `Optional[object]` fields on `TransferInfo` /
`KVArgsRegisterInfo` with `Optional["StagingTransferInfo"]` /
`Optional["StagingRegisterInfo"]` (forward refs under
TYPE_CHECKING).
- Normalize `is_dummy` API: convert NIXL `TransferInfo.is_dummy()` to
an `@property`, matching mooncake's plain attribute. Updates the
two NIXL call sites and removes the `callable(tinfo.is_dummy)`
hack from `prefetch_staging_reqs`.
- Inline `_handle_watermark_msg` / `_handle_staging_rsp` wrappers in
NIXL `bootstrap_thread` so both backends call the common helpers
the same way.
- Extract `_dispatch_kv_transfer` helper from `add_transfer_request`
so each request appends exactly one kv handle, instead of three
different `handles.append` sites in nested branches.
- Comment `split("_", 8)` to document the per-tag layout and why the
maxsplit value lets `agent_name` (which can itself contain
underscores) survive the split intact.
- Replace `chunk_id == 0` prefetch sentinel with an explicit
`PrefillStagingContext.prefetched_rooms` set so
`_prefetch_staging_reqs` is idempotent per room and the
invariant ("fan out STAGING_REQ once per room") is local to the
function instead of relying on caller behavior.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
The Mooncake part LGTM
|
@YAMY1234 could you rebase this? thanks |
There was a problem hiding this comment.
cc: @ishandhanani, we should be able to merge this PR once we get an approval from the nixl backend maintainer.
Bring in sgl-project#23967 (Nixl async transfer) and other main changes since the last merge. Conflicts were limited to python/sglang/srt/disaggregation/nixl/conn.py: 1. TransferInfo: kept main's `decode_prefix_len` field + `is_dummy()` method form, appended this PR's `staging` field at the tail. Updated 2 callers in this file from `req.is_dummy` to `req.is_dummy()`. 2. NixlKVManager.__init__ (PREFILL branch): kept this PR's `_init_staging_prefill_ctx()` AND main's `transfer_queues` / `transfer_worker` thread pool. Both run; staging ctx is initialized before workers spawn. 3. add_transfer_request: took main's async enqueue body (puts TransferKVChunk into transfer_queues[room % N], returns None) but kept this PR's `_prefetch_staging_reqs(bootstrap_room)` call before the enqueue. The staging dispatch (`_dispatch_kv_transfer`, `_do_staging_transfer`, `send_kvcache_staged`) is now temporarily dead code: enabling SGLANG_DISAGG_STAGING_BUFFER on NIXL has no effect until the next commit moves staging dispatch into `transfer_worker` (per the mooncake pattern). 4. update_transfer_status: kept this PR's tag-based dispatch (`_track_kv_arrival` / `_handle_stg_notification` / `_handle_aux_notification`) and merged main's "nokv" handling for decode-side radix cache hit (sgl-project#19746) into `_handle_aux_notification`. After this commit the staging buffer code path is preserved but unused; plain heterogeneous-TP transfers fall back to send_kvcache_slice via the new async worker. The next commit will wire staging into the worker (per-worker staging buffer + deferred re-enqueue on watermark not-ready, matching mooncake). Co-authored-by: Cursor <cursoragent@cursor.com>
…ke parity) After the previous merge of sgl-project#23967 (Nixl async transfer), staging buffer dispatch lived only in the now-deleted synchronous path of add_transfer_request, leaving SGLANG_DISAGG_STAGING_BUFFER a no-op on NIXL. This commit ports the staging dispatch into transfer_worker, 1:1 mirroring mooncake's per-worker staging design. 1. PREFILL __init__: build N staging buffers (one per transfer_queue) before workers spawn, and pass each worker its private buffer (NixlKVManager.__init__). Removes the lazy single-buffer creation in set_kv_buffer_tensors -- mooncake-style, staging buffers no longer depend on kv_buffer_tensors. 2. _try_create_staging_strategy(staging_buffer) replaces _get_staging_strategy. Returns a fresh PrefillStagingStrategy bound to the caller's staging buffer. The strategy MUST be a worker-local variable; never cache on self -- multiple workers would race on the same staging ring. 3. transfer_worker(queue, staging_buffer=None) now lazy-creates a per-worker staging_strategy on the first chunk it sees, then for each req in a chunk picks among: - staging (heterogeneous TP, both sides registered, watermark ready) -> _do_staging_transfer - send_kvcache (MLA / homogeneous TP) - send_kvcache_slice (heterogeneous TP, no staging or staging hard-failed for this chunk) When staging is not ready (watermark/alloc pending), _do_staging_transfer re-enqueues the chunk and signals `staging_deferred=True`; the worker breaks the per-req loop and `continue`s the main loop without advancing room status, so the chunk gets retried on the next pop. Same control-flow as mooncake.transfer_worker. 4. _do_staging_transfer reshaped to (handle, deferred) return tuple: - (None, True) -> chunk re-enqueued, caller should defer - (None, False) -> hard fallback, caller should try slice - (handle, False) -> staging RDMA posted; handle joins the per-chunk handle list and is busy-polled to DONE alongside aux/state handles. Oversized chunks (cannot ever fit) raise immediately. 5. _dispatch_kv_transfer (the old synchronous-path entry) is removed. add_transfer_request stays a thin enqueue + _prefetch_staging_reqs wrapper. Notes vs mooncake: - NIXL workers do NOT need an executor (no per-slice ThreadPoolExecutor); send_kvcache_slice posts a single bulk transfer. - NIXL workers do NOT send a separate ZMQ CHUNK_READY message: decode observes chunk arrival via the RDMA `stg_*` notification tag posted by send_kvcache_staged, which the decode-side receiver thread already handles. - Memory: staging pool grows N x (one per worker, default SGLANG_DISAGGREGATION_QUEUE_SIZE=4). Tunable via SGLANG_DISAGG_STAGING_POOL_SIZE_MB. Co-authored-by: Cursor <cursoragent@cursor.com>
|
need to fix lint as well |
The shared helper prefetch_staging_reqs() in common/staging_handler.py was written under the assumption that TransferInfo.is_dummy is a plain attribute / @Property on both backends. After merging upstream main (which introduced decode_prefix_len, sgl-project#19746), NIXL's TransferInfo.is_dummy was changed from @Property to a regular method to consult decode_prefix_len. NIXL's own conn.py call sites were updated to use is_dummy() but this shared helper was missed. Effect: tinfo.is_dummy evaluates to a bound-method object on NIXL, which is always truthy. The if branch is always taken, every STAGING_REQ is silently skipped, decode never allocates a staging chunk, STAGING_RSP never returns to prefill, the per-chunk staging info stays None, check_ready always returns not-ready, and the chunk is re-enqueued forever. The transfer worker spins in its dispatch loop and the prefill inflight queue never drains -- exactly the deadlock observed on lyris job 1733669 (11 inflight reqs, no further metrics, decode side never sees any stg_* notification). Mooncake is not affected because its TransferInfo.is_dummy is a real dataclass bool field. Fix: normalize via callable() in the shared helper so it works for both the mooncake (attribute) and NIXL (method) shapes. This is the minimal single-point fix and does not require touching either backend's existing TransferInfo definition or any other call site. Co-authored-by: Cursor <cursoragent@cursor.com>
…nup) Addresses the three actionable comments from @ShangmingCai on PR sgl-project#22536: 1. Drop duplicate StagingRegisterInfo import. The class was imported both in the TYPE_CHECKING block and again lazily inside KVArgsRegisterInfo. from_zmq(). Promote it to a module-level runtime import (no circular dep risk -- staging_handler.py only imports stdlib + torch) and remove the redundant lazy import. Keep StagingTransferInfo in TYPE_CHECKING because it is only referenced from a forward-ref annotation. 2. Add an explanatory comment to KVArgsRegisterInfo.staging noting why the optional staging field must remain the LAST field of the dataclass -- from_zmq() relies on positional construction and the staging payload is a variable-length tail of the ZMQ frame. 3. Clean up per-room state in the prefill transfer_worker when the last chunk reports Success. Without this, prefetched_rooms / prefetch_requested / transfer_infos / req_to_decode_prefix_len grew without bound across long-running services as new bootstrap rooms kept arriving (mooncake's transfer_worker already does the equivalent transfer_infos.pop on Success -- this brings NIXL to parity and additionally sweeps the staging-only prefetch sets). Also pick up an incidental black reformat of one call site in transfer_worker. Co-authored-by: Cursor <cursoragent@cursor.com>
Trim the two prose blocks added in the previous review-feedback commit down to a single sentence each, keeping the substance (why staging is last, mooncake-parity cleanup) without restating it across multiple lines. Co-authored-by: Cursor <cursoragent@cursor.com>
iyastreb
left a comment
There was a problem hiding this comment.
I have tested it with my banchmark, and these are results I have:
# p2d4, Best TTFT of 4 runs
num_prompts before after
1 34 32
2 54 50
4 57 53
8 70 59
16 88 74
32 150 115
64 226 154
128 425 238
256 749 355
512 2078 1437
# p4d4, Best TTFT of 4 runs
num_prompts before after
1 31 31
2 45 45
4 47 48
8 52 51
16 58 60
32 94 97
64 127 127
128 231 182
256 327 314
512 633 577
It speeds up significantly the heterogenous setup (p2d4), and even homogenous one on large dimensions (>64)
|
Thanks @iyastreb for the verification! Since it has been verified by Nixl team @ShangmingCai could you take a second look and merge it if it looks good to you? Thanks!😄 |
|
/rerun-test test/registered/distributed/test_disaggregation_different_tp.py |
|
🚀 |
|
/rerun-test test/registered/distributed/test_disaggregation_different_tp.py |
|
🚀 |
…nsfer (sgl-project#22536) Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Shangming Cai <csmthu@gmail.com>
…nsfer (sgl-project#22536) Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Shangming Cai <csmthu@gmail.com>
Motivation
NIXL disaggregated serving currently requires prefill and decode to use the same TP layout. When prefill uses TP4 and decode uses DEP4 (DP4+TP4+EP4), each prefill rank's KV cache must be split and sent to multiple decode ranks. Without staging buffers, the prefill side must issue
prefill_tp × decode_tpseparate RDMA transfers per chunk, saturating the RDMA descriptor table and adding significant per-transfer overhead.The staging buffer approach (already implemented for mooncake in #19890) consolidates KV heads into a contiguous staging region on prefill, issues a single bulk RDMA transfer per rank pair, and lets the decode side scatter from the staging buffer into the final KV cache pages asynchronously.
Collaborate with @Aphoh (The author of #18968)
Modification
This PR extends the existing staging buffer support from mooncake to NIXL. The core staging lifecycle logic (
staging_handler.py,staging_buffer.py) is already shared between backends — this PR adds the NIXL-specific integration and refactors mooncake to use the newly shared functions.Key differences from mooncake implementation:
stgtag), while mooncake uses ZMQCHUNK_READYmessages. The RDMA notification path avoids an extra network round-trip.nixl_agent.get_new_notifs(), RDMA notification processing and scatter submission stay on the main thread poll path. The background thread only handles ZMQSTAGING_REQmessages. In mooncake, both notification processing and scatter run on the background thread.init_staging_buffers(register_fn, ...)andinit_staging_allocator(register_fn, ...)now accept a genericregister_fncallback instead of a mooncake engine directly, enabling NIXL to register buffers vianixl_agent. Mooncake is refactored to use the same callback pattern.handle_watermark_msg(),handle_staging_rsp(), andhandle_chunk_arrived()are extracted from mooncake intostaging_handler.pyas common functions used by both backends.Files changed:
nixl/conn.py,common/staging_handler.py,mooncake/conn.pyAccuracy
GSM8K (full 1319 questions), Qwen3.5-397B-A17B-FP8, 1P1D TP4→DEP4, NIXL + staging:
Accuracy is consistent with the non-disaggregated baseline (~0.98), confirming staging buffer does not affect model correctness.
Performance
Setup: Qwen3.5-397B-A17B-FP8, 1P1D, GB200 Lyris cluster, ISL=1000, OSL=1000, sa-bench concurrency sweep 1→1024.
1. TP4→DEP4 Staging vs No-Staging (heterogeneous TP, measuring staging benefit)
At high concurrency (512–1024), staging delivers 29–64% throughput improvement and dramatically lower TTFT (1.3s vs 21s at c=512, 7.5s vs 67s at c=1024). The no-staging path saturates RDMA descriptors at high concurrency, causing TTFT to blow up. TPOT remains comparable across both configurations.
2. TP4→DEP4 Staging vs DEP4→DEP4 NIXL (heterogeneous vs homogeneous, measuring staging overhead)
TP4→DEP4 with staging has no systematic overhead compared to homogeneous DEP4→DEP4. TTFT and TPOT are comparable. At mid-to-high concurrency (128–512), TP4 prefill is actually more efficient than DEP4 prefill, so the heterogeneous layout outperforms homogeneous in those ranges.