Skip to content

[XPU] support MLA model on Intel GPU#37143

Merged
jikunshang merged 3 commits intovllm-project:mainfrom
jikunshang:kunshang/mla_support
Mar 25, 2026
Merged

[XPU] support MLA model on Intel GPU#37143
jikunshang merged 3 commits intovllm-project:mainfrom
jikunshang:kunshang/mla_support

Conversation

@jikunshang
Copy link
Collaborator

@jikunshang jikunshang commented Mar 16, 2026

Purpose

before this PR, we can enable MLA model by export VLLM_MLA_DISABLE=1, which will always fall back to MHA backend.
this PR will use FLASH_ATTN for prefill and TRITON_MLA for decode.

Test Plan

python3 examples/basic/offline_inference/generate.py --model deepseek-ai/DeepSeek-V2-Lite  --enforce-eager --temperature 0  -tp 2  

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for MLA models on Intel GPUs by enabling flash attention for prefill and Triton MLA for decode on the XPU platform. The changes involve updating XPU custom operations and adding platform-specific logic in the MLA attention and quantization layers. I've found a critical issue in the implementation of forward_xpu in the QuantFP8 layer that could lead to runtime errors. My review includes a suggestion to fix this.

Comment on lines +168 to +176
def forward_xpu(
self,
x: torch.Tensor,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
use_triton: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# XPU currently only supports native implementation.
return self.forward_cuda(x, scale, scale_ub, use_triton)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of forward_xpu calls self.forward_cuda, but the accompanying comment states that "XPU currently only supports native implementation." This is contradictory and can lead to a critical runtime error.

Specifically, when this method is used by subclasses like _DecodeConcatQuantFP8 (in vllm/model_executor/layers/attention/mla_attention.py), which overrides forward_cuda with a different method signature, calling self.forward_cuda from the base class's forward_xpu will cause a parameter mismatch and a crash.

To fix this and align with the comment's intent, forward_xpu should call self.forward_native instead. This will correctly dispatch to the native PyTorch implementation, which is also correctly wrapped by subclasses like _DecodeConcatQuantFP8.

Suggested change
def forward_xpu(
self,
x: torch.Tensor,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
use_triton: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# XPU currently only supports native implementation.
return self.forward_cuda(x, scale, scale_ub, use_triton)
def forward_xpu(
self,
x: torch.Tensor,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
use_triton: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# XPU currently only supports native implementation.
return self.forward_native(x, scale, scale_ub, use_triton)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
@jikunshang jikunshang force-pushed the kunshang/mla_support branch from def498e to 9a7becb Compare March 25, 2026 06:03
@jikunshang jikunshang added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 25, 2026
@jikunshang jikunshang merged commit 14771f7 into vllm-project:main Mar 25, 2026
69 checks passed
RhizoNymph pushed a commit to RhizoNymph/vllm that referenced this pull request Mar 26, 2026
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants