Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def __init__(
reduce_results: bool = True,
layer_id: int = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.layer_id = layer_id
Expand Down Expand Up @@ -543,6 +544,8 @@ def __init__(
prefix=add_prefix("attn_mha", prefix),
)

self.alt_stream = alt_stream

self.w_kc = None
self.w_vc = None
self.w_scale = None
Expand Down Expand Up @@ -706,14 +709,32 @@ def forward_absorb(
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q)
k_nope = latent_cache[..., : self.kv_lora_rank]

# overlap qk norm
if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q = self.q_a_layernorm(q)
with torch.cuda.stream(self.alt_stream):
k_nope = self.kv_a_layernorm(k_nope)
current_stream.wait_stream(self.alt_stream)
else:
q = self.q_a_layernorm(q)
k_nope = self.kv_a_layernorm(k_nope)

k_nope = k_nope.unsqueeze(1)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
k_nope = latent_cache[..., : self.kv_lora_rank]
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)

q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)

if self.use_deep_gemm_bmm:
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
Expand Down Expand Up @@ -750,11 +771,6 @@ def forward_absorb(
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)

q_nope_out = q_nope_out.transpose(0, 1)

k_nope = latent_cache[..., : self.kv_lora_rank]
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)

q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

if self.attention_backend == "fa3":
Expand Down Expand Up @@ -1104,6 +1120,7 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
is_nextn: bool = False,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -1133,6 +1150,7 @@ def __init__(
layer_id=layer_id,
reduce_results=False,
prefix=add_prefix("self_attn", prefix),
alt_stream=alt_stream,
)

self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
Expand Down Expand Up @@ -1376,13 +1394,15 @@ def __init__(
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
)
self.alt_stream = torch.cuda.Stream()
self.layers = nn.ModuleList(
[
DeepseekV2DecoderLayer(
config,
layer_id,
quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_id}", prefix),
alt_stream=self.alt_stream,
)
for layer_id in range(config.num_hidden_layers)
]
Expand Down
Loading