Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,17 @@ def forward_mla_core_npu(
k_rope=k_pe,
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
)

attn_output = attn_output.view(-1, m.num_local_heads, m.kv_lora_rank)

attn_bmm_output = torch.empty(
(attn_output.shape[0], m.num_local_heads * m.v_head_dim),
(attn_output.shape[0], m.num_local_heads, m.v_head_dim),
dtype=attn_output.dtype,
device=attn_output.device,
)
torch.bmm(
attn_output.transpose(0, 1),
m.w_vc,
out=attn_bmm_output.view(-1, m.num_local_heads, m.v_head_dim).transpose(0, 1),
)
torch.ops.npu.batch_matmul_transpose(attn_output, m.w_vc, attn_bmm_output)

attn_bmm_output = attn_bmm_output.reshape(-1, m.num_local_heads * m.v_head_dim)

output, _ = m.o_proj(attn_bmm_output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The attn_bmm_output tensor has a 3D shape of (num_tokens, num_local_heads, v_head_dim), but m.o_proj (a RowParallelLinear layer) expects a 2D input where the last dimension is m.num_local_heads * m.v_head_dim. You should reshape attn_bmm_output before passing it to m.o_proj. This is consistent with how it's handled in forward_dsa_core_npu and forward_mha_core_npu.

Suggested change
output, _ = m.o_proj(attn_bmm_output)
attn_bmm_output = attn_bmm_output.reshape(-1, m.num_local_heads * m.v_head_dim)
output, _ = m.o_proj(attn_bmm_output)


return output
Expand Down Expand Up @@ -358,7 +356,11 @@ def forward_dsa_core_npu(
device=attn_output.device,
)

if not forward_batch.forward_mode.is_decode():
if (
not forward_batch.forward_mode.is_decode()
and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
and not forward_batch.forward_mode.is_target_verify()
):
attn_output = attn_output.transpose(0, 1)
torch.bmm(
attn_output,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,9 +569,9 @@ def apply_without_routing_weights(
):
hidden_states = torch.ops.npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale],
bias=[self.w13_scale_bias],
weight=[layer.w13_weight],
scale=[layer.w13_weight_scale],
bias=[layer.w13_scale_bias],
per_token_scale=[hidden_states_scale],
group_list=group_list,
split_item=2,
Expand All @@ -586,9 +586,9 @@ def apply_without_routing_weights(

hidden_states = torch.ops.npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale],
bias=[self.w2_scale_bias],
weight=[layer.w2_weight],
scale=[layer.w2_weight_scale],
bias=[layer.w2_scale_bias],
per_token_scale=[swiglu_out_scale],
group_list=group_list,
split_item=2,
Expand Down
Loading