[MoE] Fix output_shape calculation in Attention layer to handle 3D query inputs#31596
Conversation
…ery inputs Signed-off-by: Andreas Karatzas <akaratza@amd.com>
There was a problem hiding this comment.
Code Review
The pull request modifies the vllm/attention/layer.py file to adjust the output_shape calculation in the forward method, explicitly extracting num_tokens from the query shape to handle both 2D and 3D query inputs. Additionally, in vllm/model_executor/layers/quantization/fp8.py, the DeepGEMM backend activation logic was updated to first respect the user's VLLM_USE_DEEP_GEMM setting and then disable DeepGEMM if the platform does not support it, logging an informational message. A review comment highlighted that the DeepGEMM logging logic could be misleading; if a user explicitly disables DeepGEMM, the current message incorrectly attributes the disabling to platform incompatibility. The suggested fix proposes to only log the platform incompatibility message if DeepGEMM was initially requested by the user but is not supported by the platform.
| if not is_deep_gemm_supported(): | ||
| use_deep_gemm = False | ||
| logger.info_once( | ||
| "DeepGEMM is disabled because the platform does not support it.", | ||
| scope="local", | ||
| ) |
There was a problem hiding this comment.
The current logic for checking DeepGEMM support can produce a misleading log message. If a user explicitly disables DeepGEMM by setting VLLM_USE_DEEP_GEMM=0, is_deep_gemm_supported() will return False, causing the message "DeepGEMM is disabled because the platform does not support it" to be logged. This is inaccurate because the user disabled it, not the platform.
The check should only log a message if the user intended to use DeepGEMM, but it's not supported by the platform. I've suggested a change to correct this logic and make the log message more precise.
| if not is_deep_gemm_supported(): | |
| use_deep_gemm = False | |
| logger.info_once( | |
| "DeepGEMM is disabled because the platform does not support it.", | |
| scope="local", | |
| ) | |
| if use_deep_gemm and not is_deep_gemm_supported(): | |
| use_deep_gemm = False | |
| logger.info_once( | |
| "DeepGEMM was requested but is disabled because the platform does not support it.", | |
| scope="local", | |
| ) |
There was a problem hiding this comment.
This is effectively the next check that is performed. And the message in the next if statement is the same with the proposed one. So I think this modification is unnecessary.
| logger.info_once( | ||
| "DeepGEMM is disabled because the platform does not support it.", | ||
| scope="local", | ||
| ) |
There was a problem hiding this comment.
These changes are unrelated to the intent of the PR; why did you add this?
There was a problem hiding this comment.
I can move it to a different PR if that's what you are asking. On ROCm right now the message logged during a run is that DeepGemm is requested but not found, which is not that accurate because DeepGemm is not a ROCm supported feature. So I put together this short block that renders a more precise check first.
| logger.info_once( | ||
| "DeepGEMM is disabled because the platform does not support it.", | ||
| scope="local", | ||
| ) |
|
LGTM, thanks for the contribution! |
…ery inputs (vllm-project#31596) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…ery inputs (vllm-project#31596) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…im] (vllm-project#32274) Summary: The breakage was introduced in D89937241(vllm-project#28775) and D90045073(vllm-project#31596). We will see reshaping errors with the return values of the attention layer. When the query shape is 4D, [batch_size, num_tokens, num_heads, head_dim], the output shape will be composed as [batch_size, num_heads * head_dim] however the correct shape should be [batch_size, num_tokens, num_heads * head_dim] instead. Test Plan: Patched this diff and tested vllm local services, it worked with no issue. Reviewed By: frank-wei Differential Revision: D90600898
…ery inputs (vllm-project#31596) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…ery inputs (vllm-project#31596) Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
…ery inputs (vllm-project#31596) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Fixes a regression introduced in #28775 where the
output_shapecalculation inAttention.forward()assumes 2D query input, causing failures when models pass 3D query tensors.Problem
The new code added in #28775:
This breaks when
queryis 3D[num_tokens, num_heads, head_dim]:query.shape[:-1]=(num_tokens, num_heads)output_shape=(num_tokens, num_heads, hidden_size)← incorrectThis causes
DeepseekV2Attention(used by DeepSeek-V2/V3 with MLA disabled and MTP layers) to produce incorrect output shapes, leading to downstream failures in FP8 quantization kernels.Fix
Use
query.shape[0]to always getnum_tokens, making the calculation robust to both 2D and 3D inputs:Testing
On ROCm (8 x mi325 machine):
pytest -v -s tests/v1/e2e/test_spec_decode.py::test_mtp_correctness[deepseek]pytest -v -s tests/v1/e2e/test_spec_decode.py::test_eagle_correctness[TRITON_ATTN-deepseek_eagle]Fixes issue observed on ROCm but the bug exists on all platforms.