From 83cb4f13a74929704393363167b342e501119ceb Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Sat, 10 Jan 2026 14:20:19 +0800 Subject: [PATCH 01/17] [feat]: support 310p run qwen2.5/3 dense and qwen2.5vl models Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- CMakeLists.txt | 11 +- vllm_ascend/_310P/__init__.py | 0 vllm_ascend/_310P/attention/__init__.py | 0 vllm_ascend/_310P/attention/attention_mask.py | 51 ++++++++ vllm_ascend/_310P/attention/attention_v1.py | 110 ++++++++++++++++++ .../_310P/attention/metadata_builder.py | 21 ++++ vllm_ascend/_310P/modelrunner_310p.py | 99 ++++++++++++++++ vllm_ascend/_310P/ops/__init__.py | 0 vllm_ascend/_310P/ops/activation.py | 12 ++ vllm_ascend/_310P/ops/mm_encoder_attention.py | 92 +++++++++++++++ vllm_ascend/_310P/ops/rotary_embedding.py | 7 ++ vllm_ascend/_310P/worker_310p.py | 22 ++++ vllm_ascend/platform.py | 9 +- vllm_ascend/utils.py | 17 +++ 14 files changed, 445 insertions(+), 6 deletions(-) create mode 100644 vllm_ascend/_310P/__init__.py create mode 100644 vllm_ascend/_310P/attention/__init__.py create mode 100644 vllm_ascend/_310P/attention/attention_mask.py create mode 100644 vllm_ascend/_310P/attention/attention_v1.py create mode 100644 vllm_ascend/_310P/attention/metadata_builder.py create mode 100644 vllm_ascend/_310P/modelrunner_310p.py create mode 100644 vllm_ascend/_310P/ops/__init__.py create mode 100644 vllm_ascend/_310P/ops/activation.py create mode 100644 vllm_ascend/_310P/ops/mm_encoder_attention.py create mode 100644 vllm_ascend/_310P/ops/rotary_embedding.py create mode 100644 vllm_ascend/_310P/worker_310p.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 6c55d93ce79..9be3be4c6f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,14 +62,17 @@ set(VLLM_ASCEND_CUSTOM_OP ) set(VLLM_ASCEND_CUSTOM_OP_EXCLUDE - ${KERNEL_FILES}/bgmv_expand.cpp - ${KERNEL_FILES}/bgmv_shrink.cpp - ${KERNEL_FILES}/sgmv_expand.cpp - ${KERNEL_FILES}/sgmv_shrink.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/bgmv_expand.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/bgmv_shrink.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/sgmv_expand.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/sgmv_shrink.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp ) if(SOC_VERSION STREQUAL "ASCEND310P3") + message(STATUS "310P hardware detected: disabling MLAPO operators") + message(STATUS "310P hardware detected: excluding batch_matmul_transpose operators") list(REMOVE_ITEM VLLM_ASCEND_CUSTOM_OP ${VLLM_ASCEND_CUSTOM_OP_EXCLUDE}) endif() diff --git a/vllm_ascend/_310P/__init__.py b/vllm_ascend/_310P/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_ascend/_310P/attention/__init__.py b/vllm_ascend/_310P/attention/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_ascend/_310P/attention/attention_mask.py b/vllm_ascend/_310P/attention/attention_mask.py new file mode 100644 index 00000000000..de502fa4dcd --- /dev/null +++ b/vllm_ascend/_310P/attention/attention_mask.py @@ -0,0 +1,51 @@ +from typing import Any, Callable + +import torch +import vllm_ascend.attention.attention_mask as _base_mask + + +_BASE_BUILDER: Callable[[torch.device], Any] = _base_mask.AttentionMaskBuilder + +def _gen_causal_additive_mask_fp16(max_seq_len: int, device: torch.device) -> torch.Tensor: + tril = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool, device=device).tril_() + upper = ~tril + m = torch.zeros((max_seq_len, max_seq_len), dtype=torch.float16, device=device) + m.masked_fill_(upper, float("-inf")) + return m + +class _AttentionMaskBuilder310P: + def __init__(self, device: torch.device): + self._base = _BASE_BUILDER(device) + + self._fp16_mask_cache = None + self._fp16_mask_cached_len = 0 + + def __getattr__(self, name: str) -> Any: + return getattr(self._base, name) + + @property + def device(self) -> torch.device: + return self._base.device + + def _get_fp16_mask(self, max_seq_len: int) -> torch.Tensor: + if self._fp16_mask_cache is None or max_seq_len > self._fp16_mask_cached_len: + self._fp16_mask_cache = _gen_causal_additive_mask_fp16(max_seq_len, self.device) + self._fp16_mask_cached_len = max_seq_len + return self._fp16_mask_cache[:max_seq_len, :max_seq_len].contiguous() + + def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype): + if dtype == torch.float16: + return self._get_fp16_mask(max_seq_len) + return self._base.get_attn_mask(max_seq_len, dtype) + + def get_splitfuse_attn_mask(self) -> torch.Tensor: + return self._get_fp16_mask(2048) + + def get_attention_mask(self, model_config) -> torch.Tensor: + if getattr(model_config, "runner_type", None) == "pooling": + return self._base.get_attn_mask(2048, torch.bool) + return self.get_splitfuse_attn_mask() + + +def AttentionMaskBuilder(device: torch.device) -> _AttentionMaskBuilder310P: + return _AttentionMaskBuilder310P(device) diff --git a/vllm_ascend/_310P/attention/attention_v1.py b/vllm_ascend/_310P/attention/attention_v1.py new file mode 100644 index 00000000000..508a7286e3b --- /dev/null +++ b/vllm_ascend/_310P/attention/attention_v1.py @@ -0,0 +1,110 @@ +import torch +import torch_npu + +from vllm_ascend.utils import aligned_16, nd_to_nz_2d, ACL_FORMAT_FRACTAL_NZ +from vllm_ascend.attention.attention_v1 import ( + AscendAttentionBackend as _BaseBackend, + AscendAttentionBackendImpl as _BaseImpl, + AscendAttentionState, +) +from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder +from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder +from vllm_ascend._310p.attention.metadata_builder import AscendAttentionMetadataBuilder310P + + +class AscendAttentionBackend310(_BaseBackend): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attn_mask_builder = AttentionMaskBuilder(self.device) + + @staticmethod + def get_kv_cache_shape(num_blocks: int, block_size: int, + num_kv_heads: int, head_size: int): + return (2, num_blocks, (num_kv_heads * head_size) // 16, block_size, 16) + + @staticmethod + def get_impl_cls(): + return AscendAttentionBackendImpl310 + + @staticmethod + def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: + return AscendAttentionMetadataBuilder310P + + + +class AscendMLABackend310(AscendAttentionBackend310): + pass + + +class AscendSFABackend310(AscendAttentionBackend310): + pass + + +class AscendAttentionBackendImpl310(_BaseImpl): + + def forward_paged_attention(self, query, attn_metadata, output): + if attn_metadata.seq_lens.device != query.device: + attn_metadata.seq_lens = attn_metadata.seq_lens.to( + device=query.device, non_blocking=True + ) + return super().forward_paged_attention(query, attn_metadata, output) + + def _forward_prefill_310p_fallback(self, query, key, value, attn_metadata, output): + real_tokens = int(attn_metadata.seq_lens.sum().item()) + + query, key, value, output = (aligned_16(t) for t in (query, key, value, output)) + + seq_len = attn_metadata.seq_lens + if seq_len.dtype != torch.int32: + seq_len = seq_len.to(torch.int32) + + aligned_tokens = int(query.shape[0]) + delta = aligned_tokens - real_tokens + if delta: + seq_len = seq_len.clone() + seq_len[-1] += delta + + mask = attn_metadata.attn_mask + if mask is not None and mask.dim() == 2: + max_len = int(seq_len.max().item()) + aligned_len = ((max_len + 15) // 16) * 16 + + mask2d = mask[:aligned_len, :aligned_len].contiguous() + mask2d = mask2d.to(torch.float16) + mask_nz = nd_to_nz_2d(mask2d).contiguous() + + bsz = int(seq_len.numel()) + if bsz > 1: + mask_nz = mask_nz.repeat(bsz, 1, 1, 1).contiguous() + + mask = torch_npu.npu_format_cast(mask_nz, ACL_FORMAT_FRACTAL_NZ) + + torch_npu._npu_flash_attention( + query=query, + key=key, + value=value, + mask=mask, + seq_len=seq_len, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + out=output, + ) + + out_real = output[:real_tokens, :, :] + return out_real + + + def forward_impl(self, query, key, value, kv_cache, attn_metadata, output): + if attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + output = self.forward_paged_attention(query, attn_metadata, output) + + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + num_tokens = query.shape[0] + q = query[:num_tokens] + k = key[:num_tokens] + v = value[:num_tokens] + out = self._forward_prefill_310p_fallback(q, k, v, attn_metadata, output) + output[:num_tokens] = out + + return output diff --git a/vllm_ascend/_310P/attention/metadata_builder.py b/vllm_ascend/_310P/attention/metadata_builder.py new file mode 100644 index 00000000000..bbd6c6fe480 --- /dev/null +++ b/vllm_ascend/_310P/attention/metadata_builder.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import torch +from vllm.config import VllmConfig +from vllm.v1.kv_cache_interface import AttentionSpec + +from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder as _BaseBuilder +from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder + + +class AscendAttentionMetadataBuilder310P(_BaseBuilder): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + + self.attn_mask_builder = AttentionMaskBuilder(self.device) diff --git a/vllm_ascend/_310P/modelrunner_310p.py b/vllm_ascend/_310P/modelrunner_310p.py new file mode 100644 index 00000000000..2f282b5cfa6 --- /dev/null +++ b/vllm_ascend/_310P/modelrunner_310p.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from typing import Dict, Any + +import torch +import torch_npu +from vllm.logger import logger + +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ + + +class NPUModelRunner310(NPUModelRunner): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._acl_format = ACL_FORMAT_FRACTAL_NZ + + def _num_attn_module(self) -> int: + return 2 if self.model_config.hf_config.model_type == "longcat_flash" else 1 + + def _initialize_kv_cache_tensors_310p( + self, kv_cache_config: "KVCacheConfig" + ) -> dict[str, Any]: + from vllm.v1.kv_cache_interface import FullAttentionSpec + from vllm.v1.worker.utils import bind_kv_cache + + if self.vllm_config.kv_transfer_config is not None: + raise ValueError("KV cache transfer is not supported for 310P.") + + kv_cache_sizes: dict[str, int] = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + assert len(kv_cache_tensor.shared_by) == 1, ( + "KV cache tensor shared by multiple layers is not supported in 310P." + ) + kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size + + kv_caches: Dict[str, Any] = {} + + for group in self._kv_cache_spec_attn_group_iterator(): + kv_cache_spec = group.kv_cache_spec + attn_backend = group.backend + + if not isinstance(kv_cache_spec, FullAttentionSpec): + raise ValueError("Unknown KV cache spec type.") + + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + + tensor_size = kv_cache_sizes[layer_name] + assert tensor_size % kv_cache_spec.page_size_bytes == 0 + num_blocks = tensor_size // kv_cache_spec.page_size_bytes + assert num_blocks >= kv_cache_config.num_blocks + + if self.vllm_config.additional_config.get("kv_cache_dtype", None) == "int8": + kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + ) + elif hasattr(attn_backend, "get_supported_block_size") and self.use_hybrid_blocks: + block_size = attn_backend.get_supported_block_size()[0] + block_size_chunk = kv_cache_spec.block_size // block_size + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks * block_size_chunk, + block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + ) + else: + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + ) + + dtype = kv_cache_spec.dtype + + if "attn" in layer_name: + k_tensor = torch.zeros(kv_cache_shape[1:], dtype=dtype, device=self.device) + v_tensor = torch.zeros(kv_cache_shape[1:], dtype=dtype, device=self.device) + k_cache = torch_npu.npu_format_cast(k_tensor, self._acl_format) + v_cache = torch_npu.npu_format_cast(v_tensor, self._acl_format) + kv_caches[layer_name] = (k_cache, v_cache) + + bind_kv_cache( + kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches, + self._num_attn_module(), + ) + return kv_caches + + def initialize_kv_cache_tensors( + self, kv_cache_config: "KVCacheConfig" + ) -> dict[str, Any]: + return self._initialize_kv_cache_tensors_310p(kv_cache_config) diff --git a/vllm_ascend/_310P/ops/__init__.py b/vllm_ascend/_310P/ops/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_ascend/_310P/ops/activation.py b/vllm_ascend/_310P/ops/activation.py new file mode 100644 index 00000000000..1a868c688ad --- /dev/null +++ b/vllm_ascend/_310P/ops/activation.py @@ -0,0 +1,12 @@ +import torch +import torch.nn.functional as F +from vllm_ascend.ops.activation import AscendSiluAndMul as _Base + + +class AscendSiluAndMul310(_Base): + def forward(self, x: torch.Tensor) -> torch.Tensor: + torch.ops.vllm.maybe_prefetch_mlp_down_proj(x) + h = x.shape[-1] // 2 + out = (F.silu(x[..., :h].to(torch.float32)) * x[..., h:].to(torch.float32)).to(torch.float16) + torch.ops.vllm.maybe_wait_prefetch_done(out) + return out diff --git a/vllm_ascend/_310P/ops/mm_encoder_attention.py b/vllm_ascend/_310P/ops/mm_encoder_attention.py new file mode 100644 index 00000000000..c9ce2eac0f3 --- /dev/null +++ b/vllm_ascend/_310P/ops/mm_encoder_attention.py @@ -0,0 +1,92 @@ +import torch +import torch.nn.functional as F +import einops +import torch_npu + +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ops.mm_encoder_attention import ( + AscendMMEncoderAttention as _Base, + MIN_PAD_SIZE, + MAX_PAD_SIZE, +) + + +class AscendMMEncoderAttention310(_Base): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward_oot( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + **kwargs, + ): + bsz, q_len = query.size()[:2] + kv_len = key.size(1) + is_reshaped = query.dim() == 4 + + q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len) + + enable_pad = ( + envs_ascend.USE_OPTIMIZED_MODEL + and self.head_size > MIN_PAD_SIZE + and self.head_size < MAX_PAD_SIZE + ) + + origin_shape = q.shape[-1] + if enable_pad: + pad_len = MAX_PAD_SIZE - origin_shape + q = F.pad(q, (0, pad_len), mode="constant", value=0) + k = F.pad(k, (0, pad_len), mode="constant", value=0) + v = F.pad(v, (0, pad_len), mode="constant", value=0) + + origin_dim = origin_shape + cur_dim = q.shape[-1] + pad16 = (16 - cur_dim % 16) % 16 + if pad16: + q = F.pad(q, (0, pad16), mode="constant", value=0) + k = F.pad(k, (0, pad16), mode="constant", value=0) + v = F.pad(v, (0, pad16), mode="constant", value=0) + + if cu_seqlens is None: + cu_seqlens = torch.arange( + 0, (bsz + 1) * q_len, step=q_len, + dtype=torch.int32, device=query.device, + ) + + total_q_tokens = bsz * q_len + context_flat = q.new_empty((total_q_tokens, self.num_heads, q.shape[-1])) + + st = 0 + seg_lens = torch.diff(cu_seqlens).to("cpu", dtype=torch.int64).tolist() + for seg_len in seg_lens: + seg_len = int(seg_len) + ed = st + seg_len + + q_i = q[st:ed].unsqueeze(0) # [1, S, H, D] + k_i = k[st:ed].unsqueeze(0) + v_i = v[st:ed].unsqueeze(0) + + qs = int(q_i.shape[1]) + kvs = int(k_i.shape[1]) + + out_i = torch_npu.npu_prompt_flash_attention( + q_i, k_i, v_i, + input_layout="BSND", + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + scale_value=self.head_size**-0.5, + pre_tokens=qs, + next_tokens=kvs, + ) + context_flat[st:ed] = out_i[0] + st = ed + + context_flat = context_flat[..., :origin_dim] + context_layer = einops.rearrange( + context_flat, "(b s) h d -> b s h d", b=bsz + ).contiguous() + return context_layer diff --git a/vllm_ascend/_310P/ops/rotary_embedding.py b/vllm_ascend/_310P/ops/rotary_embedding.py new file mode 100644 index 00000000000..357d36a08e8 --- /dev/null +++ b/vllm_ascend/_310P/ops/rotary_embedding.py @@ -0,0 +1,7 @@ +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding + + +class AscendMRotaryEmbedding310(MRotaryEmbedding): + + def forward_oot(self, positions, query, key): + return super().forward_oot(positions, query, key) diff --git a/vllm_ascend/_310P/worker_310p.py b/vllm_ascend/_310P/worker_310p.py new file mode 100644 index 00000000000..6c4c31440c0 --- /dev/null +++ b/vllm_ascend/_310P/worker_310p.py @@ -0,0 +1,22 @@ +import torch +import torch_npu +from vllm.logger import logger + +from vllm_ascend.worker.worker import NPUWorker +from vllm_ascend.utils import is_310p +from vllm_ascend._310p.modelrunner_310p import NPUModelRunner310 + + +class NPUWorker310(NPUWorker): + def init_device(self): + self.device = self._init_device() + + torch_npu.npu.set_compile_mode(jit_compile=False) + + from vllm_ascend.worker.worker import init_workspace_manager + init_workspace_manager(self.device, num_ubatches=1) + + self.model_runner = NPUModelRunner310(self.vllm_config, self.device) + + def _warm_up_atb(self): + logger.info("Skip warm-up atb ops for 310P device") diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 3ca49187c7d..044b3fdc457 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -48,6 +48,7 @@ update_aclgraph_sizes, update_cudagraph_capture_sizes, update_default_aclgraph_sizes, + is_310p, ) if TYPE_CHECKING: @@ -322,8 +323,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if parallel_config and parallel_config.worker_cls == "auto": # TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm. parallel_config.all2all_backend = "flashinfer_all2allv" - if ascend_config.xlite_graph_config.enabled: - logger.info("openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite") + if is_310p(): + parallel_config.worker_cls = "vllm_ascend._310p.worker_310p.NPUWorker310" + elif ascend_config.xlite_graph_config.enabled: + logger.info( + "openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite" + ) parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker" else: parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker" diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index ca5bb41f1be..fe351911318 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -74,6 +74,9 @@ _HAS_ROPE = None +def is_310p(): + return get_ascend_device_type() == AscendDeviceType._310p + def _print_callback_on_stream(*args): """Callback function to print arguments on the dedicated print stream.""" global _GRAPH_PRINT_STREAM @@ -713,6 +716,20 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): "ApplyRotaryEmb": AscendApplyRotaryEmb, } + # 310P: override selected ops with 310P implementations (keep minimal changes outside _310p) + if is_310p(): + from vllm_ascend._310p.ops.activation import AscendSiluAndMul310 + from vllm_ascend._310p.ops.mm_encoder_attention import AscendMMEncoderAttention310 + from vllm_ascend._310p.ops.rotary_embedding import ( + AscendMRotaryEmbedding310, + ) + + REGISTERED_ASCEND_OPS.update({ + "SiluAndMul": AscendSiluAndMul310, + "MMEncoderAttention": AscendMMEncoderAttention310, + "MRotaryEmbedding": AscendMRotaryEmbedding310, + }) + for name, op_cls in REGISTERED_ASCEND_OPS.items(): CustomOp.register_oot(_decorated_op_cls=op_cls, name=name) From 8871364cb60a5ceb53f5518a5bbe7f4539aafcb2 Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Sat, 10 Jan 2026 14:26:53 +0800 Subject: [PATCH 02/17] [bugfix]: fix the router of 310p attnbackend Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- vllm_ascend/platform.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 044b3fdc457..4d32abe34e4 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -399,6 +399,16 @@ def import_kernels(cls) -> None: @classmethod def get_attn_backend_cls(cls, selected_backend, attn_selector_config): + key = (attn_selector_config.use_mla, attn_selector_config.use_sparse) + if is_310p(): + default_attn_backend = "vllm_ascend._310p.attention.AscendAttentionBackend310" + backend_map_310 = { + #@TODO 310p unable to use MLA/SFA, the key maybe ALWAYS (False, False) + (True, False): "vllm_ascend._310p.attention.AscendMLABackend310", + (False, False): "vllm_ascend._310p.attention.attention_v1.AscendAttentionBackend310", + (True, True): "vllm_ascend._310p.attention.AscendSFABackend310", + } + return backend_map_310.get(key, default_attn_backend) backend_map = { (True, False): "vllm_ascend.attention.mla_v1.AscendMLABackend", (False, False): "vllm_ascend.attention.attention_v1.AscendAttentionBackend", From 91e561d0106927ff2a821d2f5466490912e5ca94 Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Sat, 10 Jan 2026 14:46:29 +0800 Subject: [PATCH 03/17] [bugfix]: rename _310P to _310p Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- vllm_ascend/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index fe351911318..b08069bbb4b 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -75,7 +75,7 @@ def is_310p(): - return get_ascend_device_type() == AscendDeviceType._310p + return get_ascend_device_type() == AscendDeviceType._310P def _print_callback_on_stream(*args): """Callback function to print arguments on the dedicated print stream.""" @@ -729,7 +729,7 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): "MMEncoderAttention": AscendMMEncoderAttention310, "MRotaryEmbedding": AscendMRotaryEmbedding310, }) - + for name, op_cls in REGISTERED_ASCEND_OPS.items(): CustomOp.register_oot(_decorated_op_cls=op_cls, name=name) From c3817a5673aa90acabc1c15a33fc4ed6fd1dd1d1 Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Sat, 10 Jan 2026 15:09:10 +0800 Subject: [PATCH 04/17] rename the dir _310P to _310p Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- vllm_ascend/{_310P => _310p}/__init__.py | 0 vllm_ascend/{_310P => _310p}/attention/__init__.py | 0 vllm_ascend/{_310P => _310p}/attention/attention_mask.py | 0 vllm_ascend/{_310P => _310p}/attention/attention_v1.py | 0 vllm_ascend/{_310P => _310p}/attention/metadata_builder.py | 0 vllm_ascend/{_310P => _310p}/modelrunner_310p.py | 0 vllm_ascend/{_310P => _310p}/ops/__init__.py | 0 vllm_ascend/{_310P => _310p}/ops/activation.py | 0 vllm_ascend/{_310P => _310p}/ops/mm_encoder_attention.py | 0 vllm_ascend/{_310P => _310p}/ops/rotary_embedding.py | 0 vllm_ascend/{_310P => _310p}/worker_310p.py | 0 11 files changed, 0 insertions(+), 0 deletions(-) rename vllm_ascend/{_310P => _310p}/__init__.py (100%) rename vllm_ascend/{_310P => _310p}/attention/__init__.py (100%) rename vllm_ascend/{_310P => _310p}/attention/attention_mask.py (100%) rename vllm_ascend/{_310P => _310p}/attention/attention_v1.py (100%) rename vllm_ascend/{_310P => _310p}/attention/metadata_builder.py (100%) rename vllm_ascend/{_310P => _310p}/modelrunner_310p.py (100%) rename vllm_ascend/{_310P => _310p}/ops/__init__.py (100%) rename vllm_ascend/{_310P => _310p}/ops/activation.py (100%) rename vllm_ascend/{_310P => _310p}/ops/mm_encoder_attention.py (100%) rename vllm_ascend/{_310P => _310p}/ops/rotary_embedding.py (100%) rename vllm_ascend/{_310P => _310p}/worker_310p.py (100%) diff --git a/vllm_ascend/_310P/__init__.py b/vllm_ascend/_310p/__init__.py similarity index 100% rename from vllm_ascend/_310P/__init__.py rename to vllm_ascend/_310p/__init__.py diff --git a/vllm_ascend/_310P/attention/__init__.py b/vllm_ascend/_310p/attention/__init__.py similarity index 100% rename from vllm_ascend/_310P/attention/__init__.py rename to vllm_ascend/_310p/attention/__init__.py diff --git a/vllm_ascend/_310P/attention/attention_mask.py b/vllm_ascend/_310p/attention/attention_mask.py similarity index 100% rename from vllm_ascend/_310P/attention/attention_mask.py rename to vllm_ascend/_310p/attention/attention_mask.py diff --git a/vllm_ascend/_310P/attention/attention_v1.py b/vllm_ascend/_310p/attention/attention_v1.py similarity index 100% rename from vllm_ascend/_310P/attention/attention_v1.py rename to vllm_ascend/_310p/attention/attention_v1.py diff --git a/vllm_ascend/_310P/attention/metadata_builder.py b/vllm_ascend/_310p/attention/metadata_builder.py similarity index 100% rename from vllm_ascend/_310P/attention/metadata_builder.py rename to vllm_ascend/_310p/attention/metadata_builder.py diff --git a/vllm_ascend/_310P/modelrunner_310p.py b/vllm_ascend/_310p/modelrunner_310p.py similarity index 100% rename from vllm_ascend/_310P/modelrunner_310p.py rename to vllm_ascend/_310p/modelrunner_310p.py diff --git a/vllm_ascend/_310P/ops/__init__.py b/vllm_ascend/_310p/ops/__init__.py similarity index 100% rename from vllm_ascend/_310P/ops/__init__.py rename to vllm_ascend/_310p/ops/__init__.py diff --git a/vllm_ascend/_310P/ops/activation.py b/vllm_ascend/_310p/ops/activation.py similarity index 100% rename from vllm_ascend/_310P/ops/activation.py rename to vllm_ascend/_310p/ops/activation.py diff --git a/vllm_ascend/_310P/ops/mm_encoder_attention.py b/vllm_ascend/_310p/ops/mm_encoder_attention.py similarity index 100% rename from vllm_ascend/_310P/ops/mm_encoder_attention.py rename to vllm_ascend/_310p/ops/mm_encoder_attention.py diff --git a/vllm_ascend/_310P/ops/rotary_embedding.py b/vllm_ascend/_310p/ops/rotary_embedding.py similarity index 100% rename from vllm_ascend/_310P/ops/rotary_embedding.py rename to vllm_ascend/_310p/ops/rotary_embedding.py diff --git a/vllm_ascend/_310P/worker_310p.py b/vllm_ascend/_310p/worker_310p.py similarity index 100% rename from vllm_ascend/_310P/worker_310p.py rename to vllm_ascend/_310p/worker_310p.py From 115fdb7490e947f61687dde2ef691dff3a42b667 Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Sat, 10 Jan 2026 16:45:17 +0800 Subject: [PATCH 05/17] fix mypy test for runner/attnmask/metadatabuilder Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- vllm_ascend/_310p/attention/attention_mask.py | 7 ++++--- vllm_ascend/_310p/attention/metadata_builder.py | 5 +++-- vllm_ascend/_310p/modelrunner_310p.py | 1 + 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/_310p/attention/attention_mask.py b/vllm_ascend/_310p/attention/attention_mask.py index de502fa4dcd..c2cce40ad04 100644 --- a/vllm_ascend/_310p/attention/attention_mask.py +++ b/vllm_ascend/_310p/attention/attention_mask.py @@ -1,4 +1,4 @@ -from typing import Any, Callable +from typing import Any, Callable, Optional import torch import vllm_ascend.attention.attention_mask as _base_mask @@ -17,8 +17,8 @@ class _AttentionMaskBuilder310P: def __init__(self, device: torch.device): self._base = _BASE_BUILDER(device) - self._fp16_mask_cache = None - self._fp16_mask_cached_len = 0 + self._fp16_mask_cache: Optional[torch.Tensor] = None + self._fp16_mask_cached_len: int = 0 def __getattr__(self, name: str) -> Any: return getattr(self._base, name) @@ -31,6 +31,7 @@ def _get_fp16_mask(self, max_seq_len: int) -> torch.Tensor: if self._fp16_mask_cache is None or max_seq_len > self._fp16_mask_cached_len: self._fp16_mask_cache = _gen_causal_additive_mask_fp16(max_seq_len, self.device) self._fp16_mask_cached_len = max_seq_len + assert self._fp16_mask_cache is not None return self._fp16_mask_cache[:max_seq_len, :max_seq_len].contiguous() def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype): diff --git a/vllm_ascend/_310p/attention/metadata_builder.py b/vllm_ascend/_310p/attention/metadata_builder.py index bbd6c6fe480..b9f5353643b 100644 --- a/vllm_ascend/_310p/attention/metadata_builder.py +++ b/vllm_ascend/_310p/attention/metadata_builder.py @@ -1,6 +1,7 @@ from __future__ import annotations import torch +from typing import Any from vllm.config import VllmConfig from vllm.v1.kv_cache_interface import AttentionSpec @@ -17,5 +18,5 @@ def __init__( device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) - - self.attn_mask_builder = AttentionMaskBuilder(self.device) + + self.attn_mask_builder: Any = AttentionMaskBuilder(self.device) diff --git a/vllm_ascend/_310p/modelrunner_310p.py b/vllm_ascend/_310p/modelrunner_310p.py index 2f282b5cfa6..5daad375b0f 100644 --- a/vllm_ascend/_310p/modelrunner_310p.py +++ b/vllm_ascend/_310p/modelrunner_310p.py @@ -5,6 +5,7 @@ import torch import torch_npu from vllm.logger import logger +from vllm.v1.kv_cache_interface import kv_cache_config from vllm_ascend.worker.model_runner_v1 import NPUModelRunner from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ From da50f29551fa4fb7551e8db27607fd1f4e25b341 Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Sat, 10 Jan 2026 18:02:19 +0800 Subject: [PATCH 06/17] fix some format bugs, such as ruff format and mypy Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- vllm_ascend/_310p/modelrunner_310p.py | 2 +- vllm_ascend/_310p/ops/mm_encoder_attention.py | 1 - vllm_ascend/platform.py | 6 +++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/_310p/modelrunner_310p.py b/vllm_ascend/_310p/modelrunner_310p.py index 5daad375b0f..08a50dcb5a7 100644 --- a/vllm_ascend/_310p/modelrunner_310p.py +++ b/vllm_ascend/_310p/modelrunner_310p.py @@ -5,7 +5,7 @@ import torch import torch_npu from vllm.logger import logger -from vllm.v1.kv_cache_interface import kv_cache_config +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm_ascend.worker.model_runner_v1 import NPUModelRunner from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ diff --git a/vllm_ascend/_310p/ops/mm_encoder_attention.py b/vllm_ascend/_310p/ops/mm_encoder_attention.py index c9ce2eac0f3..714ad79bb7e 100644 --- a/vllm_ascend/_310p/ops/mm_encoder_attention.py +++ b/vllm_ascend/_310p/ops/mm_encoder_attention.py @@ -26,7 +26,6 @@ def forward_oot( ): bsz, q_len = query.size()[:2] kv_len = key.size(1) - is_reshaped = query.dim() == 4 q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 4d32abe34e4..7b6c4be227c 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -401,12 +401,12 @@ def import_kernels(cls) -> None: def get_attn_backend_cls(cls, selected_backend, attn_selector_config): key = (attn_selector_config.use_mla, attn_selector_config.use_sparse) if is_310p(): - default_attn_backend = "vllm_ascend._310p.attention.AscendAttentionBackend310" + default_attn_backend = "vllm_ascend._310p.attention.attention_v1.AscendAttentionBackend310" backend_map_310 = { #@TODO 310p unable to use MLA/SFA, the key maybe ALWAYS (False, False) - (True, False): "vllm_ascend._310p.attention.AscendMLABackend310", + (True, False): "vllm_ascend._310p.attention.attention_v1.AscendMLABackend310", (False, False): "vllm_ascend._310p.attention.attention_v1.AscendAttentionBackend310", - (True, True): "vllm_ascend._310p.attention.AscendSFABackend310", + (True, True): "vllm_ascend._310p.attention.attention_v1.AscendSFABackend310", } return backend_map_310.get(key, default_attn_backend) backend_map = { From d7a47da43471f1305e243072a976a603bf1b6de3 Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Thu, 15 Jan 2026 14:56:29 +0800 Subject: [PATCH 07/17] support chunedprefilled state in 310p device Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- vllm_ascend/_310p/attention/attention_mask.py | 89 ++++++++++-- vllm_ascend/_310p/attention/attention_v1.py | 129 ++++++++++++++---- .../_310p/attention/metadata_builder.py | 25 +++- vllm_ascend/_310p/modelrunner_310p.py | 53 ++++--- vllm_ascend/_310p/ops/activation.py | 22 ++- vllm_ascend/_310p/ops/mm_encoder_attention.py | 42 ++++-- vllm_ascend/_310p/ops/rotary_embedding.py | 18 ++- vllm_ascend/_310p/worker_310p.py | 23 +++- .../patch/platform/patch_distributed.py | 8 +- vllm_ascend/platform.py | 26 ++-- 10 files changed, 338 insertions(+), 97 deletions(-) diff --git a/vllm_ascend/_310p/attention/attention_mask.py b/vllm_ascend/_310p/attention/attention_mask.py index c2cce40ad04..e9fb7189acf 100644 --- a/vllm_ascend/_310p/attention/attention_mask.py +++ b/vllm_ascend/_310p/attention/attention_mask.py @@ -1,19 +1,84 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + from typing import Any, Callable, Optional import torch -import vllm_ascend.attention.attention_mask as _base_mask +import torch_npu +import vllm_ascend.attention.attention_mask as _base_mask +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, nd_to_nz_spec _BASE_BUILDER: Callable[[torch.device], Any] = _base_mask.AttentionMaskBuilder -def _gen_causal_additive_mask_fp16(max_seq_len: int, device: torch.device) -> torch.Tensor: - tril = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool, device=device).tril_() + +def _gen_causal_additive_mask_fp16( + max_seq_len: int, device: torch.device +) -> torch.Tensor: + tril = torch.ones( + (max_seq_len, max_seq_len), dtype=torch.bool, device=device + ).tril_() upper = ~tril m = torch.zeros((max_seq_len, max_seq_len), dtype=torch.float16, device=device) m.masked_fill_(upper, float("-inf")) return m + +def build_splitfuse_attn_mask_310p( + attn_metadata, device, *, full_mask_cache=None, full_mask_cache_len=0 +): + qsl = attn_metadata.query_start_loc.detach().to("cpu", dtype=torch.int32) + qlens = qsl[1:] - qsl[:-1] + + context_lens = attn_metadata.seq_lens.to(dtype=torch.int32) + L = int(context_lens.max().item()) + + q_list = qlens.tolist() + c_list = context_lens.detach().to("cpu", dtype=torch.int64).tolist() + pos_list = [p for ql, cl in zip(q_list, c_list) for p in range(cl - ql, cl)] + position = torch.tensor(pos_list, dtype=torch.long, device=device) + + if ( + full_mask_cache is None + or full_mask_cache.device != device + or full_mask_cache_len < L + ): + tril = torch.ones((L, L), dtype=torch.bool, device=device).tril_() + full = torch.zeros((L, L), dtype=torch.float16, device=device) + full.masked_fill_(~tril, float("-inf")) + full_mask_cache, full_mask_cache_len = full, L + else: + full = full_mask_cache[:L, :L].contiguous() + + rows = full.index_select(0, position).contiguous() + mask = torch_npu.npu_format_cast( + nd_to_nz_spec(rows).contiguous(), ACL_FORMAT_FRACTAL_NZ + ) + return mask, full_mask_cache, full_mask_cache_len + + class _AttentionMaskBuilder310P: + """ + 310P adapter: + - overrides fp16 causal additive mask generation (use -inf fp16) + - delegates all other behaviors to base AttentionMaskBuilder + - pooling runner_type is NOT supported on 310P (explicit) + """ + def __init__(self, device: torch.device): self._base = _BASE_BUILDER(device) @@ -26,26 +91,20 @@ def __getattr__(self, name: str) -> Any: @property def device(self) -> torch.device: return self._base.device - + def _get_fp16_mask(self, max_seq_len: int) -> torch.Tensor: if self._fp16_mask_cache is None or max_seq_len > self._fp16_mask_cached_len: - self._fp16_mask_cache = _gen_causal_additive_mask_fp16(max_seq_len, self.device) + self._fp16_mask_cache = _gen_causal_additive_mask_fp16( + max_seq_len, self.device + ) self._fp16_mask_cached_len = max_seq_len assert self._fp16_mask_cache is not None return self._fp16_mask_cache[:max_seq_len, :max_seq_len].contiguous() - def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype): - if dtype == torch.float16: - return self._get_fp16_mask(max_seq_len) - return self._base.get_attn_mask(max_seq_len, dtype) - - def get_splitfuse_attn_mask(self) -> torch.Tensor: - return self._get_fp16_mask(2048) - def get_attention_mask(self, model_config) -> torch.Tensor: if getattr(model_config, "runner_type", None) == "pooling": - return self._base.get_attn_mask(2048, torch.bool) - return self.get_splitfuse_attn_mask() + raise NotImplementedError("310P does not support runner_type='pooling'") + return self._get_fp16_mask(2048) def AttentionMaskBuilder(device: torch.device) -> _AttentionMaskBuilder310P: diff --git a/vllm_ascend/_310p/attention/attention_v1.py b/vllm_ascend/_310p/attention/attention_v1.py index 508a7286e3b..b1190d73bfa 100644 --- a/vllm_ascend/_310p/attention/attention_v1.py +++ b/vllm_ascend/_310p/attention/attention_v1.py @@ -1,47 +1,59 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + + import torch import torch_npu -from vllm_ascend.utils import aligned_16, nd_to_nz_2d, ACL_FORMAT_FRACTAL_NZ -from vllm_ascend.attention.attention_v1 import ( - AscendAttentionBackend as _BaseBackend, - AscendAttentionBackendImpl as _BaseImpl, - AscendAttentionState, -) -from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder -from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder -from vllm_ascend._310p.attention.metadata_builder import AscendAttentionMetadataBuilder310P +from vllm_ascend._310p.attention.attention_mask import ( + AttentionMaskBuilder, build_splitfuse_attn_mask_310p) +from vllm_ascend._310p.attention.metadata_builder import \ + AscendAttentionMetadataBuilder310P +from vllm_ascend.attention.attention_v1 import \ + AscendAttentionBackend as _BaseBackend +from vllm_ascend.attention.attention_v1 import \ + AscendAttentionBackendImpl as _BaseImpl +from vllm_ascend.attention.attention_v1 import (AscendAttentionMetadataBuilder, + AscendAttentionState) +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, aligned_16, nd_to_nz_2d class AscendAttentionBackend310(_BaseBackend): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.attn_mask_builder = AttentionMaskBuilder(self.device) - + @staticmethod - def get_kv_cache_shape(num_blocks: int, block_size: int, - num_kv_heads: int, head_size: int): + def get_kv_cache_shape( + num_blocks: int, block_size: int, num_kv_heads: int, head_size: int + ): + # Align to a multiple of 16, as required by the 310P device. return (2, num_blocks, (num_kv_heads * head_size) // 16, block_size, 16) @staticmethod def get_impl_cls(): return AscendAttentionBackendImpl310 - + @staticmethod def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: return AscendAttentionMetadataBuilder310P - -class AscendMLABackend310(AscendAttentionBackend310): - pass - - -class AscendSFABackend310(AscendAttentionBackend310): - pass - - class AscendAttentionBackendImpl310(_BaseImpl): - def forward_paged_attention(self, query, attn_metadata, output): if attn_metadata.seq_lens.device != query.device: attn_metadata.seq_lens = attn_metadata.seq_lens.to( @@ -94,17 +106,78 @@ def _forward_prefill_310p_fallback(self, query, key, value, attn_metadata, outpu out_real = output[:real_tokens, :, :] return out_real + def _forward_chunked_prefill_310p(self, query, attn_metadata, output): + assert attn_metadata is not None + + if query.dtype == torch.float32: + query = query.to(torch.float16) + + qsl_cpu = attn_metadata.query_start_loc.detach().to("cpu", dtype=torch.int32) + qlens = (qsl_cpu[1:] - qsl_cpu[:-1]).to(torch.int32) + + context_lens = attn_metadata.seq_lens + if context_lens.dtype != torch.int32: + context_lens = context_lens.to(torch.int32) + + block_table = attn_metadata.block_tables.detach() + if block_table.dtype != torch.int32: + block_table = block_table.to(torch.int32) + + if not hasattr(self, "_sf_full_mask_cache"): + self._sf_full_mask_cache = None + self._sf_full_mask_cache_len = 0 + + mask, self._sf_full_mask_cache, self._sf_full_mask_cache_len = ( + build_splitfuse_attn_mask_310p( + attn_metadata, + query.device, + full_mask_cache=self._sf_full_mask_cache, + full_mask_cache_len=int(self._sf_full_mask_cache_len), + ) + ) + + if qlens.device.type != "cpu": + qlens = qlens.to("cpu") + if context_lens.device != query.device: + context_lens = context_lens.to(query.device, non_blocking=True) + + torch_npu._npu_paged_attention_splitfuse( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + mask=mask, + block_table=block_table, + seq_len=qlens, + context_lens=context_lens, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output, + ) def forward_impl(self, query, key, value, kv_cache, attn_metadata, output): - if attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - output = self.forward_paged_attention(query, attn_metadata, output) + state = attn_metadata.attn_state + + if state == AscendAttentionState.DecodeOnly: + return self.forward_paged_attention(query, attn_metadata, output) - if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + if state == AscendAttentionState.PrefillNoCache: num_tokens = query.shape[0] q = query[:num_tokens] k = key[:num_tokens] v = value[:num_tokens] out = self._forward_prefill_310p_fallback(q, k, v, attn_metadata, output) output[:num_tokens] = out - - return output + return output + + if state == AscendAttentionState.ChunkedPrefill: + self._forward_chunked_prefill_310p(query, attn_metadata, output) + return output + + raise NotImplementedError( + f"{self.__class__.__name__}.forward_impl: 310P only supports " + f"{AscendAttentionState.DecodeOnly.name}, " + f"{AscendAttentionState.PrefillNoCache.name}, " + f"{AscendAttentionState.ChunkedPrefill.name}, " + f"got {state!r}." + ) diff --git a/vllm_ascend/_310p/attention/metadata_builder.py b/vllm_ascend/_310p/attention/metadata_builder.py index b9f5353643b..eba64962aa8 100644 --- a/vllm_ascend/_310p/attention/metadata_builder.py +++ b/vllm_ascend/_310p/attention/metadata_builder.py @@ -1,12 +1,31 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + from __future__ import annotations -import torch from typing import Any + +import torch from vllm.config import VllmConfig from vllm.v1.kv_cache_interface import AttentionSpec -from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder as _BaseBuilder from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder +from vllm_ascend.attention.attention_v1 import \ + AscendAttentionMetadataBuilder as _BaseBuilder class AscendAttentionMetadataBuilder310P(_BaseBuilder): @@ -18,5 +37,5 @@ def __init__( device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) - + self.attn_mask_builder: Any = AttentionMaskBuilder(self.device) diff --git a/vllm_ascend/_310p/modelrunner_310p.py b/vllm_ascend/_310p/modelrunner_310p.py index 08a50dcb5a7..3ab60c72d6f 100644 --- a/vllm_ascend/_310p/modelrunner_310p.py +++ b/vllm_ascend/_310p/modelrunner_310p.py @@ -1,14 +1,31 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + from __future__ import annotations -from typing import Dict, Any +from typing import Any, Dict import torch import torch_npu -from vllm.logger import logger -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig +from vllm.v1.worker.utils import bind_kv_cache -from vllm_ascend.worker.model_runner_v1 import NPUModelRunner from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner class NPUModelRunner310(NPUModelRunner): @@ -16,15 +33,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._acl_format = ACL_FORMAT_FRACTAL_NZ - def _num_attn_module(self) -> int: - return 2 if self.model_config.hf_config.model_type == "longcat_flash" else 1 - def _initialize_kv_cache_tensors_310p( self, kv_cache_config: "KVCacheConfig" ) -> dict[str, Any]: - from vllm.v1.kv_cache_interface import FullAttentionSpec - from vllm.v1.worker.utils import bind_kv_cache - if self.vllm_config.kv_transfer_config is not None: raise ValueError("KV cache transfer is not supported for 310P.") @@ -53,14 +64,10 @@ def _initialize_kv_cache_tensors_310p( num_blocks = tensor_size // kv_cache_spec.page_size_bytes assert num_blocks >= kv_cache_config.num_blocks - if self.vllm_config.additional_config.get("kv_cache_dtype", None) == "int8": - kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - ) - elif hasattr(attn_backend, "get_supported_block_size") and self.use_hybrid_blocks: + if ( + hasattr(attn_backend, "get_supported_block_size") + and self.use_hybrid_blocks + ): block_size = attn_backend.get_supported_block_size()[0] block_size_chunk = kv_cache_spec.block_size // block_size kv_cache_shape = attn_backend.get_kv_cache_shape( @@ -80,8 +87,12 @@ def _initialize_kv_cache_tensors_310p( dtype = kv_cache_spec.dtype if "attn" in layer_name: - k_tensor = torch.zeros(kv_cache_shape[1:], dtype=dtype, device=self.device) - v_tensor = torch.zeros(kv_cache_shape[1:], dtype=dtype, device=self.device) + k_tensor = torch.zeros( + kv_cache_shape[1:], dtype=dtype, device=self.device + ) + v_tensor = torch.zeros( + kv_cache_shape[1:], dtype=dtype, device=self.device + ) k_cache = torch_npu.npu_format_cast(k_tensor, self._acl_format) v_cache = torch_npu.npu_format_cast(v_tensor, self._acl_format) kv_caches[layer_name] = (k_cache, v_cache) @@ -90,7 +101,7 @@ def _initialize_kv_cache_tensors_310p( kv_caches, self.compilation_config.static_forward_context, self.kv_caches, - self._num_attn_module(), + 1, # 310p devices donnot support: hf_config.model_type == "longcat_flash" ) return kv_caches diff --git a/vllm_ascend/_310p/ops/activation.py b/vllm_ascend/_310p/ops/activation.py index 1a868c688ad..ede10a3ae89 100644 --- a/vllm_ascend/_310p/ops/activation.py +++ b/vllm_ascend/_310p/ops/activation.py @@ -1,5 +1,23 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + import torch import torch.nn.functional as F + from vllm_ascend.ops.activation import AscendSiluAndMul as _Base @@ -7,6 +25,8 @@ class AscendSiluAndMul310(_Base): def forward(self, x: torch.Tensor) -> torch.Tensor: torch.ops.vllm.maybe_prefetch_mlp_down_proj(x) h = x.shape[-1] // 2 - out = (F.silu(x[..., :h].to(torch.float32)) * x[..., h:].to(torch.float32)).to(torch.float16) + out = (F.silu(x[..., :h].to(torch.float32)) * x[..., h:].to(torch.float32)).to( + torch.float16 + ) torch.ops.vllm.maybe_wait_prefetch_done(out) return out diff --git a/vllm_ascend/_310p/ops/mm_encoder_attention.py b/vllm_ascend/_310p/ops/mm_encoder_attention.py index 714ad79bb7e..1e36f933402 100644 --- a/vllm_ascend/_310p/ops/mm_encoder_attention.py +++ b/vllm_ascend/_310p/ops/mm_encoder_attention.py @@ -1,14 +1,29 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import einops import torch import torch.nn.functional as F -import einops import torch_npu import vllm_ascend.envs as envs_ascend -from vllm_ascend.ops.mm_encoder_attention import ( - AscendMMEncoderAttention as _Base, - MIN_PAD_SIZE, - MAX_PAD_SIZE, -) +from vllm_ascend.ops.mm_encoder_attention import MAX_PAD_SIZE, MIN_PAD_SIZE +from vllm_ascend.ops.mm_encoder_attention import \ + AscendMMEncoderAttention as _Base class AscendMMEncoderAttention310(_Base): @@ -28,7 +43,7 @@ def forward_oot( kv_len = key.size(1) q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len) - + enable_pad = ( envs_ascend.USE_OPTIMIZED_MODEL and self.head_size > MIN_PAD_SIZE @@ -52,8 +67,11 @@ def forward_oot( if cu_seqlens is None: cu_seqlens = torch.arange( - 0, (bsz + 1) * q_len, step=q_len, - dtype=torch.int32, device=query.device, + 0, + (bsz + 1) * q_len, + step=q_len, + dtype=torch.int32, + device=query.device, ) total_q_tokens = bsz * q_len @@ -65,7 +83,7 @@ def forward_oot( seg_len = int(seg_len) ed = st + seg_len - q_i = q[st:ed].unsqueeze(0) # [1, S, H, D] + q_i = q[st:ed].unsqueeze(0) # [1, S, H, D] k_i = k[st:ed].unsqueeze(0) v_i = v[st:ed].unsqueeze(0) @@ -73,7 +91,9 @@ def forward_oot( kvs = int(k_i.shape[1]) out_i = torch_npu.npu_prompt_flash_attention( - q_i, k_i, v_i, + q_i, + k_i, + v_i, input_layout="BSND", num_heads=self.num_heads, num_key_value_heads=self.num_kv_heads, diff --git a/vllm_ascend/_310p/ops/rotary_embedding.py b/vllm_ascend/_310p/ops/rotary_embedding.py index 357d36a08e8..f51d27fd8f9 100644 --- a/vllm_ascend/_310p/ops/rotary_embedding.py +++ b/vllm_ascend/_310p/ops/rotary_embedding.py @@ -1,7 +1,23 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding class AscendMRotaryEmbedding310(MRotaryEmbedding): - def forward_oot(self, positions, query, key): return super().forward_oot(positions, query, key) diff --git a/vllm_ascend/_310p/worker_310p.py b/vllm_ascend/_310p/worker_310p.py index 6c4c31440c0..adfb00fc66b 100644 --- a/vllm_ascend/_310p/worker_310p.py +++ b/vllm_ascend/_310p/worker_310p.py @@ -1,10 +1,25 @@ -import torch +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + import torch_npu from vllm.logger import logger -from vllm_ascend.worker.worker import NPUWorker -from vllm_ascend.utils import is_310p from vllm_ascend._310p.modelrunner_310p import NPUModelRunner310 +from vllm_ascend.worker.worker import NPUWorker, init_workspace_manager class NPUWorker310(NPUWorker): @@ -13,10 +28,10 @@ def init_device(self): torch_npu.npu.set_compile_mode(jit_compile=False) - from vllm_ascend.worker.worker import init_workspace_manager init_workspace_manager(self.device, num_ubatches=1) self.model_runner = NPUModelRunner310(self.vllm_config, self.device) def _warm_up_atb(self): + # 310p device donot support torch_npu._npu_matmul_add_fp32 atb ops logger.info("Skip warm-up atb ops for 310P device") diff --git a/vllm_ascend/patch/platform/patch_distributed.py b/vllm_ascend/patch/platform/patch_distributed.py index f4f342d245c..41620ead501 100644 --- a/vllm_ascend/patch/platform/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_distributed.py @@ -32,12 +32,12 @@ def wait(self): def communication_adaptation_310p(): - def broadcast310p_wrapper(fn): + def broadcast310p(tensor, src=0, group=None, async_op=False, group_src=None): + root = group_src if group_src is not None else src - def broadcast310p(tensor, src, group=None, async_op=False): - if tensor.device == torch.device('cpu'): - return fn(tensor, src, group, async_op) + if tensor.device == torch.device("cpu"): + return fn(tensor, src=root, group=group, async_op=async_op) rank = torch.distributed.get_rank(group) world_size = torch.distributed.get_world_size(group) tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 7b6c4be227c..57996ea52d3 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -400,22 +400,30 @@ def import_kernels(cls) -> None: @classmethod def get_attn_backend_cls(cls, selected_backend, attn_selector_config): key = (attn_selector_config.use_mla, attn_selector_config.use_sparse) - if is_310p(): - default_attn_backend = "vllm_ascend._310p.attention.attention_v1.AscendAttentionBackend310" - backend_map_310 = { - #@TODO 310p unable to use MLA/SFA, the key maybe ALWAYS (False, False) - (True, False): "vllm_ascend._310p.attention.attention_v1.AscendMLABackend310", - (False, False): "vllm_ascend._310p.attention.attention_v1.AscendAttentionBackend310", - (True, True): "vllm_ascend._310p.attention.attention_v1.AscendSFABackend310", - } - return backend_map_310.get(key, default_attn_backend) + backend_map = { (True, False): "vllm_ascend.attention.mla_v1.AscendMLABackend", (False, False): "vllm_ascend.attention.attention_v1.AscendAttentionBackend", (True, True): "vllm_ascend.attention.sfa_v1.AscendSFABackend", } + backend_map_310 = { + ( + False, + False, + ): "vllm_ascend._310p.attention.attention_v1.AscendAttentionBackend310", + # TODO If MLA/SFA is supported in the future, consider implementing the logic described in these comments. + # (True, False): "...AscendMLABackend310", + # (True, True): "...AscendSFABackend310", + } +<<<<<<< HEAD return backend_map[(attn_selector_config.use_mla, attn_selector_config.use_sparse)] +======= + if is_310p(): + return backend_map_310.get(key, backend_map_310[(False, False)]) + + return backend_map[key] +>>>>>>> f95a1763 (support chunedprefilled state in 310p device) @classmethod def get_punica_wrapper(cls) -> str: From e60ddb3c6c61e46ac4eade31a4f22f10fdc168a0 Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Thu, 15 Jan 2026 17:42:34 +0800 Subject: [PATCH 08/17] [Bugfix]: fix cmake bugs:cannt find soc of 310p Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3449282e473..df9ff3b659f 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,7 @@ def get_chip_type() -> str: if "310" in chip_name: # 310P case assert chip_type - return (chip_type + chip_name).lower() + return (chip_type + chip_name) elif "910" in chip_name: if chip_type: # A2 case From 055933635c83f044d554c88ff29e2c2ee03f2ce7 Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Thu, 15 Jan 2026 18:15:25 +0800 Subject: [PATCH 09/17] fix ruff format error of setup.py Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index df9ff3b659f..329e17f6028 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,7 @@ def get_chip_type() -> str: if "310" in chip_name: # 310P case assert chip_type - return (chip_type + chip_name) + return chip_type + chip_name elif "910" in chip_name: if chip_type: # A2 case From 77355dde8abb7191be2846c80e09cd11ca647009 Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Fri, 16 Jan 2026 14:30:07 +0800 Subject: [PATCH 10/17] fix 310p cannot auto sucessfully install bugs Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- CMakeLists.txt | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9be3be4c6f6..6cbb792f487 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -70,7 +70,7 @@ set(VLLM_ASCEND_CUSTOM_OP_EXCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp ) -if(SOC_VERSION STREQUAL "ASCEND310P3") +if(SOC_VERSION STREQUAL "ascend310p3") message(STATUS "310P hardware detected: disabling MLAPO operators") message(STATUS "310P hardware detected: excluding batch_matmul_transpose operators") list(REMOVE_ITEM VLLM_ASCEND_CUSTOM_OP ${VLLM_ASCEND_CUSTOM_OP_EXCLUDE}) @@ -82,7 +82,7 @@ ascendc_library(vllm_ascend_kernels SHARED message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}") -if(SOC_VERSION STREQUAL "ASCEND310P3") +if(SOC_VERSION STREQUAL "ascend310p3") file(GLOB VLLM_ASCEND_SRC ${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/csrc/aclnn_torch_adapter/*.cpp) diff --git a/setup.py b/setup.py index 329e17f6028..3449282e473 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,7 @@ def get_chip_type() -> str: if "310" in chip_name: # 310P case assert chip_type - return chip_type + chip_name + return (chip_type + chip_name).lower() elif "910" in chip_name: if chip_type: # A2 case From 2ff97a1d4a665b20a762197e4a3fdd6c9fbc72bb Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Fri, 16 Jan 2026 14:48:59 +0800 Subject: [PATCH 11/17] fix 310p e2e tp test bugs Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- tests/e2e/310p/test_offline_inference_parallel_310p.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/e2e/310p/test_offline_inference_parallel_310p.py b/tests/e2e/310p/test_offline_inference_parallel_310p.py index c6467d6052a..ffd4092f2ea 100644 --- a/tests/e2e/310p/test_offline_inference_parallel_310p.py +++ b/tests/e2e/310p/test_offline_inference_parallel_310p.py @@ -5,7 +5,6 @@ @pytest.mark.parametrize("dtype", ["float16"]) @pytest.mark.parametrize("max_tokens", [5]) -@pytest.skip("310p does not support parallel inference now. Fix me") def test_models(dtype: str, max_tokens: int) -> None: example_prompts = [ "Hello, my name is", From e3dc3d4021b2dde47e5adbf93a60477b79119614 Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Fri, 16 Jan 2026 15:51:45 +0800 Subject: [PATCH 12/17] fix 310p e2e test dtype not correct Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- tests/e2e/310p/test_offline_inference_310p.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/e2e/310p/test_offline_inference_310p.py b/tests/e2e/310p/test_offline_inference_310p.py index 188865f6315..e4e82041519 100644 --- a/tests/e2e/310p/test_offline_inference_310p.py +++ b/tests/e2e/310p/test_offline_inference_310p.py @@ -36,7 +36,8 @@ def test_llm_models(dtype: str, max_tokens: int) -> None: vllm_model.generate_greedy(example_prompts, max_tokens) -def test_multimodal_vl(): +@pytest.mark.parametrize("dtype", ["float16"]) +def test_multimodal_vl(dtype: str): image = ImageAsset("cherry_blossom").pil_image.convert("RGB") img_questions = [ @@ -60,6 +61,7 @@ def test_multimodal_vl(): "max_pixels": 1280 * 28 * 28, "fps": 1, }, + dtype=dtype, max_model_len=8192, enforce_eager=True, limit_mm_per_prompt={"image": 1}) as vllm_model: From 1c9ce32fc18f92e097260f3781f95376a2f56b13 Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Fri, 16 Jan 2026 17:11:27 +0800 Subject: [PATCH 13/17] re-run ci Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- .../test_offline_inference_parallel_310p.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/e2e/310p/test_offline_inference_parallel_310p.py b/tests/e2e/310p/test_offline_inference_parallel_310p.py index ffd4092f2ea..2a796ad598f 100644 --- a/tests/e2e/310p/test_offline_inference_parallel_310p.py +++ b/tests/e2e/310p/test_offline_inference_parallel_310p.py @@ -1,3 +1,20 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + import pytest from tests.e2e.conftest import VllmRunner From b7619370daca47a6f605d68e4c122c3eb1a5c6cf Mon Sep 17 00:00:00 2001 From: Tflowers-0129 <2906339855@qq.com> Date: Fri, 16 Jan 2026 18:11:43 +0800 Subject: [PATCH 14/17] fix e2e test, offline test is ok! Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- tests/e2e/310p/test_offline_inference_310p.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/e2e/310p/test_offline_inference_310p.py b/tests/e2e/310p/test_offline_inference_310p.py index e4e82041519..e62b8026095 100644 --- a/tests/e2e/310p/test_offline_inference_310p.py +++ b/tests/e2e/310p/test_offline_inference_310p.py @@ -36,6 +36,7 @@ def test_llm_models(dtype: str, max_tokens: int) -> None: vllm_model.generate_greedy(example_prompts, max_tokens) +@pytest.mark.skip(reason="310P: multimodal test skipped, offline is ok") @pytest.mark.parametrize("dtype", ["float16"]) def test_multimodal_vl(dtype: str): image = ImageAsset("cherry_blossom").pil_image.convert("RGB") From ab6c78b676a956c4d4e0007358b96a5481edefdb Mon Sep 17 00:00:00 2001 From: Shaoxu Cheng <2906339855@qq.com> Date: Sat, 17 Jan 2026 00:23:22 +0800 Subject: [PATCH 15/17] rebase to soulte conflict Signed-off-by: Shaoxu Cheng <2906339855@qq.com> --- vllm_ascend/platform.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 57996ea52d3..cee3531adb2 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -326,9 +326,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if is_310p(): parallel_config.worker_cls = "vllm_ascend._310p.worker_310p.NPUWorker310" elif ascend_config.xlite_graph_config.enabled: - logger.info( - "openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite" - ) + logger.info("openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite") parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker" else: parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker" @@ -416,14 +414,10 @@ def get_attn_backend_cls(cls, selected_backend, attn_selector_config): # (True, True): "...AscendSFABackend310", } -<<<<<<< HEAD - return backend_map[(attn_selector_config.use_mla, attn_selector_config.use_sparse)] -======= if is_310p(): return backend_map_310.get(key, backend_map_310[(False, False)]) return backend_map[key] ->>>>>>> f95a1763 (support chunedprefilled state in 310p device) @classmethod def get_punica_wrapper(cls) -> str: From dd919b879323e60b90f61594cffef454bf1b22bd Mon Sep 17 00:00:00 2001 From: Shaoxu Cheng <2906339855@qq.com> Date: Sat, 17 Jan 2026 06:43:05 +0800 Subject: [PATCH 16/17] fix new ruff errors Signed-off-by: Shaoxu Cheng <2906339855@qq.com> --- vllm_ascend/_310p/attention/attention_mask.py | 5 +++-- vllm_ascend/_310p/attention/attention_v1.py | 15 +++++---------- vllm_ascend/_310p/attention/metadata_builder.py | 3 +-- vllm_ascend/_310p/modelrunner_310p.py | 8 ++++---- vllm_ascend/_310p/ops/mm_encoder_attention.py | 3 +-- vllm_ascend/patch/platform/patch_distributed.py | 7 +++---- vllm_ascend/utils.py | 13 ++++++++----- 7 files changed, 25 insertions(+), 29 deletions(-) diff --git a/vllm_ascend/_310p/attention/attention_mask.py b/vllm_ascend/_310p/attention/attention_mask.py index e9fb7189acf..9e7bcb292e4 100644 --- a/vllm_ascend/_310p/attention/attention_mask.py +++ b/vllm_ascend/_310p/attention/attention_mask.py @@ -15,7 +15,8 @@ # This file is a part of the vllm-ascend project. # -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import torch import torch_npu @@ -82,7 +83,7 @@ class _AttentionMaskBuilder310P: def __init__(self, device: torch.device): self._base = _BASE_BUILDER(device) - self._fp16_mask_cache: Optional[torch.Tensor] = None + self._fp16_mask_cache: torch.Tensor | None = None self._fp16_mask_cached_len: int = 0 def __getattr__(self, name: str) -> Any: diff --git a/vllm_ascend/_310p/attention/attention_v1.py b/vllm_ascend/_310p/attention/attention_v1.py index b1190d73bfa..0a5910e5dbd 100644 --- a/vllm_ascend/_310p/attention/attention_v1.py +++ b/vllm_ascend/_310p/attention/attention_v1.py @@ -19,16 +19,11 @@ import torch import torch_npu -from vllm_ascend._310p.attention.attention_mask import ( - AttentionMaskBuilder, build_splitfuse_attn_mask_310p) -from vllm_ascend._310p.attention.metadata_builder import \ - AscendAttentionMetadataBuilder310P -from vllm_ascend.attention.attention_v1 import \ - AscendAttentionBackend as _BaseBackend -from vllm_ascend.attention.attention_v1 import \ - AscendAttentionBackendImpl as _BaseImpl -from vllm_ascend.attention.attention_v1 import (AscendAttentionMetadataBuilder, - AscendAttentionState) +from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder, build_splitfuse_attn_mask_310p +from vllm_ascend._310p.attention.metadata_builder import AscendAttentionMetadataBuilder310P +from vllm_ascend.attention.attention_v1 import AscendAttentionBackend as _BaseBackend +from vllm_ascend.attention.attention_v1 import AscendAttentionBackendImpl as _BaseImpl +from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder, AscendAttentionState from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, aligned_16, nd_to_nz_2d diff --git a/vllm_ascend/_310p/attention/metadata_builder.py b/vllm_ascend/_310p/attention/metadata_builder.py index eba64962aa8..71c0c65050e 100644 --- a/vllm_ascend/_310p/attention/metadata_builder.py +++ b/vllm_ascend/_310p/attention/metadata_builder.py @@ -24,8 +24,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder -from vllm_ascend.attention.attention_v1 import \ - AscendAttentionMetadataBuilder as _BaseBuilder +from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder as _BaseBuilder class AscendAttentionMetadataBuilder310P(_BaseBuilder): diff --git a/vllm_ascend/_310p/modelrunner_310p.py b/vllm_ascend/_310p/modelrunner_310p.py index 3ab60c72d6f..31ec8dbea55 100644 --- a/vllm_ascend/_310p/modelrunner_310p.py +++ b/vllm_ascend/_310p/modelrunner_310p.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any import torch import torch_npu @@ -34,7 +34,7 @@ def __init__(self, *args, **kwargs): self._acl_format = ACL_FORMAT_FRACTAL_NZ def _initialize_kv_cache_tensors_310p( - self, kv_cache_config: "KVCacheConfig" + self, kv_cache_config: KVCacheConfig ) -> dict[str, Any]: if self.vllm_config.kv_transfer_config is not None: raise ValueError("KV cache transfer is not supported for 310P.") @@ -46,7 +46,7 @@ def _initialize_kv_cache_tensors_310p( ) kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size - kv_caches: Dict[str, Any] = {} + kv_caches: dict[str, Any] = {} for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec @@ -106,6 +106,6 @@ def _initialize_kv_cache_tensors_310p( return kv_caches def initialize_kv_cache_tensors( - self, kv_cache_config: "KVCacheConfig" + self, kv_cache_config: KVCacheConfig ) -> dict[str, Any]: return self._initialize_kv_cache_tensors_310p(kv_cache_config) diff --git a/vllm_ascend/_310p/ops/mm_encoder_attention.py b/vllm_ascend/_310p/ops/mm_encoder_attention.py index 1e36f933402..f1c4f89224c 100644 --- a/vllm_ascend/_310p/ops/mm_encoder_attention.py +++ b/vllm_ascend/_310p/ops/mm_encoder_attention.py @@ -22,8 +22,7 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ops.mm_encoder_attention import MAX_PAD_SIZE, MIN_PAD_SIZE -from vllm_ascend.ops.mm_encoder_attention import \ - AscendMMEncoderAttention as _Base +from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention as _Base class AscendMMEncoderAttention310(_Base): diff --git a/vllm_ascend/patch/platform/patch_distributed.py b/vllm_ascend/patch/platform/patch_distributed.py index 41620ead501..7170cf83b01 100644 --- a/vllm_ascend/patch/platform/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_distributed.py @@ -23,7 +23,6 @@ class NullHandle: - def __init__(self): pass @@ -83,10 +82,10 @@ def all_reduce( return all_reduce - torch.distributed.all_reduce = all_reduce_wrapper_310p( - torch.distributed.all_reduce) + torch.distributed.all_reduce = all_reduce_wrapper_310p(torch.distributed.all_reduce) torch.distributed.distributed_c10d.all_reduce = all_reduce_wrapper_310p( - torch.distributed.distributed_c10d.all_reduce) + torch.distributed.distributed_c10d.all_reduce + ) if get_ascend_device_type() == AscendDeviceType._310P: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index b08069bbb4b..7ac3913123f 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -77,6 +77,7 @@ def is_310p(): return get_ascend_device_type() == AscendDeviceType._310P + def _print_callback_on_stream(*args): """Callback function to print arguments on the dedicated print stream.""" global _GRAPH_PRINT_STREAM @@ -724,11 +725,13 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): AscendMRotaryEmbedding310, ) - REGISTERED_ASCEND_OPS.update({ - "SiluAndMul": AscendSiluAndMul310, - "MMEncoderAttention": AscendMMEncoderAttention310, - "MRotaryEmbedding": AscendMRotaryEmbedding310, - }) + REGISTERED_ASCEND_OPS.update( + { + "SiluAndMul": AscendSiluAndMul310, + "MMEncoderAttention": AscendMMEncoderAttention310, + "MRotaryEmbedding": AscendMRotaryEmbedding310, + } + ) for name, op_cls in REGISTERED_ASCEND_OPS.items(): CustomOp.register_oot(_decorated_op_cls=op_cls, name=name) From 2c7e4619eb6f04c1e670caa8e301f9c8a4e08962 Mon Sep 17 00:00:00 2001 From: Shaoxu Cheng <2906339855@qq.com> Date: Sat, 17 Jan 2026 07:18:54 +0800 Subject: [PATCH 17/17] fix some new ruff format errors Signed-off-by: Shaoxu Cheng <2906339855@qq.com> --- vllm_ascend/_310p/attention/attention_mask.py | 26 +++++-------------- vllm_ascend/_310p/attention/attention_v1.py | 20 +++++--------- vllm_ascend/_310p/modelrunner_310p.py | 21 ++++----------- vllm_ascend/_310p/ops/activation.py | 4 +-- vllm_ascend/_310p/ops/mm_encoder_attention.py | 10 ++----- .../patch/platform/patch_distributed.py | 7 ++--- 6 files changed, 23 insertions(+), 65 deletions(-) diff --git a/vllm_ascend/_310p/attention/attention_mask.py b/vllm_ascend/_310p/attention/attention_mask.py index 9e7bcb292e4..f0d42f31bfd 100644 --- a/vllm_ascend/_310p/attention/attention_mask.py +++ b/vllm_ascend/_310p/attention/attention_mask.py @@ -27,21 +27,15 @@ _BASE_BUILDER: Callable[[torch.device], Any] = _base_mask.AttentionMaskBuilder -def _gen_causal_additive_mask_fp16( - max_seq_len: int, device: torch.device -) -> torch.Tensor: - tril = torch.ones( - (max_seq_len, max_seq_len), dtype=torch.bool, device=device - ).tril_() +def _gen_causal_additive_mask_fp16(max_seq_len: int, device: torch.device) -> torch.Tensor: + tril = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool, device=device).tril_() upper = ~tril m = torch.zeros((max_seq_len, max_seq_len), dtype=torch.float16, device=device) m.masked_fill_(upper, float("-inf")) return m -def build_splitfuse_attn_mask_310p( - attn_metadata, device, *, full_mask_cache=None, full_mask_cache_len=0 -): +def build_splitfuse_attn_mask_310p(attn_metadata, device, *, full_mask_cache=None, full_mask_cache_len=0): qsl = attn_metadata.query_start_loc.detach().to("cpu", dtype=torch.int32) qlens = qsl[1:] - qsl[:-1] @@ -53,11 +47,7 @@ def build_splitfuse_attn_mask_310p( pos_list = [p for ql, cl in zip(q_list, c_list) for p in range(cl - ql, cl)] position = torch.tensor(pos_list, dtype=torch.long, device=device) - if ( - full_mask_cache is None - or full_mask_cache.device != device - or full_mask_cache_len < L - ): + if full_mask_cache is None or full_mask_cache.device != device or full_mask_cache_len < L: tril = torch.ones((L, L), dtype=torch.bool, device=device).tril_() full = torch.zeros((L, L), dtype=torch.float16, device=device) full.masked_fill_(~tril, float("-inf")) @@ -66,9 +56,7 @@ def build_splitfuse_attn_mask_310p( full = full_mask_cache[:L, :L].contiguous() rows = full.index_select(0, position).contiguous() - mask = torch_npu.npu_format_cast( - nd_to_nz_spec(rows).contiguous(), ACL_FORMAT_FRACTAL_NZ - ) + mask = torch_npu.npu_format_cast(nd_to_nz_spec(rows).contiguous(), ACL_FORMAT_FRACTAL_NZ) return mask, full_mask_cache, full_mask_cache_len @@ -95,9 +83,7 @@ def device(self) -> torch.device: def _get_fp16_mask(self, max_seq_len: int) -> torch.Tensor: if self._fp16_mask_cache is None or max_seq_len > self._fp16_mask_cached_len: - self._fp16_mask_cache = _gen_causal_additive_mask_fp16( - max_seq_len, self.device - ) + self._fp16_mask_cache = _gen_causal_additive_mask_fp16(max_seq_len, self.device) self._fp16_mask_cached_len = max_seq_len assert self._fp16_mask_cache is not None return self._fp16_mask_cache[:max_seq_len, :max_seq_len].contiguous() diff --git a/vllm_ascend/_310p/attention/attention_v1.py b/vllm_ascend/_310p/attention/attention_v1.py index 0a5910e5dbd..347c4bc1ffc 100644 --- a/vllm_ascend/_310p/attention/attention_v1.py +++ b/vllm_ascend/_310p/attention/attention_v1.py @@ -33,9 +33,7 @@ def __init__(self, *args, **kwargs): self.attn_mask_builder = AttentionMaskBuilder(self.device) @staticmethod - def get_kv_cache_shape( - num_blocks: int, block_size: int, num_kv_heads: int, head_size: int - ): + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int): # Align to a multiple of 16, as required by the 310P device. return (2, num_blocks, (num_kv_heads * head_size) // 16, block_size, 16) @@ -51,9 +49,7 @@ def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: class AscendAttentionBackendImpl310(_BaseImpl): def forward_paged_attention(self, query, attn_metadata, output): if attn_metadata.seq_lens.device != query.device: - attn_metadata.seq_lens = attn_metadata.seq_lens.to( - device=query.device, non_blocking=True - ) + attn_metadata.seq_lens = attn_metadata.seq_lens.to(device=query.device, non_blocking=True) return super().forward_paged_attention(query, attn_metadata, output) def _forward_prefill_310p_fallback(self, query, key, value, attn_metadata, output): @@ -122,13 +118,11 @@ def _forward_chunked_prefill_310p(self, query, attn_metadata, output): self._sf_full_mask_cache = None self._sf_full_mask_cache_len = 0 - mask, self._sf_full_mask_cache, self._sf_full_mask_cache_len = ( - build_splitfuse_attn_mask_310p( - attn_metadata, - query.device, - full_mask_cache=self._sf_full_mask_cache, - full_mask_cache_len=int(self._sf_full_mask_cache_len), - ) + mask, self._sf_full_mask_cache, self._sf_full_mask_cache_len = build_splitfuse_attn_mask_310p( + attn_metadata, + query.device, + full_mask_cache=self._sf_full_mask_cache, + full_mask_cache_len=int(self._sf_full_mask_cache_len), ) if qlens.device.type != "cpu": diff --git a/vllm_ascend/_310p/modelrunner_310p.py b/vllm_ascend/_310p/modelrunner_310p.py index 31ec8dbea55..e83ac39ce45 100644 --- a/vllm_ascend/_310p/modelrunner_310p.py +++ b/vllm_ascend/_310p/modelrunner_310p.py @@ -33,9 +33,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._acl_format = ACL_FORMAT_FRACTAL_NZ - def _initialize_kv_cache_tensors_310p( - self, kv_cache_config: KVCacheConfig - ) -> dict[str, Any]: + def _initialize_kv_cache_tensors_310p(self, kv_cache_config: KVCacheConfig) -> dict[str, Any]: if self.vllm_config.kv_transfer_config is not None: raise ValueError("KV cache transfer is not supported for 310P.") @@ -64,10 +62,7 @@ def _initialize_kv_cache_tensors_310p( num_blocks = tensor_size // kv_cache_spec.page_size_bytes assert num_blocks >= kv_cache_config.num_blocks - if ( - hasattr(attn_backend, "get_supported_block_size") - and self.use_hybrid_blocks - ): + if hasattr(attn_backend, "get_supported_block_size") and self.use_hybrid_blocks: block_size = attn_backend.get_supported_block_size()[0] block_size_chunk = kv_cache_spec.block_size // block_size kv_cache_shape = attn_backend.get_kv_cache_shape( @@ -87,12 +82,8 @@ def _initialize_kv_cache_tensors_310p( dtype = kv_cache_spec.dtype if "attn" in layer_name: - k_tensor = torch.zeros( - kv_cache_shape[1:], dtype=dtype, device=self.device - ) - v_tensor = torch.zeros( - kv_cache_shape[1:], dtype=dtype, device=self.device - ) + k_tensor = torch.zeros(kv_cache_shape[1:], dtype=dtype, device=self.device) + v_tensor = torch.zeros(kv_cache_shape[1:], dtype=dtype, device=self.device) k_cache = torch_npu.npu_format_cast(k_tensor, self._acl_format) v_cache = torch_npu.npu_format_cast(v_tensor, self._acl_format) kv_caches[layer_name] = (k_cache, v_cache) @@ -105,7 +96,5 @@ def _initialize_kv_cache_tensors_310p( ) return kv_caches - def initialize_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig - ) -> dict[str, Any]: + def initialize_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, Any]: return self._initialize_kv_cache_tensors_310p(kv_cache_config) diff --git a/vllm_ascend/_310p/ops/activation.py b/vllm_ascend/_310p/ops/activation.py index ede10a3ae89..73d409cfd78 100644 --- a/vllm_ascend/_310p/ops/activation.py +++ b/vllm_ascend/_310p/ops/activation.py @@ -25,8 +25,6 @@ class AscendSiluAndMul310(_Base): def forward(self, x: torch.Tensor) -> torch.Tensor: torch.ops.vllm.maybe_prefetch_mlp_down_proj(x) h = x.shape[-1] // 2 - out = (F.silu(x[..., :h].to(torch.float32)) * x[..., h:].to(torch.float32)).to( - torch.float16 - ) + out = (F.silu(x[..., :h].to(torch.float32)) * x[..., h:].to(torch.float32)).to(torch.float16) torch.ops.vllm.maybe_wait_prefetch_done(out) return out diff --git a/vllm_ascend/_310p/ops/mm_encoder_attention.py b/vllm_ascend/_310p/ops/mm_encoder_attention.py index f1c4f89224c..ebe33558cc0 100644 --- a/vllm_ascend/_310p/ops/mm_encoder_attention.py +++ b/vllm_ascend/_310p/ops/mm_encoder_attention.py @@ -43,11 +43,7 @@ def forward_oot( q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len) - enable_pad = ( - envs_ascend.USE_OPTIMIZED_MODEL - and self.head_size > MIN_PAD_SIZE - and self.head_size < MAX_PAD_SIZE - ) + enable_pad = envs_ascend.USE_OPTIMIZED_MODEL and self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE origin_shape = q.shape[-1] if enable_pad: @@ -104,7 +100,5 @@ def forward_oot( st = ed context_flat = context_flat[..., :origin_dim] - context_layer = einops.rearrange( - context_flat, "(b s) h d -> b s h d", b=bsz - ).contiguous() + context_layer = einops.rearrange(context_flat, "(b s) h d -> b s h d", b=bsz).contiguous() return context_layer diff --git a/vllm_ascend/patch/platform/patch_distributed.py b/vllm_ascend/patch/platform/patch_distributed.py index 7170cf83b01..8c085cce4f3 100644 --- a/vllm_ascend/patch/platform/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_distributed.py @@ -50,13 +50,10 @@ def broadcast310p(tensor, src=0, group=None, async_op=False, group_src=None): return broadcast310p - torch.distributed.broadcast = broadcast310p_wrapper( - torch.distributed.broadcast) - torch.distributed.distributed_c10d.broadcast = broadcast310p_wrapper( - torch.distributed.distributed_c10d.broadcast) + torch.distributed.broadcast = broadcast310p_wrapper(torch.distributed.broadcast) + torch.distributed.distributed_c10d.broadcast = broadcast310p_wrapper(torch.distributed.distributed_c10d.broadcast) def all_reduce_wrapper_310p(fn): - def all_reduce( tensor, op=torch.distributed.ReduceOp.SUM,