Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/model_executor/models/hunyuan_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ def forward(
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
out = self.attn(q, k, v)
out = out.view(
x.size(0),
-1,
self.num_attention_heads_per_partition
* self.hidden_size_per_attention_head,
)
Comment on lines +250 to +255
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For better maintainability and readability, you can simplify the calculation of the last dimension in the view operation. Instead of re-calculating the partitioned hidden size from the number of heads and head size, you can directly use the shape of the weight from the subsequent projection layer o_proj. This makes the code more robust to future changes as it directly references the expected input dimension of the next layer.

        out = out.view(
            x.size(0),
            -1,
            self.o_proj.weight.shape[1],
        )

Comment on lines +250 to +255
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After mmencoderattn, we got the out shaple like :[B, S, N, D], We should merge the multi-head before entering the O matrix.

Not really, we will automatically reshape output to make sure it align with q, k, v input:

is_reshaped = query.dim() != 4
query, key, value = self.maybe_reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len
)
output = vit_flash_attn_wrapper(
q=query,
k=key,
v=value,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=bsz,
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
fa_version=self._fa_version,
)
if is_reshaped:
output = output.reshape(bsz, q_len, -1)

So I think we won't hit the shape mismatch issue here. 🤔

Copy link
Copy Markdown
Contributor Author

@Potabk Potabk Dec 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, it seems the issue for ascend, I'll have a further check

Copy link
Copy Markdown
Contributor Author

@Potabk Potabk Dec 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, Confirmed it is the issue for forward_oot vllm-project/vllm-ascend#5443

output, _ = self.o_proj(out)
return output

Expand Down