diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 9d0e16518c8..7340b2c73c2 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -248,7 +248,7 @@ def __init__( dtype=config.torch_dtype, config=model_config, aux_stream=aux_stream) - self.fused_a = DeepseekV3Linear( + self.kv_a_proj_with_mqa = DeepseekV3Linear( config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim + (self.q_lora_rank if not self.is_lite else 0), @@ -1384,7 +1384,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, attn_module.v_b_proj_scale = nn.Parameter( v_b_proj_scale, requires_grad=False) - elif names[-1] == "fused_a": + elif names[-1] == "kv_a_proj_with_mqa": fused_a = weights[ f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:] if not is_lite: diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 0f2a191a9c0..f9e04a2b5ad 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -502,7 +502,7 @@ def __init__( self.quant_config = quant_config if not self.is_lite: - self.fused_a = Linear( + self.kv_a_proj_with_mqa = Linear( hidden_size, self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim, bias=bias, @@ -528,7 +528,7 @@ def __init__( allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization) else: - self.fused_a = Linear( + self.kv_a_proj_with_mqa = Linear( hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=bias, @@ -743,14 +743,15 @@ def forward_impl(self, torch.Tensor: The output tensor. """ if self.is_lite: - compressed_kv, k_pe = self.fused_a(hidden_states).split( + compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split( [self.kv_lora_rank, self.qk_rope_head_dim], -1) compressed_kv = self.kv_a_layernorm(compressed_kv) q = hidden_states else: - q, compressed_kv, k_pe = self.fused_a(hidden_states).split( - [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], - -1) + q, compressed_kv, k_pe = self.kv_a_proj_with_mqa( + hidden_states).split([ + self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim + ], -1) q, compressed_kv = maybe_execute_in_parallel( lambda: self.q_a_layernorm(q),