Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion python/sglang/srt/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,15 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int) -> torch.Tensor:
kv_start_idx = kv_proj_shard_size * kv_shard_id
kv_end_idx = kv_start_idx + kv_proj_shard_size

q_size, k_size, _ = base_layer.output_sizes
# The adapter weight `B` is in pre-replication layout
# [q_total, k_total, v_total]. When tp_size > total_num_kv_heads,
# base_layer.output_sizes reflects the post-replication kv size, which
# over-counts and pushes the v offset past the end of `B`. Compute
# offsets from the pre-replication head counts instead.
head_size = base_layer.head_size
q_size = base_layer.total_num_heads * head_size
k_size = base_layer.total_num_kv_heads * head_size

B_q_shard = B[q_start_idx:q_end_idx, :]
B_k_shard = B[q_size + kv_start_idx : q_size + kv_end_idx, :]
B_v_shard = B[q_size + k_size + kv_start_idx : q_size + k_size + kv_end_idx, :]
Expand Down
26 changes: 25 additions & 1 deletion python/sglang/srt/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,29 @@ class LoRAType(Enum):
LORA_B = 1


def get_qkv_lora_kv_total(num_key_value_heads: int) -> int:
"""Return the kv-head count the LoRA qkv buffer must reserve for, accounting
for KV-head replication when ``tp_size > num_key_value_heads``.

Asserts that DP-attention / context-parallel are off, because the rest of
the LoRA path assumes ``attn_tp_size == tp_size`` (the mem_pool divides
buffer shapes by global ``tp_size``, not ``attn_tp_size``).
"""
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.dp_attention import get_attention_tp_size

tp_size = get_tensor_model_parallel_world_size()
attn_tp_size = get_attention_tp_size()
assert attn_tp_size == tp_size, (
f"LoRA qkv sizing assumes attn_tp_size == tp_size, got "
f"attn_tp_size={attn_tp_size}, tp_size={tp_size}. DP-attention or "
f"context-parallel is not supported with LoRA today (the mem_pool "
f"sizes buffers by global tp_size; see lora/mem_pool.py)."
)
kv_heads_per_rank = max(1, num_key_value_heads // tp_size)
return kv_heads_per_rank * tp_size


def get_hidden_dim(
module_name: str,
config: AutoConfig,
Expand All @@ -79,8 +102,9 @@ def get_hidden_dim(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
if module_name == "qkv_proj":
kv_total_replicated = get_qkv_lora_kv_total(config.num_key_value_heads)
return config.hidden_size, head_dim * (
config.num_attention_heads + config.num_key_value_heads * 2
config.num_attention_heads + kv_total_replicated * 2
)
elif module_name == "o_proj":
o_head_dim = getattr(config, "v_head_dim", None) or head_dim
Expand Down
6 changes: 4 additions & 2 deletions python/sglang/srt/models/nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,10 +778,12 @@ def get_hidden_dim(self, module_name, layer_idx):
)

if module_name == "qkv_proj":
from sglang.srt.lora.utils import get_qkv_lora_kv_total

kv_total_replicated = get_qkv_lora_kv_total(config.num_key_value_heads)
return (
hidden_size,
head_dim
* (config.num_attention_heads + config.num_key_value_heads * 2),
head_dim * (config.num_attention_heads + kv_total_replicated * 2),
)
elif module_name == "o_proj":
return (
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,11 +957,14 @@ def get_hidden_dim(self, module_name: str, layer_idx: int):
head_dim = config.head_dim or (config.hidden_size // config.num_attention_heads)

if module_name == "qkv_proj":
from sglang.srt.lora.utils import get_qkv_lora_kv_total

attn_output_gate = getattr(config, "attn_output_gate", True)
q_heads = config.num_attention_heads * (2 if attn_output_gate else 1)
kv_total_replicated = get_qkv_lora_kv_total(config.num_key_value_heads)
return (
config.hidden_size,
head_dim * (q_heads + config.num_key_value_heads * 2),
head_dim * (q_heads + kv_total_replicated * 2),
)
elif module_name == "o_proj":
return config.num_attention_heads * head_dim, config.hidden_size
Expand Down
Loading