From 1e97e2d1227404820eabafd3725d00eab0d56bc7 Mon Sep 17 00:00:00 2001 From: chang-wenbin Date: Fri, 17 Oct 2025 15:55:15 +0800 Subject: [PATCH] update mla --- .../model_executor/layers/attention/mla_attention_backend.py | 5 +++++ fastdeploy/model_executor/models/deepseek_v3.py | 5 ++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index d7d18526f9..8293b6936f 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -136,6 +136,7 @@ def __init__( self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank self.qk_rope_head_dim: int = fd_config.model_config.qk_rope_head_dim self.qk_head_dim: int = fd_config.model_config.qk_nope_head_dim + fd_config.model_config.qk_rope_head_dim + self.v_head_dim: int = fd_config.model_config.v_head_dim self.attn_softmax_scale: float = self.qk_head_dim**-0.5 if fd_config.model_config.rope_scaling: mscale_all_dim = fd_config.model_config.rope_scaling.get("mscale_all_dim", False) # 1.0 @@ -443,6 +444,10 @@ def forward_mixed( **self.flash_attn_kwargs, )[0] + # NOTE: (changwenbin)If you use Flash-attn2, you need to cut off the padding part. + if self.flash_attn_func is flash_attn_unpadded: + fmha_out = fmha_out[:, :, : self.v_head_dim] + return fmha_out # Decode diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index faa76be8d2..5c9070a701 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -369,8 +369,9 @@ def forward( key = paddle.empty_like(query) key[..., : self.qk_nope_head_dim] = key_nope key[..., self.qk_nope_head_dim :] = key_pe - value = paddle.nn.functional.pad(value, [0, self.qk_head_dim - self.v_head_dim], value=0) + # NOTE: (changwenbin) Flash-attn3 does not need to padding the head-dim of Value, + # Flash-attn2 has already processed it when returning. fmha_out_prefill = self.mla_attn( q=query, k=key, @@ -381,8 +382,6 @@ def forward( forward_meta=forward_meta, ) - fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) - fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim] fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)