Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,17 @@ set(VLLM_ASCEND_CUSTOM_OP
)

set(VLLM_ASCEND_CUSTOM_OP_EXCLUDE
${KERNEL_FILES}/bgmv_expand.cpp
${KERNEL_FILES}/bgmv_shrink.cpp
${KERNEL_FILES}/sgmv_expand.cpp
${KERNEL_FILES}/sgmv_shrink.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/bgmv_expand.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/bgmv_shrink.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/sgmv_expand.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/sgmv_shrink.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp
)

if(SOC_VERSION STREQUAL "ASCEND310P3")
if(SOC_VERSION STREQUAL "ascend310p3")
message(STATUS "310P hardware detected: disabling MLAPO operators")
message(STATUS "310P hardware detected: excluding batch_matmul_transpose operators")
list(REMOVE_ITEM VLLM_ASCEND_CUSTOM_OP ${VLLM_ASCEND_CUSTOM_OP_EXCLUDE})
endif()

Expand All @@ -79,7 +82,7 @@ ascendc_library(vllm_ascend_kernels SHARED

message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}")

if(SOC_VERSION STREQUAL "ASCEND310P3")
if(SOC_VERSION STREQUAL "ascend310p3")
file(GLOB VLLM_ASCEND_SRC
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/aclnn_torch_adapter/*.cpp)
Expand Down
5 changes: 4 additions & 1 deletion tests/e2e/310p/test_offline_inference_310p.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def test_llm_models(dtype: str, max_tokens: int) -> None:
vllm_model.generate_greedy(example_prompts, max_tokens)


def test_multimodal_vl():
@pytest.mark.skip(reason="310P: multimodal test skipped, offline is ok")
@pytest.mark.parametrize("dtype", ["float16"])
def test_multimodal_vl(dtype: str):
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")

img_questions = [
Expand All @@ -60,6 +62,7 @@ def test_multimodal_vl():
"max_pixels": 1280 * 28 * 28,
"fps": 1,
},
dtype=dtype,
max_model_len=8192,
enforce_eager=True,
limit_mm_per_prompt={"image": 1}) as vllm_model:
Expand Down
18 changes: 17 additions & 1 deletion tests/e2e/310p/test_offline_inference_parallel_310p.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.

import pytest

from tests.e2e.conftest import VllmRunner


@pytest.mark.parametrize("dtype", ["float16"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.skip("310p does not support parallel inference now. Fix me")
def test_models(dtype: str, max_tokens: int) -> None:
example_prompts = [
"Hello, my name is",
Expand Down
Empty file added vllm_ascend/_310p/__init__.py
Empty file.
Empty file.
98 changes: 98 additions & 0 deletions vllm_ascend/_310p/attention/attention_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#

from collections.abc import Callable
from typing import Any

import torch
import torch_npu

import vllm_ascend.attention.attention_mask as _base_mask
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, nd_to_nz_spec

_BASE_BUILDER: Callable[[torch.device], Any] = _base_mask.AttentionMaskBuilder


Comment thread
Tflowers-0129 marked this conversation as resolved.
def _gen_causal_additive_mask_fp16(max_seq_len: int, device: torch.device) -> torch.Tensor:
tril = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool, device=device).tril_()
upper = ~tril
m = torch.zeros((max_seq_len, max_seq_len), dtype=torch.float16, device=device)
m.masked_fill_(upper, float("-inf"))
return m
Comment thread
Tflowers-0129 marked this conversation as resolved.


def build_splitfuse_attn_mask_310p(attn_metadata, device, *, full_mask_cache=None, full_mask_cache_len=0):
qsl = attn_metadata.query_start_loc.detach().to("cpu", dtype=torch.int32)
qlens = qsl[1:] - qsl[:-1]

context_lens = attn_metadata.seq_lens.to(dtype=torch.int32)
L = int(context_lens.max().item())

q_list = qlens.tolist()
c_list = context_lens.detach().to("cpu", dtype=torch.int64).tolist()
pos_list = [p for ql, cl in zip(q_list, c_list) for p in range(cl - ql, cl)]
position = torch.tensor(pos_list, dtype=torch.long, device=device)

if full_mask_cache is None or full_mask_cache.device != device or full_mask_cache_len < L:
tril = torch.ones((L, L), dtype=torch.bool, device=device).tril_()
full = torch.zeros((L, L), dtype=torch.float16, device=device)
full.masked_fill_(~tril, float("-inf"))
full_mask_cache, full_mask_cache_len = full, L
else:
full = full_mask_cache[:L, :L].contiguous()

rows = full.index_select(0, position).contiguous()
mask = torch_npu.npu_format_cast(nd_to_nz_spec(rows).contiguous(), ACL_FORMAT_FRACTAL_NZ)
return mask, full_mask_cache, full_mask_cache_len


class _AttentionMaskBuilder310P:
Comment thread
Tflowers-0129 marked this conversation as resolved.
"""
310P adapter:
- overrides fp16 causal additive mask generation (use -inf fp16)
- delegates all other behaviors to base AttentionMaskBuilder
- pooling runner_type is NOT supported on 310P (explicit)
"""

def __init__(self, device: torch.device):
self._base = _BASE_BUILDER(device)

self._fp16_mask_cache: torch.Tensor | None = None
self._fp16_mask_cached_len: int = 0

def __getattr__(self, name: str) -> Any:
return getattr(self._base, name)

@property
def device(self) -> torch.device:
return self._base.device

def _get_fp16_mask(self, max_seq_len: int) -> torch.Tensor:
if self._fp16_mask_cache is None or max_seq_len > self._fp16_mask_cached_len:
self._fp16_mask_cache = _gen_causal_additive_mask_fp16(max_seq_len, self.device)
self._fp16_mask_cached_len = max_seq_len
assert self._fp16_mask_cache is not None
return self._fp16_mask_cache[:max_seq_len, :max_seq_len].contiguous()

def get_attention_mask(self, model_config) -> torch.Tensor:
if getattr(model_config, "runner_type", None) == "pooling":
Comment thread
Tflowers-0129 marked this conversation as resolved.
raise NotImplementedError("310P does not support runner_type='pooling'")
return self._get_fp16_mask(2048)


def AttentionMaskBuilder(device: torch.device) -> _AttentionMaskBuilder310P:
return _AttentionMaskBuilder310P(device)
172 changes: 172 additions & 0 deletions vllm_ascend/_310p/attention/attention_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#


import torch
import torch_npu

from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder, build_splitfuse_attn_mask_310p
from vllm_ascend._310p.attention.metadata_builder import AscendAttentionMetadataBuilder310P
from vllm_ascend.attention.attention_v1 import AscendAttentionBackend as _BaseBackend
from vllm_ascend.attention.attention_v1 import AscendAttentionBackendImpl as _BaseImpl
from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder, AscendAttentionState
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, aligned_16, nd_to_nz_2d


class AscendAttentionBackend310(_BaseBackend):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_mask_builder = AttentionMaskBuilder(self.device)

@staticmethod
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int):
# Align to a multiple of 16, as required by the 310P device.
return (2, num_blocks, (num_kv_heads * head_size) // 16, block_size, 16)

@staticmethod
def get_impl_cls():
return AscendAttentionBackendImpl310

@staticmethod
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
return AscendAttentionMetadataBuilder310P


class AscendAttentionBackendImpl310(_BaseImpl):
def forward_paged_attention(self, query, attn_metadata, output):
if attn_metadata.seq_lens.device != query.device:
attn_metadata.seq_lens = attn_metadata.seq_lens.to(device=query.device, non_blocking=True)
return super().forward_paged_attention(query, attn_metadata, output)

def _forward_prefill_310p_fallback(self, query, key, value, attn_metadata, output):
real_tokens = int(attn_metadata.seq_lens.sum().item())

query, key, value, output = (aligned_16(t) for t in (query, key, value, output))

seq_len = attn_metadata.seq_lens
if seq_len.dtype != torch.int32:
seq_len = seq_len.to(torch.int32)

aligned_tokens = int(query.shape[0])
delta = aligned_tokens - real_tokens
if delta:
seq_len = seq_len.clone()
seq_len[-1] += delta

mask = attn_metadata.attn_mask
if mask is not None and mask.dim() == 2:
max_len = int(seq_len.max().item())
aligned_len = ((max_len + 15) // 16) * 16

mask2d = mask[:aligned_len, :aligned_len].contiguous()
mask2d = mask2d.to(torch.float16)
mask_nz = nd_to_nz_2d(mask2d).contiguous()

bsz = int(seq_len.numel())
if bsz > 1:
mask_nz = mask_nz.repeat(bsz, 1, 1, 1).contiguous()

mask = torch_npu.npu_format_cast(mask_nz, ACL_FORMAT_FRACTAL_NZ)

torch_npu._npu_flash_attention(
query=query,
key=key,
value=value,
mask=mask,
seq_len=seq_len,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output,
)

out_real = output[:real_tokens, :, :]
return out_real

def _forward_chunked_prefill_310p(self, query, attn_metadata, output):
assert attn_metadata is not None

if query.dtype == torch.float32:
query = query.to(torch.float16)

qsl_cpu = attn_metadata.query_start_loc.detach().to("cpu", dtype=torch.int32)
qlens = (qsl_cpu[1:] - qsl_cpu[:-1]).to(torch.int32)

context_lens = attn_metadata.seq_lens
if context_lens.dtype != torch.int32:
context_lens = context_lens.to(torch.int32)

block_table = attn_metadata.block_tables.detach()
if block_table.dtype != torch.int32:
block_table = block_table.to(torch.int32)

if not hasattr(self, "_sf_full_mask_cache"):
self._sf_full_mask_cache = None
self._sf_full_mask_cache_len = 0

mask, self._sf_full_mask_cache, self._sf_full_mask_cache_len = build_splitfuse_attn_mask_310p(
attn_metadata,
query.device,
full_mask_cache=self._sf_full_mask_cache,
full_mask_cache_len=int(self._sf_full_mask_cache_len),
)

if qlens.device.type != "cpu":
qlens = qlens.to("cpu")
if context_lens.device != query.device:
context_lens = context_lens.to(query.device, non_blocking=True)

torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=mask,
block_table=block_table,
seq_len=qlens,
context_lens=context_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output,
)

def forward_impl(self, query, key, value, kv_cache, attn_metadata, output):
state = attn_metadata.attn_state

if state == AscendAttentionState.DecodeOnly:
return self.forward_paged_attention(query, attn_metadata, output)

if state == AscendAttentionState.PrefillNoCache:
num_tokens = query.shape[0]
q = query[:num_tokens]
k = key[:num_tokens]
v = value[:num_tokens]
out = self._forward_prefill_310p_fallback(q, k, v, attn_metadata, output)
output[:num_tokens] = out
return output

if state == AscendAttentionState.ChunkedPrefill:
self._forward_chunked_prefill_310p(query, attn_metadata, output)
return output

raise NotImplementedError(
f"{self.__class__.__name__}.forward_impl: 310P only supports "
f"{AscendAttentionState.DecodeOnly.name}, "
f"{AscendAttentionState.PrefillNoCache.name}, "
f"{AscendAttentionState.ChunkedPrefill.name}, "
f"got {state!r}."
)
40 changes: 40 additions & 0 deletions vllm_ascend/_310p/attention/metadata_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#

from __future__ import annotations

from typing import Any

import torch
from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import AttentionSpec

from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder as _BaseBuilder


class AscendAttentionMetadataBuilder310P(_BaseBuilder):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)

self.attn_mask_builder: Any = AttentionMaskBuilder(self.device)
Loading