Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/_xpu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def flash_attn_varlen_func(
num_splits=0,
return_softmax_lse: bool | None = False,
s_aux: torch.Tensor | None = None,
return_attn_probs: bool | None = False,
):
assert cu_seqlens_k is not None or seqused_k is not None, (
"cu_seqlens_k or seqused_k must be provided"
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/attention/mla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,10 @@ class QueryLenSupport(Enum):
"MLA models using TRITON_MLA will require flash_attn. "
"AITER_MLA backends use aiter kernels instead."
)
elif current_platform.is_xpu():
from vllm._xpu_ops import xpu_ops as ops

flash_attn_varlen_func = ops.flash_attn_varlen_func # type: ignore[no-redef]


def dynamic_per_batched_tensor_quant(
Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/layers/quantization/input_quant_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ def forward_hip(
# Fallback to CUDA implementation
return self.forward_cuda(x, scale, scale_ub)

def forward_xpu(
self,
x: torch.Tensor,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
use_triton: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# XPU can use same code path as CUDA.
return self.forward_cuda(x, scale, scale_ub, use_triton)
Comment on lines +168 to +176
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of forward_xpu calls self.forward_cuda, but the accompanying comment states that "XPU currently only supports native implementation." This is contradictory and can lead to a critical runtime error.

Specifically, when this method is used by subclasses like _DecodeConcatQuantFP8 (in vllm/model_executor/layers/attention/mla_attention.py), which overrides forward_cuda with a different method signature, calling self.forward_cuda from the base class's forward_xpu will cause a parameter mismatch and a crash.

To fix this and align with the comment's intent, forward_xpu should call self.forward_native instead. This will correctly dispatch to the native PyTorch implementation, which is also correctly wrapped by subclasses like _DecodeConcatQuantFP8.

Suggested change
def forward_xpu(
self,
x: torch.Tensor,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
use_triton: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# XPU currently only supports native implementation.
return self.forward_cuda(x, scale, scale_ub, use_triton)
def forward_xpu(
self,
x: torch.Tensor,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
use_triton: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# XPU currently only supports native implementation.
return self.forward_native(x, scale, scale_ub, use_triton)


def forward_native(
self,
x: torch.Tensor,
Expand Down
12 changes: 0 additions & 12 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def get_static_graph_wrapper_cls(cls) -> str:
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
# in V1(or with chunked prefill) block_size is 64
if cache_config and not cache_config.user_specified_block_size:
Expand Down Expand Up @@ -209,17 +208,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if vllm_config.kv_transfer_config is not None:
vllm_config.kv_transfer_config.enable_permute_local_kv = True

if model_config and model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled."
)
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)

# In some cases, the internal memory type cache can misdetect GPU
# memory as host memory, also leading to invalid memory access.
# This cache can be disabled by setting UCX_MEMTYPE_CACHE=n.
Expand Down
Loading