Skip to content

[LoRA] Fix qkv_proj LoRA buffer sizing when tp_size > num_key_value_heads#24420

Merged
Fridge003 merged 3 commits into
sgl-project:mainfrom
gh1595:sminakov/lora-qkv-kv-replication-fix
May 6, 2026
Merged

[LoRA] Fix qkv_proj LoRA buffer sizing when tp_size > num_key_value_heads#24420
Fridge003 merged 3 commits into
sgl-project:mainfrom
gh1595:sminakov/lora-qkv-kv-replication-fix

Conversation

@gh1595
Copy link
Copy Markdown
Contributor

@gh1595 gh1595 commented May 5, 2026

Motivation

When QKVParallelLinear is sharded with tp_size > total_num_kv_heads,
KV heads are replicated across tp_size // num_kv_heads ranks rather
than further divided. The current LoRA path doesn't handle that layout
on either side of the per-rank slice, so loading any LoRA adapter that
targets qkv_proj fails with a buffer-shape assert at the first
forward.
Concrete repro on main:

python -m sglang.launch_server \
    --model-path Qwen/Qwen3.5-35B-A3B \
    --tp 4 --ep 4 --enable-lora --max-loras-per-batch 1 --max-lora-rank 32 \
    --lora-target-modules qkv_proj \
    --lora-paths my_lora=<some-qwen3.5-lora-checkpoint>

Error:

AssertionError: LoRA buffer shape torch.Size([576, 16]) does not match
weight shape torch.Size([768, 16])

Modifications

Two cooperating pieces have to agree on the per-rank QKV layout when KV heads are replicated:

LoRAMemoryPool.get_lora_B_shape (mem_pool.py) previously sized the per-rank qkv_proj LoRA B output as divide(total_output_dim, tp_size). That under-counts whenever tp_size > num_kv_heads, because the K/V part of total_output_dim is 2 * num_kv_heads * head_dim, which doesn't divide further by tp_size.

New helper _column_parallel_lora_b_per_rank_dim returns q_per_rank + 2 * head_dim for qkv_proj in the replicated case (each rank owns exactly head_dim of K and head_dim of V — one KV head replica), and falls back to the naive even split for every other column-parallel module and for the no-replication case (tp_size <= num_kv_heads). It also handles attn_output_gate=True (Qwen3.5) implicitly: gate-doubled q heads are encoded in total_output_dim, so the q remainder after subtracting K+V is correct without special-casing.

QKVParallelLinearWithLoRA.slice_lora_b_weights (layers.py) was indexing the PEFT-format (un-replicated) B tensor [q_total, k_total, v_total] with q_size, k_size, _ = base_layer.output_sizes. But base_layer.output_sizes[1] and [2] are the K/V replicated dims (per-rank K/V × tp_size), so they over-count K/V by num_kv_head_replicas. The V slice then falls past the end of B and silently returns 0 rows.

Fix: use output_sizes[0] for the q offset and output_sizes[1] // num_kv_head_replicas for the k offset, so indexing matches B's un-replicated layout.

The two changes together restore the previous buffer-shape contract (allocated dim == sliced dim) for the replicated case and leave the non-replicated case bit-identical.

Accuracy Tests

This is a load-time / shape correctness fix on a path that previously crashed before any forward, so existing accuracy tests for qkv_proj LoRA in the no-replication regime continue to pass unchanged.

The new unit test test/registered/unit/lora/test_qkv_lora_kv_replication.py exercises both pieces hermetically (no CUDA, no distributed init):

_column_parallel_lora_b_per_rank_dim for non-qkv modules, no-replication qkv, replicated qkv with and without attn_output_gate, varying replication factor, and the defensive fallback when num_key_value_heads is missing.
slice_lora_b_weights for the Qwen3.5-shaped (TP=16, 2 KV heads) case: per-rank shape (1024 rows = 512 q + 256 k + 256 v), KV replication groups sharing identical K/V slices, and an unchanged-behavior smoke test for the no-replication regime (tp <= num_kv_heads).
Run:

python -m pytest test/registered/unit/lora/test_qkv_lora_kv_replication.py -v
End-to-end: validated on Qwen3.5-35B-A3B with
--tp 4 --ep 4 --enable-lora --lora-target-modules all, which previously crashed at first request and now serves correctly.

Speed Tests and Profiling

No expected impact: the helper runs once per init_buffers, and slice_lora_b_weights does the same number of slices/concat as before — only the indices change. The non-replication path (the common case) is bit-identical to the previous implementation.

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@github-actions github-actions Bot added the lora label May 5, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an issue with QKV LoRA weight slicing and buffer sizing when tensor parallelism size exceeds the number of KV heads, leading to KV replication. Changes include updating the slicing logic in layers.py to use unreplicated K/V sizes and implementing a specialized per-rank dimension calculation in mem_pool.py. A comprehensive unit test suite is also added. Feedback indicates a potential AttributeError in LoRAMemoryPool due to the use of an uninitialized _text_config attribute; using base_hf_config with proper resolution is suggested instead.

Comment thread python/sglang/srt/lora/mem_pool.py Outdated
@jybsuper jybsuper marked this pull request as ready for review May 5, 2026 22:17
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@jybsuper jybsuper added the run-ci label May 5, 2026
@yushengsu-thu yushengsu-thu enabled auto-merge (squash) May 5, 2026 22:47
@jybsuper jybsuper force-pushed the sminakov/lora-qkv-kv-replication-fix branch from 80cbdcb to 4080966 Compare May 5, 2026 23:36
@jybsuper jybsuper force-pushed the sminakov/lora-qkv-kv-replication-fix branch from 4080966 to c386016 Compare May 5, 2026 23:50
@yushengsu-thu
Copy link
Copy Markdown
Collaborator

you can rerun the last case in stage c

@@ -0,0 +1,386 @@
"""Unit tests for QKV LoRA correctness when ``tp_size > num_key_value_heads``.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this test

Comment thread python/sglang/srt/lora/mem_pool.py Outdated
head_dim = getattr(cfg, "head_dim", None) or (
cfg.hidden_size // cfg.num_attention_heads
)
# `total_output_dim` from `get_hidden_dim("qkv_proj")` is
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can delete the comments here, since it has been explained in docstring

Comment thread python/sglang/srt/lora/layers.py Outdated
kv_end_idx = kv_start_idx + kv_proj_shard_size

q_size, k_size, _ = base_layer.output_sizes
# `base_layer.output_sizes[1]` and `[2]` are the QKVParallelLinear's
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete the comments here

@Fridge003 Fridge003 disabled auto-merge May 6, 2026 21:51
@Fridge003 Fridge003 merged commit ece7e95 into sgl-project:main May 6, 2026
1 check passed
ltcs11 added a commit to ltcs11/sglang that referenced this pull request May 7, 2026
* main: (894 commits)
  [Bug Fix] Fix RunAI streamer: corrupted weights, missing quant init, and broken URIs for multimodal models (sgl-project#22715)
  [Kernel] Deprecate DeepGemm in sgl kernel and apply custom wheel sgl-deep-gemm (sgl-project#24268)
  propagate pytest exit code from test __main__ entries (sgl-project#24487)
  [R3] Avoid implicit CUDA sync in routed experts DP slicing (sgl-project#24550)
  Add ChatCompletionRequest-style support to /v1/tokenize (sgl-project#23981)
  Support Triton MLA FP8 KV cache (sgl-project#20479)
  [diffusion] chore: align LTX-2 with official (sgl-project#24313)
  Expand support matrix for pypi wheel release (sgl-project#24565)
  [codex] Optimize Z-Image packed QKV (sgl-project#24117)
  [Misc] Fix breaking weight checker test (sgl-project#24553)
  [LoRA] Fix qkv_proj LoRA buffer sizing when tp_size > num_key_value_heads (sgl-project#24420)
  ci: bump test_mimo_models.py est_time 330 → 610 (sgl-project#24551)
  [CI] Temporarily disable marco/mcdse-2b-v1 in test_embedding_models (sgl-project#24279)
  Improve metrics, observability, and PD deploy tooling (sgl-project#24521)
  Fix diffusion fallback guards and validation (sgl-project#23335)
  [PD] Prevent update_status to Failed from cleared entries (sgl-project#24539)
  [CP] Register KV cache allgather buffer with symmetric memory (sgl-project#24040)
  Support getting checksums in weight checker (sgl-project#24537)
  Refactor buffer patterns in weight checker (sgl-project#24538)
  Add unit and end-to-end tests for weight checker (sgl-project#24536)
  ...

# Conflicts:
#	python/sglang/srt/managers/scheduler.py
#	python/sglang/srt/model_executor/model_runner.py
LLThomas pushed a commit to LLThomas/sglang that referenced this pull request May 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants