Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank
self.qk_rope_head_dim: int = fd_config.model_config.qk_rope_head_dim
self.qk_head_dim: int = fd_config.model_config.qk_nope_head_dim + fd_config.model_config.qk_rope_head_dim
self.v_head_dim: int = fd_config.model_config.v_head_dim
self.attn_softmax_scale: float = self.qk_head_dim**-0.5
if fd_config.model_config.rope_scaling:
mscale_all_dim = fd_config.model_config.rope_scaling.get("mscale_all_dim", False) # 1.0
Expand Down Expand Up @@ -443,6 +444,10 @@ def forward_mixed(
**self.flash_attn_kwargs,
)[0]

# NOTE: (changwenbin)If you use Flash-attn2, you need to cut off the padding part.
if self.flash_attn_func is flash_attn_unpadded:
fmha_out = fmha_out[:, :, : self.v_head_dim]

return fmha_out

# Decode
Expand Down
5 changes: 2 additions & 3 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,9 @@ def forward(
key = paddle.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
value = paddle.nn.functional.pad(value, [0, self.qk_head_dim - self.v_head_dim], value=0)

# NOTE: (changwenbin) Flash-attn3 does not need to padding the head-dim of Value,
# Flash-attn2 has already processed it when returning.
fmha_out_prefill = self.mla_attn(
q=query,
k=key,
Expand All @@ -381,8 +382,6 @@ def forward(
forward_meta=forward_meta,
)

fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim]
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)

Expand Down
Loading