@@ -502,7 +502,7 @@ def __init__(
502502 self .quant_config = quant_config
503503
504504 if not self .is_lite :
505- self .fused_a = Linear (
505+ self .kv_a_proj_with_mqa = Linear (
506506 hidden_size ,
507507 self .q_lora_rank + self .kv_lora_rank + self .qk_rope_head_dim ,
508508 bias = bias ,
@@ -528,7 +528,7 @@ def __init__(
528528 allreduce_strategy = config .allreduce_strategy ,
529529 force_dynamic_quantization = config .force_dynamic_quantization )
530530 else :
531- self .fused_a = Linear (
531+ self .kv_a_proj_with_mqa = Linear (
532532 hidden_size ,
533533 self .kv_lora_rank + self .qk_rope_head_dim ,
534534 bias = bias ,
@@ -743,14 +743,15 @@ def forward_impl(self,
743743 torch.Tensor: The output tensor.
744744 """
745745 if self .is_lite :
746- compressed_kv , k_pe = self .fused_a (hidden_states ).split (
746+ compressed_kv , k_pe = self .kv_a_proj_with_mqa (hidden_states ).split (
747747 [self .kv_lora_rank , self .qk_rope_head_dim ], - 1 )
748748 compressed_kv = self .kv_a_layernorm (compressed_kv )
749749 q = hidden_states
750750 else :
751- q , compressed_kv , k_pe = self .fused_a (hidden_states ).split (
752- [self .q_lora_rank , self .kv_lora_rank , self .qk_rope_head_dim ],
753- - 1 )
751+ q , compressed_kv , k_pe = self .kv_a_proj_with_mqa (
752+ hidden_states ).split ([
753+ self .q_lora_rank , self .kv_lora_rank , self .qk_rope_head_dim
754+ ], - 1 )
754755
755756 q , compressed_kv = maybe_execute_in_parallel (
756757 lambda : self .q_a_layernorm (q ),
0 commit comments