Skip to content

feat: WIP NIXL transfer through CPU buffer for better performance with mixed TP sizes#18968

Draft
Aphoh wants to merge 5 commits into
sgl-project:mainfrom
Aphoh:warnold/mixed-tp-disagg
Draft

feat: WIP NIXL transfer through CPU buffer for better performance with mixed TP sizes#18968
Aphoh wants to merge 5 commits into
sgl-project:mainfrom
Aphoh:warnold/mixed-tp-disagg

Conversation

@Aphoh
Copy link
Copy Markdown
Contributor

@Aphoh Aphoh commented Feb 18, 2026

Motivation

Currently NIXL performs transfers between mixed TP nodes by issuing a single nixl transaction per token, layer, head. This is incredibly inefficient and any reasonable load will grind a deployment to a halt on KV transfer.

Modifications

Introduced three main components

1. Triton kernels for copying KVs to host

These kernels transpose the layout from
k_buffer[layer]: [num_slots, num_heads, head_dim] to [num_heads, num_layers, 2, num_tokens, head_dim] which lets us slice out heads and do a single nixl transfer per head we need to send.
In benchmarking these kernels easily hit 90% of host to device bw.

- D2H memcpy (baseline): 227.9 ms, 54.18 GB/s
- H2D memcpy (baseline): 222.7 ms, 55.44 GB/s
- Gather (GPU→CPU): 234.8 ms, 52.60 GB/s — 97% of D2H baseline
- Scatter (CPU→GPU): 240.2 ms, 51.41 GB/s — 93% of H2D baseline

2. A Pinned CPU buffer allocator

This slot-based allocator helps reduce how much CPU buffer we need to allocate on the host.

Accuracy Tests

 # Prefill server (e.g. TP=4)
  python -m sglang.launch_server \
      --model-path Qwen/Qwen3-8B \
      --disaggregation-mode prefill \
      --disaggregation-transfer-backend nixl \
      --disaggregation-bootstrap-port 8998 \
      --nixl-use-cpu-buffer \
      --tp 4 --port 30000 --disable-radix-cache

  # Decode server (e.g. TP=2, on different GPUs)
  python -m sglang.launch_server \
      --model-path Qwen/Qwen3-8B \
      --disaggregation-mode decode \
      --disaggregation-transfer-backend nixl \
      --nixl-use-cpu-buffer \
      --tp 2 --base-gpu-id 4 --port 30100 --disable-radix-cache

  # Send a request (same payload to both, same bootstrap_room)
  ROOM=$RANDOM
  curl -X POST http://127.0.0.1:30000/generate \
      -H "Content-Type: application/json" \
      -d '{"text":"The capital of France is","sampling_params":{"max_new_tokens":16},"bootstrap_host":"
  127.0.0.1","bootstrap_port":8998,"bootstrap_room":'$ROOM'}' &
  curl -X POST http://127.0.0.1:30100/generate \
      -H "Content-Type: application/json" \
      -d '{"text":"The capital of France is","sampling_params":{"max_new_tokens":16},"bootstrap_host":"
  127.0.0.1","bootstrap_port":8998,"bootstrap_room":'$ROOM'}'

The three tested configurations were equal TP (2,2), prefill TP > decode TP (4,2), and prefill TP < decode TP (2,4), all on Qwen3-8B.

Benchmarking and Profiling

bench_serving Results (Qwen/Qwen3-8B, 320 prompts, in=256, out=128, H200)

(4,2) — Prefill TP=4, Decode TP=2

Concurrency Req/s (Base) Req/s (PR) Output tok/s (Base) Output tok/s (PR) TTFT ms (Base) TTFT ms (PR) TPOT ms (Base) TPOT ms (PR) E2E ms (Base) E2E ms (PR)
1 2.93 3.59 161 232 132 34 3.84 3.82 337 278
4 4.96 12.79 272 828 530 50 3.98 4.10 744 311
8 7.69 18.61 491 1,188 727 115 4.07 4.42 983 395
32 66.87 4,329 114 5.23 449

(2,4) — Prefill TP=2, Decode TP=4

Concurrency Req/s (Base) Req/s (PR) Output tok/s (Base) Output tok/s (PR) TTFT ms (Base) TTFT ms (PR) TPOT ms (Base) TPOT ms (PR) E2E ms (Base) E2E ms (PR)
1 4.07 4.42 224 286 84 39 2.91 2.94 242 226
4 8.58 14.86 470 962 218 44 3.84 3.51 430 268
8 16.23 18.97 1,035 1,210 231 130 3.66 4.43 457 392
32 67.94 4,399 98 5.53 447

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Aphoh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a substantial performance enhancement for NIXL-based KV cache transfers, particularly critical for scenarios involving mixed Tensor Parallelism configurations. By integrating custom Triton kernels and a dedicated pinned CPU buffer allocator, the system can now efficiently move KV cache data between GPU and host memory. This change drastically reduces the overhead associated with numerous small NIXL transactions, replacing them with fewer, larger, and more optimized transfers, leading to a more scalable and performant disaggregated inference setup.

Highlights

  • Optimized NIXL KV Transfer: Introduced a new mechanism for NIXL (NVIDIA Interconnect eXchange Library) KV cache transfers that leverages a pinned CPU buffer, significantly improving performance for mixed Tensor Parallelism (TP) sizes by reducing the number of individual NIXL transactions.
  • Triton Kernels for GPU-CPU Data Movement: Implemented highly efficient Triton kernels for gathering scattered KV data from GPU memory to a contiguous pinned CPU buffer, and scattering data from the pinned CPU buffer back to GPU KV cache. These kernels achieve near-peak PCIe bandwidth and have minimal GPU memory overhead.
  • Pinned CPU Buffer Allocator: Developed a shared, per-GPU pinned CPU buffer pool that manages memory allocations for KV transfers, preventing redundant allocations and optimizing host memory usage during disaggregated inference.
  • Mixed TP Support Enhancements: Improved handling of mixed TP configurations (e.g., prefill TP > decode TP or prefill TP < decode TP) by enabling batched transfers and correct head redistribution using the new CPU buffer approach, addressing previous performance bottlenecks.
  • FP8 Dtype Compatibility: Ensured the new Triton KV transfer kernels correctly handle FP8 (float8_e4m3fn) KV cache data, including robust dtype validation to prevent memory corruption issues.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/disaggregation/base/conn.py
    • Updated KVArgs to include optional k_buffers, v_buffers, and head_dim for CPU buffer KV transfer.
    • Added Any to typing imports.
  • python/sglang/srt/disaggregation/common/conn.py
    • Modified warning messages for mixed TP sizes to recommend using --nixl-use-cpu-buffer for improved performance and correct head redistribution.
  • python/sglang/srt/disaggregation/decode.py
    • Added logic to populate KVArgs with k_buffers, v_buffers, and head_dim if nixl_use_cpu_buffer is enabled.
    • Adjusted bootstrap_room validation to mask the value to a signed 64-bit integer range, preventing potential overflow issues from u64 values generated by Dynamo.
  • python/sglang/srt/disaggregation/nixl/init.py
    • Imported the new PinnedBufferPool module.
  • python/sglang/srt/disaggregation/nixl/conn.py
    • Imported torch and PinnedBufferPool for CPU buffer management.
    • Introduced _import_triton_kv_transfer for lazy loading of Triton kernels.
    • Extended KVArgsRegisterInfo with dst_pinned_ptr and dst_pinned_size to communicate pinned buffer details.
    • Refactored TransferStatus to use sender_key (NIXL peer_name) instead of pp_rank for more accurate tracking of transfers from multiple prefill TP ranks.
    • Initialized Triton transfer-related attributes (nixl_use_cpu_buffer, triton_staging_buffer, _pinned_pool, triton_pinned_descs) in NixlKVManager.
    • Added _init_triton_transfer_buffers method to set up GPU staging buffers and the shared pinned buffer pool.
    • Implemented _expand_pages_to_slots to convert page indices to slot indices for Triton kernels.
    • Added send_kvcache_triton for performing single NIXL transfers using Triton gather kernels.
    • Introduced _send_kvcache_triton_batched to handle batched KV transfers for prefill_tp < decode_tp scenarios, gathering all heads once and then slicing for parallel NIXL transfers.
    • Implemented scatter_received_kv to scatter received KV data from the pinned buffer to the GPU KV cache using Triton kernels.
    • Modified add_transfer_request to incorporate the new Triton-based CPU buffer transfer logic, including special handling for batched mixed-TP transfers.
    • Added _send_kv_state for transferring page-indexed state data (NSA/SWA) via RDMA.
    • Updated NixlKVSender.poll to release allocated pinned buffer regions upon transfer completion or error.
    • Modified NixlKVReceiver to store kv_indices and trigger scatter_received_kv after NIXL transfer completion when CPU buffer transfer is enabled.
    • Updated _register_kv_args to include pinned buffer pointer and size in the registration message.
  • python/sglang/srt/disaggregation/nixl/pinned_buffer_pool.py
    • Added new file pinned_buffer_pool.py implementing a per-GPU singleton PinnedBufferPool.
    • The PinnedBufferPool pre-allocates a shared pinned CPU buffer and uses a range-based allocator for variable-sized concurrent transfers.
    • Includes methods for allocate, release, get_buffer_info, and register_with_nixl.
  • python/sglang/srt/disaggregation/prefill.py
    • Added logic to populate KVArgs with k_buffers, v_buffers, and head_dim if nixl_use_cpu_buffer is enabled.
  • python/sglang/srt/disaggregation/utils.py
    • Modified set_buf to mask the bootstrap_room value to a signed 64-bit integer range, ensuring consistency and preventing overflow when dealing with u64 values from Dynamo.
  • python/sglang/srt/layers/attention/triton_ops/kv_transfer.py
    • Added new file kv_transfer.py containing Triton kernels for efficient KV cache transfers.
    • Implemented _gather_kv_all_layers_kernel and gather_kv_to_pinned_all_layers to gather scattered GPU KV data to a contiguous pinned CPU buffer.
    • Implemented _scatter_kv_all_layers_from_pinned_kernel and scatter_kv_with_staging_all_layers to scatter contiguous pinned CPU data to scattered GPU KV cache.
    • These kernels operate on all layers in a single launch, use pointer tensors, and support dtype-agnostic byte copying, including FP8.
  • python/sglang/srt/models/deepseek_common/deepseek_weight_loader.py
    • Added checks to ensure param_name exists in params_dict before accessing, preventing errors for missing FP8 scale weights in certain quantization scenarios.
    • Added a check to skip FP8 conversion if weights are already in FP8 format.
  • python/sglang/srt/models/deepseek_nextn.py
    • Set self.quant_config to None for modelopt_fp4 quantization in DeepseekNextNModel to avoid potential issues.
  • python/sglang/srt/server_args.py
    • Added new command-line arguments --nixl-use-cpu-buffer (boolean) to enable the new CPU buffer transfer method.
    • Added --nixl-cpu-buffer-size-gb (float) to configure the total size of the pinned CPU buffer for NIXL KV transfers.
  • sgl-kernel/benchmark/bench_kv_transfer.py
    • Added new file bench_kv_transfer.py to benchmark the performance of the new Triton KV transfer kernels.
    • Includes benchmarks for raw CUDA memcpy, and the gather/scatter kernels, measuring time and bandwidth.
  • sgl-kernel/tests/test_kv_transfer.py
    • Added new file test_kv_transfer.py with comprehensive unit tests for the Triton KV transfer kernels.
    • Includes tests for full and sliced head transfers, contiguous and sparse indices, large pools, long sequences, and roundtrip functionality.
    • Features specific tests for FP8 dtype compatibility and assertions for dtype mismatches to ensure correctness and prevent memory errors.
Activity
  • The pull request introduces significant new functionality and performance optimizations for NIXL KV transfers.
  • New Triton kernels and a pinned CPU buffer pool have been implemented and thoroughly tested.
  • Configuration options have been added to enable and control the new transfer mechanism.
  • Benchmarking scripts are included to evaluate the performance gains of these changes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance enhancement for NIXL transfers, especially in mixed tensor-parallelism scenarios, by implementing a new transfer path via a pinned CPU buffer. This is achieved through new Triton kernels for efficient gather/scatter operations and a shared pinned buffer pool. The changes are extensive, well-structured, and include comprehensive tests and benchmarks, which is great. The new logic for handling mixed TP sizes and batched transfers appears sound. My review focuses on the correctness of the new implementation, and I've identified a minor issue in memory size calculation that should be addressed.

)

# Allocate GPU staging buffer (fixed size, 256MB by default)
staging_size_bytes = int(DEFAULT_TRITON_STAGING_BUFFER_SIZE_MB * 1e6)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculation of staging_size_bytes uses a decimal prefix (1e6 for MB). For memory sizes, it's standard and more accurate to use binary prefixes (powers of 1024). Using 1e6 results in allocating ~5% less memory than specified. Please use 1024**2 for megabytes (MiB) to ensure correct memory allocation.

Suggested change
staging_size_bytes = int(DEFAULT_TRITON_STAGING_BUFFER_SIZE_MB * 1e6)
staging_size_bytes = int(DEFAULT_TRITON_STAGING_BUFFER_SIZE_MB * 1024**2)

Comment on lines +436 to +438
pinned_size_bytes = int(
getattr(self._server_args, "nixl_cpu_buffer_size_gb", 16.0) * 1e9
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculation of pinned_size_bytes uses a decimal prefix (1e9 for GB). For memory sizes, it's standard and more accurate to use binary prefixes (powers of 1024). Using 1e9 results in allocating ~7% less memory than specified. Please use 1024**3 for gigabytes (GiB) to ensure correct memory allocation.

Suggested change
pinned_size_bytes = int(
getattr(self._server_args, "nixl_cpu_buffer_size_gb", 16.0) * 1e9
)
pinned_size_bytes = int(
getattr(self._server_args, "nixl_cpu_buffer_size_gb", 16.0) * 1024**3
)

Comment on lines +447 to +448
f"staging={self.triton_staging_buffer.nbytes / 1e6:.2f}MB (GPU), "
f"shared_pinned_pool={pinned_size_bytes / 1e9:.2f}GB (CPU)"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To be consistent with using binary prefixes for memory sizes (MiB/GiB), the logging should also use powers of 1024 for reporting. This avoids confusion between decimal (MB/GB) and binary (MiB/GiB) units and correctly reflects the allocated memory after applying the suggested changes for memory calculation.

Suggested change
f"staging={self.triton_staging_buffer.nbytes / 1e6:.2f}MB (GPU), "
f"shared_pinned_pool={pinned_size_bytes / 1e9:.2f}GB (CPU)"
f"staging={self.triton_staging_buffer.nbytes / (1024**2):.2f}MiB (GPU), "
f"shared_pinned_pool={pinned_size_bytes / (1024**3):.2f}GiB (CPU)"

@Aphoh Aphoh force-pushed the warnold/mixed-tp-disagg branch from 23b3a39 to 53bd1c1 Compare February 19, 2026 00:52
nvidia and others added 5 commits February 25, 2026 11:53
Triton gather/scatter kernels + shared pinned buffer reduce NIXL descriptor
count from O(tokens * layers) to O(1), achieving ~100% PCIe bandwidth.
Previously TransferStatus tracked KV notifications by pp_rank, but in
mixed TP (prefill_tp > decode_tp) multiple prefill TP ranks share the
same pp_rank (0 when PP=1). The decode expected N unique senders but
only ever saw one unique key, causing a 5-minute waiting_timeout.

Fix: track by peer_name (unique NIXL agent UUID per TP rank) so each
prefill TP rank is counted as a distinct sender. This mirrors the
approach in the dd-rebased-058 branch.
For prefill_tp > decode_tp:
- dst_head_offset was using the global local_tp_rank, causing prefill
  ranks in the second decode bucket (e.g. ranks 2,3 for prefill_tp=4,
  decode_tp=2) to write beyond the end of the decode node's pinned
  buffer. Fix: use (local_tp_rank % prefill_ranks_per_decode) so the
  offset is always relative to the decode bucket.

For prefill_tp < decode_tp (batched path):
- head_start was using the absolute decode_tp_rank, so prefill rank 1
  (with local heads 0..H/2) would be asked to send heads starting at
  H/2, H/4*3, etc. which are out of range. Fix: use
  (decode_tp_rank % decode_per_prefill) as the relative rank within the
  group of decode ranks served by this prefill rank.

Also update the "Performance is NOT guaranteed" warning:
- Suppress it when --nixl-use-cpu-buffer is set (which now correctly
  handles head redistribution); emit an info_once instead.
- When cpu buffer is NOT used, keep the warning and add a suggestion to
  use --nixl-use-cpu-buffer.

Add DEBUG logging for both mixed-TP paths to aid future diagnosis.
Replace blocking torch.cuda.synchronize() after each Triton gather kernel
with a non-blocking CUDA event. NixlKVSender.poll() checks event.query()
and posts the NIXL transfer only once the gather completes, keeping the
model forward pass unblocked. Also removes the scatter sync on the decode
side (stream ordering is sufficient) and drops the backwards-compat shim
in TransferInfo.from_zmq.
Each chunk computed dst_offset using the chunk token count instead of
total, causing chunk N to overwrite chunk N-1. Accumulate kv_indices
across chunks and issue a single gather+NIXL write on is_last=True.

Accuracy on GSM8K prefill_tp_larger (TP4->TP2): 86.5% -> 95.0%
TODO: revisit multi-chunk handling more cleanly in the future
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant