Skip to content

[P/D disagg] - support decode side radix cache#19746

Open
ishandhanani wants to merge 57 commits intomainfrom
ishan/add-radix-cache-decode
Open

[P/D disagg] - support decode side radix cache#19746
ishandhanani wants to merge 57 commits intomainfrom
ishan/add-radix-cache-decode

Conversation

@ishandhanani
Copy link
Copy Markdown
Collaborator

@ishandhanani ishandhanani commented Mar 3, 2026

Summary

In PD disaggregation, the decode worker can now use radix cache to reuse shared prefixes and request only the delta KV from prefill instead of transferring the full prefix on every turn.

This is enabled with --disaggregation-decode-enable-radix-cache on the decode server.

For now, this path is supported only with --disaggregation-transfer-backend nixl. server_args.py now rejects other transfer backends early when the decode radix cache flag is enabled. Mooncake support will follow in a separate PR.

Main Changes

  • Decode scheduler
    • Match incoming requests against the decode-side radix tree.
    • Lock matched prefix nodes for the request lifetime.
    • Pre-allocate only the delta KV pages beyond the matched prefix.
  • Decode -> prefill protocol
    • Plumb decode_prefix_len from decode to prefill for the NIXL path.
    • Allow full-prefix hits where decode may need no KV pages transferred.
  • Prefill transfer path
    • Initialize the sender with only the unsent delta pages.
    • Keep the chunked transfer cursor monotonic when decode already has part of the prefix.
    • Skip empty non-last chunks so the sender/receiver chunk protocol stays consistent.
  • Correctness / cleanup
    • Align matched prefix length to page boundaries for paged KV allocators.
    • Guard lock release / cleanup paths for transfer-failure cases.
    • Batch finished prebuilt frees through the free-group path.
  • CLI / config
    • The user-facing switch is --disaggregation-decode-enable-radix-cache.
    • Current validation requires --disaggregation-transfer-backend nixl when that flag is set.

Interface

Enable decode radix cache on the decode worker with:

--disaggregation-mode decode --disaggregation-transfer-backend nixl --disaggregation-decode-enable-radix-cache

Prefill continues to run with --disaggregation-transfer-backend nixl.

Note: DP attention is still experimental here. The flag is allowed, but good cache hit rates require prefix-aware DP routing.

Benchmark

Setup

  • Hardware: 1x NVIDIA B200 node (8 GPUs), single-node PD disaggregation via NIXL
  • Model: Qwen/Qwen3-32B, FP8 KV cache, 3P1D, TP=2 per worker
  • Workload: 20 unique ~50K-token prefixes + ~4.5K suffix (~91% prefix reuse), 1000 requests, concurrency 128

Results

Metric Baseline Decode Radix Cache Improvement
Request throughput (req/s) 1.21 1.59 1.32x
Output token throughput (tok/s) 430 566 1.32x
TTFT p50 (s) 73.2 9.0 8.1x
TTFT avg (s) 77.7 31.6 2.5x
Request latency p50 (s) 99.1 73.4 1.35x
ITL avg (ms) 65.6 130.6 0.50x
Benchmark duration (s) 827 628 1.32x

Decode-side logs show the reason for the throughput gain: baseline decode ran near KV capacity (token_usage ~ 0.99) and only fit ~37 running requests, while decode radix cache reduced duplicate prefix residency (token_usage ~ 0.75) and fit roughly 104-126 running requests. The ITL regression is expected from the larger decode batch.

Test Plan

  • Qwen3-0.6B local PD disagg sanity runs
  • MiniMax-M2.5 1P1D on B200
  • Qwen3-32B 3P1D on B200 (results above)
  • Guard decode radix cache behind nixl in server_args.py
  • Multi-node cross-host testing (RDMA transport)
  • Mooncake transfer backend support (separate PR)

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@ishandhanani ishandhanani changed the title [Draft] [P/D disagg] - support decode side radix cache [P/D disagg] - support decode side radix cache Mar 3, 2026
@dongyibo
Copy link
Copy Markdown

dongyibo commented Mar 3, 2026

@ishandhanani Can this feature be understood as follows:
In a multi-turn dialogue scenario, the first round takes tokens 1, 2, 3 as input and outputs tokens 4, 5, 6.
The second round takes tokens 1, 2, 3, 4, 5, 6, 7, 8, 9 as input and outputs tokens 10, 11, 12.

Current status of pd-disagg:
In the first round, for the decode worker, the generated tokens 4, 5, 6 are not cached, and the KV cache of the input tokens 1, 2, 3 is not saved.
In the second round, the prefill worker needs to send the KV cache for all tokens 1, 2, 3, 4, 5, 6, 7, 8, and 9 to the decode worker.

Based on this PR's implementation:
In the first round, the decode worker saves the KV cache for tokens 1, 2, 3, 4, 5, and 6.
In the second round, the prefill worker only needs to send the KV cache for tokens 7, 8, and 9 to the decode worker.
Is my understanding correct?

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

ishandhanani commented Mar 3, 2026

@ishandhanani Can this feature be understood as follows: In a multi-turn dialogue scenario, the first round takes tokens 1, 2, 3 as input and outputs tokens 4, 5, 6. The second round takes tokens 1, 2, 3, 4, 5, 6, 7, 8, 9 as input and outputs tokens 10, 11, 12.

Current status of pd-disagg: In the first round, for the decode worker, the generated tokens 4, 5, 6 are not cached, and the KV cache of the input tokens 1, 2, 3 is not saved. In the second round, the prefill worker needs to send the KV cache for all tokens 1, 2, 3, 4, 5, 6, 7, 8, and 9 to the decode worker.

Based on this PR's implementation: In the first round, the decode worker saves the KV cache for tokens 1, 2, 3, 4, 5, and 6. In the second round, the prefill worker only needs to send the KV cache for tokens 7, 8, and 9 to the decode worker. Is my understanding correct?

Yep. This is correct

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

- Set req.prefix_indices in _pre_alloc so init_next_round_input(None)
  computes extend_input_len correctly from the cached prefix length.
  Without this, prepare_for_prebuilt runs a full-length extend instead
  of a delta extend.

- Always call inc_lock_ref on the matched node (even on empty match)
  to match aggregated scheduler behavior. Prevents lock_ref underflow
  when cache_finished_req unconditionally calls dec_lock_ref.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

ishandhanani commented Mar 4, 2026

Next step is testing with a larger model on B200. And then step after (maybe in follow up) is to do the same for mooncake

@dongyibo
Copy link
Copy Markdown

dongyibo commented Mar 4, 2026

@ishandhanani There seems to be a constraint here:
For multiple decode workers, such as when decode is run with DP, it's best if the same DP rank is used for the entire conversation; otherwise, the cached KV cache cannot be utilized?

@nananall
Copy link
Copy Markdown

nananall commented Mar 4, 2026

Could you share the exact command you used to run this? I'd like to reproduce it and test it on my side.

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

ishandhanani commented Mar 4, 2026

@ishandhanani There seems to be a constraint here: For multiple decode workers, such as when decode is run with DP, it's best if the same DP rank is used for the entire conversation; otherwise, the cached KV cache cannot be utilized?

Theres a few things here.

  1. when running with multiple decode workers (standard data parallelism of workers) - I expect the router to pick the right decode worker based on kv load. The dynamo router handles this very well + performantly out of the box
  2. For DP attention - agreed. Right now I have not added support. Need to do this

@cctry
Copy link
Copy Markdown
Collaborator

cctry commented Apr 7, 2026

How is warm up request handled? iiuc it has fake bootstrap addr so its kv is garbage and should not be inserted?

@cctry
Copy link
Copy Markdown
Collaborator

cctry commented Apr 7, 2026

The ITL regression is expected from the larger decode batch.

please address all the performance issues. this is actually not expected.

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

The ITL regression is expected from the larger decode batch.

please address all the performance issues. this is actually not expected.

Can you explain why this is not expected? If we have a larger batch being driven through the decode we might take a slight hit on ITL. I might be missing something so please feel free to correct me

Remove the if/else branch for disable_radix_cache in pop_preallocated.
_match_prefix_and_lock works with both RadixCache and ChunkCache —
ChunkCache returns prefix_len=0 and lock/unlock are no-ops.
@ishandhanani ishandhanani force-pushed the ishan/add-radix-cache-decode branch from 8a9a52a to 53294d6 Compare April 7, 2026 02:56
Add fail-fast guards for --disaggregation-decode-enable-radix-cache with:
- speculative decoding (checked in server_args)
- SWA hybrid models (checked in scheduler init, model-dependent)
- Mamba/SSM hybrid models (checked in scheduler init, model-dependent)
Running requests' pages are already inserted into the radix tree
during process_prebuilt -> cache_unfinished_req, so they should be
properly accounted for by check_memory. The guard was masking
potential accounting issues rather than fixing them.
@cctry
Copy link
Copy Markdown
Collaborator

cctry commented Apr 7, 2026

The ITL regression is expected from the larger decode batch.

please address all the performance issues. this is actually not expected.

Can you explain why this is not expected? If we have a larger batch being driven through the decode we might take a slight hit on ITL. I might be missing something so please feel free to correct me

For that long common prefix, the memory traffic of loading kv cache does not increase proportionally to batch size.

Also, the 8x TTFT speedup only leads to 3x-ish concurrency improvements, which means the engine is not properly loaded. (8x looks sus to me)

@ShangmingCai
Copy link
Copy Markdown
Collaborator

I have found a similar effort, #22234. Should we arrange a meeting and gather all these people who are interested in this feature and settle on a design together?

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

ishandhanani commented Apr 7, 2026

I have found a similar effort, #22234. Should we arrange a meeting and gather all these people who are interested in this feature and settle on a design together?

Oh nice. Yea that would be good. They do some things very well. Maybe over slack?

- Remove unnecessary running_batch guard in decode idle memory check:
  running requests' pages are already in the tree via process_prebuilt
  -> cache_unfinished_req.
- Move start_send_idx update after send() to prevent cursor advance
  on skipped/failed sends.
- Add TP sync guards in _update_handshake_waiters and pop_transferred
  to prevent gloo hangs from queue size divergence across TP ranks.
- Add diagnostics to Bug #14 KV cache full assertion in _pre_alloc.

Cursor fix and TP sync guards inspired by #22234 (yudian0504).
Validated on sa-b200 Qwen 32B 3P1D at concurrency 128 and 256.
req.req_pool_idx, : len(req.fill_ids)
]
kv_indices = kv_indices_orig[: req.kv_committed_len]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check the different scenarios:

  • Incremental transfer & success:
    • inc_lock_ref(pop_preallocated) → dec_lock_ref+inc_lock_ref(cache_unfinished_req) → dec_lock_ref(cache_finished_req)
  • Full transfer & success:
    • inc_lock_ref(get_new_prebuilt_batch) → dec_lock_ref+inc_lock_ref(cache_unfinished_req) → dec_lock_ref(cache_finished_req)
  • Incremental transfer & failure:
    • inc_lock_ref(pop_preallocated) → dec_lock_ref(cache_finished_req via abort and release)
  • Full transfer & failure:
    • no inc_lock_ref → dec_lock_ref(root_node) # no-op

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified all four scenarios with unit tests in 4f2c34e (test/registered/unit/mem_cache/test_decode_radix_lock_ref.py). All pass.

Your analysis is correct. One nuance worth noting: inc_lock_ref and dec_lock_ref both use while node != root_node, so calls on root are no-ops. This means scenarios 2 and 4 (full transfer, where last_node = root) have trivially balanced lock_refs — the real balancing work happens in scenarios 1 and 3 (incremental transfer with actual prefix nodes).

The tests cover:

  1. Incremental transfer & success — exercises full inc → dec+inc → dec chain ✓
  2. Full transfer & success — confirms root is no-op, leaf inc/dec balances ✓
  3. Incremental transfer & failure — inc_lock_ref then cache_finished_req(is_insert=False) properly dec's ✓
  4. Full transfer & failure — dec_lock_ref(root) is safe no-op ✓
  5. Repeated incremental (5 iterations) — no cumulative lock_ref drift ✓

Warmup requests use FAKE_BOOTSTRAP_HOST, so no actual KV transfer
occurs — the decode side runs forward passes on uninitialized memory.
With decode radix cache enabled, cache_unfinished_req / cache_finished_req
insert this garbage into the tree, where it could be matched by real
requests sharing the same prefix tokens.

Fix: call flush_cache after both warmup paths (custom --warmups and
the general _execute_server_warmup) complete in disagg mode. The flush
is safe because it runs before ServerStatus.Up is set, so no real
traffic can race with it.
@ishandhanani
Copy link
Copy Markdown
Collaborator Author

How is warm up request handled? iiuc it has fake bootstrap addr so its kv is garbage and should not be inserted?

Addressed in a6b6e8b.

flush_cache after both warmup paths (custom --warmups and the general _execute_server_warmup) complete in disagg mode. The flush is safe because it runs before ServerStatus.Up is set, so no real traffic can race with it.

Extract inline floor-alignment math into a shared helper in
disaggregation/utils.py. Used in two places: prefix_len alignment
and SWA window_start alignment.
Revert cache_unfinished_req back to upstream: it now uses req.fill_ids
directly without truncation.  The decode-specific truncation to
kv_committed_len is done at the two decode call sites instead:
  - _pre_alloc: when fill_ids is first set
  - get_new_prebuilt_batch: after init_next_round_input resets fill_ids

This addresses review feedback (hzh0425, yudian0504) that modifying
RadixTree internals is too aggressive — the truncation is a
decode-disagg concern, not a tree concern.
Restore getattr(config, "rope_theta", 1000000) from main — our branch
had config.rope_parameters["rope_theta"] from a bad merge resolution,
which breaks on checkpoints without rope_parameters.
…-decode

# Conflicts:
#	python/sglang/srt/disaggregation/nixl/conn.py
Tests the four inc_lock_ref/dec_lock_ref scenarios identified by
yudian0504 in PR #19746 review:

1. Incremental transfer & success (prefix match > 0)
2. Full transfer & success (no prefix match, root)
3. Incremental transfer & failure (prefix match, is_insert=False)
4. Full transfer & failure (root, dec_lock_ref is no-op)
5. Repeated incremental transfers (no lock_ref leak)

Key invariant: after each scenario, protected_size()==0 and root
lock_ref==1 (unchanged).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.