[XPU] support MLA model on Intel GPU#37143
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for MLA models on Intel GPUs by enabling flash attention for prefill and Triton MLA for decode on the XPU platform. The changes involve updating XPU custom operations and adding platform-specific logic in the MLA attention and quantization layers. I've found a critical issue in the implementation of forward_xpu in the QuantFP8 layer that could lead to runtime errors. My review includes a suggestion to fix this.
| def forward_xpu( | ||
| self, | ||
| x: torch.Tensor, | ||
| scale: torch.Tensor | None = None, | ||
| scale_ub: torch.Tensor | None = None, | ||
| use_triton: bool = False, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| # XPU currently only supports native implementation. | ||
| return self.forward_cuda(x, scale, scale_ub, use_triton) |
There was a problem hiding this comment.
The implementation of forward_xpu calls self.forward_cuda, but the accompanying comment states that "XPU currently only supports native implementation." This is contradictory and can lead to a critical runtime error.
Specifically, when this method is used by subclasses like _DecodeConcatQuantFP8 (in vllm/model_executor/layers/attention/mla_attention.py), which overrides forward_cuda with a different method signature, calling self.forward_cuda from the base class's forward_xpu will cause a parameter mismatch and a crash.
To fix this and align with the comment's intent, forward_xpu should call self.forward_native instead. This will correctly dispatch to the native PyTorch implementation, which is also correctly wrapped by subclasses like _DecodeConcatQuantFP8.
| def forward_xpu( | |
| self, | |
| x: torch.Tensor, | |
| scale: torch.Tensor | None = None, | |
| scale_ub: torch.Tensor | None = None, | |
| use_triton: bool = False, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| # XPU currently only supports native implementation. | |
| return self.forward_cuda(x, scale, scale_ub, use_triton) | |
| def forward_xpu( | |
| self, | |
| x: torch.Tensor, | |
| scale: torch.Tensor | None = None, | |
| scale_ub: torch.Tensor | None = None, | |
| use_triton: bool = False, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| # XPU currently only supports native implementation. | |
| return self.forward_native(x, scale, scale_ub, use_triton) |
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
def498e to
9a7becb
Compare
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Purpose
before this PR, we can enable MLA model by
export VLLM_MLA_DISABLE=1, which will always fall back to MHA backend.this PR will use FLASH_ATTN for prefill and TRITON_MLA for decode.
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.