Skip to content

Commit c56e9de

Browse files
committed
Rename layer to comply with deepseek
Signed-off-by: peaceh <[email protected]>
1 parent b3ca159 commit c56e9de

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def __init__(
248248
dtype=config.torch_dtype,
249249
config=model_config,
250250
aux_stream=aux_stream)
251-
self.fused_a = DeepseekV3Linear(
251+
self.kv_a_proj_with_mqa = DeepseekV3Linear(
252252
config.hidden_size,
253253
self.kv_lora_rank + self.qk_rope_head_dim +
254254
(self.q_lora_rank if not self.is_lite else 0),
@@ -1384,7 +1384,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
13841384
attn_module.v_b_proj_scale = nn.Parameter(
13851385
v_b_proj_scale, requires_grad=False)
13861386

1387-
elif names[-1] == "fused_a":
1387+
elif names[-1] == "kv_a_proj_with_mqa":
13881388
fused_a = weights[
13891389
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
13901390
if not is_lite:

tensorrt_llm/_torch/modules/attention.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def __init__(
502502
self.quant_config = quant_config
503503

504504
if not self.is_lite:
505-
self.fused_a = Linear(
505+
self.kv_a_proj_with_mqa = Linear(
506506
hidden_size,
507507
self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
508508
bias=bias,
@@ -528,7 +528,7 @@ def __init__(
528528
allreduce_strategy=config.allreduce_strategy,
529529
force_dynamic_quantization=config.force_dynamic_quantization)
530530
else:
531-
self.fused_a = Linear(
531+
self.kv_a_proj_with_mqa = Linear(
532532
hidden_size,
533533
self.kv_lora_rank + self.qk_rope_head_dim,
534534
bias=bias,
@@ -743,14 +743,15 @@ def forward_impl(self,
743743
torch.Tensor: The output tensor.
744744
"""
745745
if self.is_lite:
746-
compressed_kv, k_pe = self.fused_a(hidden_states).split(
746+
compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split(
747747
[self.kv_lora_rank, self.qk_rope_head_dim], -1)
748748
compressed_kv = self.kv_a_layernorm(compressed_kv)
749749
q = hidden_states
750750
else:
751-
q, compressed_kv, k_pe = self.fused_a(hidden_states).split(
752-
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim],
753-
-1)
751+
q, compressed_kv, k_pe = self.kv_a_proj_with_mqa(
752+
hidden_states).split([
753+
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim
754+
], -1)
754755

755756
q, compressed_kv = maybe_execute_in_parallel(
756757
lambda: self.q_a_layernorm(q),

0 commit comments

Comments
 (0)