[LoRA] Fix qkv_proj LoRA buffer sizing when tp_size > num_key_value_heads#24420
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses an issue with QKV LoRA weight slicing and buffer sizing when tensor parallelism size exceeds the number of KV heads, leading to KV replication. Changes include updating the slicing logic in layers.py to use unreplicated K/V sizes and implementing a specialized per-rank dimension calculation in mem_pool.py. A comprehensive unit test suite is also added. Feedback indicates a potential AttributeError in LoRAMemoryPool due to the use of an uninitialized _text_config attribute; using base_hf_config with proper resolution is suggested instead.
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
80cbdcb to
4080966
Compare
4080966 to
c386016
Compare
|
you can rerun the last case in stage c |
| @@ -0,0 +1,386 @@ | |||
| """Unit tests for QKV LoRA correctness when ``tp_size > num_key_value_heads``. | |||
| head_dim = getattr(cfg, "head_dim", None) or ( | ||
| cfg.hidden_size // cfg.num_attention_heads | ||
| ) | ||
| # `total_output_dim` from `get_hidden_dim("qkv_proj")` is |
There was a problem hiding this comment.
We can delete the comments here, since it has been explained in docstring
| kv_end_idx = kv_start_idx + kv_proj_shard_size | ||
|
|
||
| q_size, k_size, _ = base_layer.output_sizes | ||
| # `base_layer.output_sizes[1]` and `[2]` are the QKVParallelLinear's |
There was a problem hiding this comment.
Delete the comments here
* main: (894 commits) [Bug Fix] Fix RunAI streamer: corrupted weights, missing quant init, and broken URIs for multimodal models (sgl-project#22715) [Kernel] Deprecate DeepGemm in sgl kernel and apply custom wheel sgl-deep-gemm (sgl-project#24268) propagate pytest exit code from test __main__ entries (sgl-project#24487) [R3] Avoid implicit CUDA sync in routed experts DP slicing (sgl-project#24550) Add ChatCompletionRequest-style support to /v1/tokenize (sgl-project#23981) Support Triton MLA FP8 KV cache (sgl-project#20479) [diffusion] chore: align LTX-2 with official (sgl-project#24313) Expand support matrix for pypi wheel release (sgl-project#24565) [codex] Optimize Z-Image packed QKV (sgl-project#24117) [Misc] Fix breaking weight checker test (sgl-project#24553) [LoRA] Fix qkv_proj LoRA buffer sizing when tp_size > num_key_value_heads (sgl-project#24420) ci: bump test_mimo_models.py est_time 330 → 610 (sgl-project#24551) [CI] Temporarily disable marco/mcdse-2b-v1 in test_embedding_models (sgl-project#24279) Improve metrics, observability, and PD deploy tooling (sgl-project#24521) Fix diffusion fallback guards and validation (sgl-project#23335) [PD] Prevent update_status to Failed from cleared entries (sgl-project#24539) [CP] Register KV cache allgather buffer with symmetric memory (sgl-project#24040) Support getting checksums in weight checker (sgl-project#24537) Refactor buffer patterns in weight checker (sgl-project#24538) Add unit and end-to-end tests for weight checker (sgl-project#24536) ... # Conflicts: # python/sglang/srt/managers/scheduler.py # python/sglang/srt/model_executor/model_runner.py
…eads (sgl-project#24420) Co-authored-by: Yanbin Jiang <jybsuper@gmail.com>
Motivation
When
QKVParallelLinearis sharded withtp_size > total_num_kv_heads,KV heads are replicated across
tp_size // num_kv_headsranks ratherthan further divided. The current LoRA path doesn't handle that layout
on either side of the per-rank slice, so loading any LoRA adapter that
targets
qkv_projfails with a buffer-shape assert at the firstforward.
Concrete repro on
main:python -m sglang.launch_server \ --model-path Qwen/Qwen3.5-35B-A3B \ --tp 4 --ep 4 --enable-lora --max-loras-per-batch 1 --max-lora-rank 32 \ --lora-target-modules qkv_proj \ --lora-paths my_lora=<some-qwen3.5-lora-checkpoint>Error:
Modifications
Two cooperating pieces have to agree on the per-rank QKV layout when KV heads are replicated:
LoRAMemoryPool.get_lora_B_shape (mem_pool.py) previously sized the per-rank qkv_proj LoRA B output as divide(total_output_dim, tp_size). That under-counts whenever tp_size > num_kv_heads, because the K/V part of total_output_dim is 2 * num_kv_heads * head_dim, which doesn't divide further by tp_size.
New helper _column_parallel_lora_b_per_rank_dim returns q_per_rank + 2 * head_dim for qkv_proj in the replicated case (each rank owns exactly head_dim of K and head_dim of V — one KV head replica), and falls back to the naive even split for every other column-parallel module and for the no-replication case (tp_size <= num_kv_heads). It also handles attn_output_gate=True (Qwen3.5) implicitly: gate-doubled q heads are encoded in total_output_dim, so the q remainder after subtracting K+V is correct without special-casing.
QKVParallelLinearWithLoRA.slice_lora_b_weights (layers.py) was indexing the PEFT-format (un-replicated) B tensor [q_total, k_total, v_total] with q_size, k_size, _ = base_layer.output_sizes. But base_layer.output_sizes[1] and [2] are the K/V replicated dims (per-rank K/V × tp_size), so they over-count K/V by num_kv_head_replicas. The V slice then falls past the end of B and silently returns 0 rows.
Fix: use output_sizes[0] for the q offset and output_sizes[1] // num_kv_head_replicas for the k offset, so indexing matches B's un-replicated layout.
The two changes together restore the previous buffer-shape contract (allocated dim == sliced dim) for the replicated case and leave the non-replicated case bit-identical.
Accuracy Tests
This is a load-time / shape correctness fix on a path that previously crashed before any forward, so existing accuracy tests for qkv_proj LoRA in the no-replication regime continue to pass unchanged.
The new unit test test/registered/unit/lora/test_qkv_lora_kv_replication.py exercises both pieces hermetically (no CUDA, no distributed init):
_column_parallel_lora_b_per_rank_dim for non-qkv modules, no-replication qkv, replicated qkv with and without attn_output_gate, varying replication factor, and the defensive fallback when num_key_value_heads is missing.
slice_lora_b_weights for the Qwen3.5-shaped (TP=16, 2 KV heads) case: per-rank shape (1024 rows = 512 q + 256 k + 256 v), KV replication groups sharing identical K/V slices, and an unchanged-behavior smoke test for the no-replication regime (tp <= num_kv_heads).
Run:
python -m pytest test/registered/unit/lora/test_qkv_lora_kv_replication.py -v
End-to-end: validated on Qwen3.5-35B-A3B with
--tp 4 --ep 4 --enable-lora --lora-target-modules all, which previously crashed at first request and now serves correctly.
Speed Tests and Profiling
No expected impact: the helper runs once per init_buffers, and slice_lora_b_weights does the same number of slices/concat as before — only the indices change. The non-replication path (the common case) is bit-identical to the previous implementation.
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci