Skip to content

[MoE] Fix output_shape calculation in Attention layer to handle 3D query inputs#31596

Merged
LucasWilkinson merged 1 commit intovllm-project:mainfrom
ROCm:akaratza_fix_moe_out_shape
Jan 2, 2026
Merged

[MoE] Fix output_shape calculation in Attention layer to handle 3D query inputs#31596
LucasWilkinson merged 1 commit intovllm-project:mainfrom
ROCm:akaratza_fix_moe_out_shape

Conversation

@AndreasKaratzas
Copy link
Collaborator

@AndreasKaratzas AndreasKaratzas commented Jan 1, 2026

Fixes a regression introduced in #28775 where the output_shape calculation in Attention.forward() assumes 2D query input, causing failures when models pass 3D query tensors.

Problem

The new code added in #28775:

if output_shape is None:
    output_shape = torch.Size((*query.shape[:-1], self.num_heads * self.head_size_v))

This breaks when query is 3D [num_tokens, num_heads, head_dim]:

  • query.shape[:-1] = (num_tokens, num_heads)
  • output_shape = (num_tokens, num_heads, hidden_size) ← incorrect

This causes DeepseekV2Attention (used by DeepSeek-V2/V3 with MLA disabled and MTP layers) to produce incorrect output shapes, leading to downstream failures in FP8 quantization kernels.

Fix

Use query.shape[0] to always get num_tokens, making the calculation robust to both 2D and 3D inputs:

if output_shape is None:
    num_tokens = query.shape[0]
    output_shape = torch.Size((num_tokens, self.num_heads * self.head_size_v))

Testing

On ROCm (8 x mi325 machine):

  • pytest -v -s tests/v1/e2e/test_spec_decode.py::test_mtp_correctness[deepseek]
  • pytest -v -s tests/v1/e2e/test_spec_decode.py::test_eagle_correctness[TRITON_ATTN-deepseek_eagle]

Fixes issue observed on ROCm but the bug exists on all platforms.

…ery inputs

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@AndreasKaratzas
Copy link
Collaborator Author

cc @DarkLight1337 @yt0428

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request modifies the vllm/attention/layer.py file to adjust the output_shape calculation in the forward method, explicitly extracting num_tokens from the query shape to handle both 2D and 3D query inputs. Additionally, in vllm/model_executor/layers/quantization/fp8.py, the DeepGEMM backend activation logic was updated to first respect the user's VLLM_USE_DEEP_GEMM setting and then disable DeepGEMM if the platform does not support it, logging an informational message. A review comment highlighted that the DeepGEMM logging logic could be misleading; if a user explicitly disables DeepGEMM, the current message incorrectly attributes the disabling to platform incompatibility. The suggested fix proposes to only log the platform incompatibility message if DeepGEMM was initially requested by the user but is not supported by the platform.

Comment on lines +188 to +193
if not is_deep_gemm_supported():
use_deep_gemm = False
logger.info_once(
"DeepGEMM is disabled because the platform does not support it.",
scope="local",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic for checking DeepGEMM support can produce a misleading log message. If a user explicitly disables DeepGEMM by setting VLLM_USE_DEEP_GEMM=0, is_deep_gemm_supported() will return False, causing the message "DeepGEMM is disabled because the platform does not support it" to be logged. This is inaccurate because the user disabled it, not the platform.

The check should only log a message if the user intended to use DeepGEMM, but it's not supported by the platform. I've suggested a change to correct this logic and make the log message more precise.

Suggested change
if not is_deep_gemm_supported():
use_deep_gemm = False
logger.info_once(
"DeepGEMM is disabled because the platform does not support it.",
scope="local",
)
if use_deep_gemm and not is_deep_gemm_supported():
use_deep_gemm = False
logger.info_once(
"DeepGEMM was requested but is disabled because the platform does not support it.",
scope="local",
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is effectively the next check that is performed. And the message in the next if statement is the same with the proposed one. So I think this modification is unnecessary.

@tjtanaa tjtanaa added rocm Related to AMD ROCm ready ONLY add when PR is ready to merge/full CI is needed labels Jan 1, 2026
logger.info_once(
"DeepGEMM is disabled because the platform does not support it.",
scope="local",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are unrelated to the intent of the PR; why did you add this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can move it to a different PR if that's what you are asking. On ROCm right now the message logged during a run is that DeepGemm is requested but not found, which is not that accurate because DeepGemm is not a ROCm supported feature. So I put together this short block that renders a more precise check first.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, thanks 👍

logger.info_once(
"DeepGEMM is disabled because the platform does not support it.",
scope="local",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, thanks 👍

@LucasWilkinson
Copy link
Collaborator

LGTM, thanks for the contribution!

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) January 2, 2026 15:18
@LucasWilkinson LucasWilkinson merged commit 6ef770d into vllm-project:main Jan 2, 2026
60 of 61 checks passed
@AndreasKaratzas AndreasKaratzas deleted the akaratza_fix_moe_out_shape branch January 2, 2026 17:40
LucasWilkinson pushed a commit to neuralmagic/vllm that referenced this pull request Jan 6, 2026
…ery inputs (vllm-project#31596)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Jan 9, 2026
…ery inputs (vllm-project#31596)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
henryoier added a commit to henryoier/vllm that referenced this pull request Jan 15, 2026
…im] (vllm-project#32274)

Summary:

The breakage was introduced in D89937241(vllm-project#28775) and D90045073(vllm-project#31596). We will see reshaping errors with the return values of the attention layer.

When the query shape is 4D, [batch_size, num_tokens, num_heads, head_dim], the output shape will be composed as [batch_size, num_heads * head_dim] however the correct shape should be [batch_size, num_tokens, num_heads * head_dim] instead.

Test Plan: Patched this diff and tested vllm local services, it worked with no issue.

Reviewed By: frank-wei

Differential Revision: D90600898
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
…ery inputs (vllm-project#31596)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…ery inputs (vllm-project#31596)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…ery inputs (vllm-project#31596)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants