diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index d7d18526f9..8293b6936f 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -136,6 +136,7 @@ def __init__( self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank self.qk_rope_head_dim: int = fd_config.model_config.qk_rope_head_dim self.qk_head_dim: int = fd_config.model_config.qk_nope_head_dim + fd_config.model_config.qk_rope_head_dim + self.v_head_dim: int = fd_config.model_config.v_head_dim self.attn_softmax_scale: float = self.qk_head_dim**-0.5 if fd_config.model_config.rope_scaling: mscale_all_dim = fd_config.model_config.rope_scaling.get("mscale_all_dim", False) # 1.0 @@ -443,6 +444,10 @@ def forward_mixed( **self.flash_attn_kwargs, )[0] + # NOTE: (changwenbin)If you use Flash-attn2, you need to cut off the padding part. + if self.flash_attn_func is flash_attn_unpadded: + fmha_out = fmha_out[:, :, : self.v_head_dim] + return fmha_out # Decode diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index faa76be8d2..5c9070a701 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -369,8 +369,9 @@ def forward( key = paddle.empty_like(query) key[..., : self.qk_nope_head_dim] = key_nope key[..., self.qk_nope_head_dim :] = key_pe - value = paddle.nn.functional.pad(value, [0, self.qk_head_dim - self.v_head_dim], value=0) + # NOTE: (changwenbin) Flash-attn3 does not need to padding the head-dim of Value, + # Flash-attn2 has already processed it when returning. fmha_out_prefill = self.mla_attn( q=query, k=key, @@ -381,8 +382,6 @@ def forward( forward_meta=forward_meta, ) - fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) - fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim] fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)