Skip to content

[Disagg][NIXL] Fix heterogeneous TP KV transfer for non-MLA models (same logic with mooncake, Step 1/2 for Qwen3.5 support)#22145

Merged
ShangmingCai merged 3 commits intosgl-project:mainfrom
YAMY1234:fix/nixl-heterogeneous-tp-non-mla
Apr 7, 2026
Merged

[Disagg][NIXL] Fix heterogeneous TP KV transfer for non-MLA models (same logic with mooncake, Step 1/2 for Qwen3.5 support)#22145
ShangmingCai merged 3 commits intosgl-project:mainfrom
YAMY1234:fix/nixl-heterogeneous-tp-non-mla

Conversation

@YAMY1234
Copy link
Copy Markdown
Contributor

@YAMY1234 YAMY1234 commented Apr 5, 2026

Motivation

NIXL disaggregated serving with heterogeneous TP (prefill TP ≠ decode TP) on non-MLA models hangs indefinitely due to two bugs in nixl/conn.py:

  1. Notification key collision: _process_kvcache_transfer uses pp_rank in RDMA notification tags. With PP=1, all prefill ranks share pp_rank=0, so TransferStatus.received_kvs_per_pp only records one key while num_pp_ranks_expected > 1is_done() never returns True → decode hangs.

  2. Wrong head distribution: send_kvcache_slice uses per-rank kv_head_num instead of total_kv_head_num, losing precision under GQA (total_kv_heads < tp_size). It also misses GQA replication handling, causing incorrect dst_head_start_offset when multiple prefill ranks share the same KV heads.

Modifications

  • send_kvcache_slice(): derive head counts from total_kv_head_num with max(1, ...) guards; add src_replication / unique_head_idx for GQA replication, aligned with Mooncake's implementation.
  • _process_kvcache_transfer(): replace pp_rank with engine_rank in KV/state notification tags.

Accuracy Tests

Setup

  • Model: Qwen3-32B
  • Platform: GB200
  • Topology: 1P4D (prefill TP4 → decode TP1×4)
  • Backend: NIXL
  • Frontend: Dynamo
  • Eval: GSM8K 8-shot
  • Examples: 1311

With fix

100%|██████████| 1311/1311 [08:52<00:00,  2.46it/s]
Total latency: 532.396 s
Score: 0.961
Output throughput: 568960.187 token/s
[METRIC] gsm8k_score=0.9610983981693364 labels={"model": "Qwen/Qwen3-32B", "eval": "gsm8k"}
[METRIC] gsm8k_latency=532.3955084759946 labels={"model": "Qwen/Qwen3-32B", "eval": "gsm8k"}
{'score:std': np.float64(0.19336045926112227), 'score': np.float64(0.9610983981693364), 'latency': 532.3955084759946, 'output_throughput': 568960.1868864341}

Without fix

Same config as above

Result

  • Decode hangs indefinitely

  • 0 completions

  • Manually cancelled after ~50 min

    0%| | 0/1311 [00:00<?, ?it/s]

Prefill drained all requests but they stayed permanently in-flight (#inflight-req never drops to 0):

Prefill batch, #new-seq: 1, #new-token: 1536, #cached-token: 0, token usage: 0.00,
  #running-req: 0, #queue-req: 0, #prealloc-req: 0, #inflight-req: 2

Decode workers had zero decode activity after startup; they only exited due to manual cancellation:

WARN  Performance is NOT guaranteed when using different TP sizes for non-MLA models.
* STEP 1383563.6 ON lyris0105 CANCELLED AT 2026-04-05T00:43:19 DUE to SIGNAL Terminated *

Config details:

  name: "qwen3-32b-hetero-tp4-tp1-nixl-verify"
  model:
    path: "Qwen/Qwen3-32B"
    precision: "bf16"
  resources:
    gpus_per_node: 4
    prefill_nodes: 1      # 1 prefill worker @ TP4 = 4 GPUs
    decode_nodes: 1       # 4 decode workers @ TP1 = 4 GPUs
    prefill_workers: 1
    decode_workers: 4
  backend:
    type: sglang
    sglang_config:
      prefill:
        served-model-name: "Qwen/Qwen3-32B"
        trust-remote-code: true
        tensor-parallel-size: 4
        disaggregation-mode: "prefill"
        disaggregation-transfer-backend: "nixl"
        mem-fraction-static: 0.85
        context-length: 32768
        page-size: 64
        disable-radix-cache: true
        watchdog-timeout: 1000000
      decode:
        served-model-name: "Qwen/Qwen3-32B"
        trust-remote-code: true
        tensor-parallel-size: 1
        disaggregation-mode: "decode"
        disaggregation-transfer-backend: "nixl"
        mem-fraction-static: 0.85
        context-length: 32768
        page-size: 64
        disable-radix-cache: true
        watchdog-timeout: 1000000
  benchmark:
    type: "gsm8k"
    num_examples: 1319
    max_tokens: 16000
    num_threads: 512
    num_shots: 8

Checklist

…geneous TP

send_kvcache_slice used per-rank kv_head_num instead of total_kv_head_num
for head distribution, which loses precision under GQA (total_heads < tp_size).
It also lacked GQA replication handling, causing multiple prefill ranks
sharing the same KV heads to write to wrong dst offsets.

This resulted in corrupted KV cache data on the decode side, producing 0%
accuracy on GPQA while staging path (which uses compute_head_slice_params)
was correct.

Fix: use total_kv_head_num with max(1,...) guards and add src_replication /
unique_head_idx logic, aligned with Mooncake's send_kvcache_slice.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refines the KV cache slicing logic by utilizing the total KV head count to ensure accurate head distribution and handle GQA replication. It also updates transfer notification strings to use engine_rank instead of pp_rank to prevent collisions. A potential ZeroDivisionError was found in the head distribution logic, and a suggestion was made to ensure the total head count is positive before use.

Under heterogeneous TP (prefill TP > decode TP) with PP=1, all prefill
ranks share pp_rank=0, causing RDMA notifications to collapse into a
single key in TransferStatus.received_kvs_per_pp. Since
num_pp_ranks_expected equals the number of prefill ranks (e.g. 4),
but only 1 unique pp_rank key is ever recorded, is_done() never
returns true and decode hangs indefinitely.

Fix by using engine_rank (which is unique per prefill rank) instead
of pp_rank in kv and state notification tags. This is a pre-existing
bug that affects any NIXL + heterogeneous TP (prefill TP > decode TP)
+ PP=1 configuration with non-MLA models.

Made-with: Cursor
@YAMY1234 YAMY1234 force-pushed the fix/nixl-heterogeneous-tp-non-mla branch from 410e7e6 to 2c7b29e Compare April 5, 2026 09:02
Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Apr 5, 2026
@ShangmingCai
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@ShangmingCai
Copy link
Copy Markdown
Collaborator

PD CI has passed.
image

@ShangmingCai ShangmingCai merged commit 3148742 into sgl-project:main Apr 7, 2026
249 of 299 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants