Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7766f8d
[bugfix] fix xlite error: has no attribute 'query_start_loc_cpu'
Dec 23, 2025
2015711
[bugfix] fix xlite error: has no attribute 'query_start_loc_cpu'
Dec 23, 2025
32a1cfc
Merge branch 'vllm-project:main' into refactor_attention
weijinqian0 Dec 24, 2025
5432405
[Refactor] use cos_sin_cache
Dec 24, 2025
e58e977
[Refactor] use cos_sin_cache
Dec 24, 2025
09ea370
[Refactor] use cos_sin_cache
Dec 24, 2025
5bfb03a
[Refactor] use cos_sin_cache & remove parameter like model in builder.
Dec 24, 2025
03630b8
[Refactor] use cos_sin_cache & remove parameter like model in builder.
Dec 24, 2025
e6dd46c
Merge branch 'vllm-project:main' into refactor_attention
weijinqian0 Dec 24, 2025
946971e
[Refactor] use cos_sin_cache & remove parameter like model in builder.
Dec 24, 2025
4e3095c
[Refactor] use cos_sin_cache & remove parameter like model in builder.
Dec 24, 2025
45e184f
[Refactor] use cos_sin_cache & remove parameter like model in builder.
Dec 24, 2025
31abe7a
[Refactor] use cos_sin_cache & remove parameter like model in builder.
Dec 24, 2025
607384d
[Refactor] use cos_sin_cache & remove parameter like model in builder.
Dec 25, 2025
cef4b3f
[Refactor] use cos_sin_cache & remove parameter like model in builder.
Dec 25, 2025
35d4c89
[Refactor] use cos_sin_cache & remove parameter like model in builder.
Dec 25, 2025
c79b784
[Refactor] use cos_sin_cache & remove parameter like model in builder.
Dec 25, 2025
f5a795e
[Refactor] use cos_sin_cache & remove parameter like model in builder.
Dec 25, 2025
2604468
[Refactor] use cos_sin_cache & remove parameter like model in builder.
Dec 25, 2025
6aef9b3
Merge branch 'main' into refactor_attention
weijinqian0 Dec 25, 2025
2546349
Merge branch 'main' into refactor_attention
weijinqian0 Dec 25, 2025
1814fc8
Merge branch 'main' into refactor_attention
weijinqian0 Dec 26, 2025
e1cf263
[Refactor] use cos_sin_cache & remove parameter like model in builder.
Dec 26, 2025
ac20bdf
[Refactor] use cos_sin_cache & remove parameter like model in builder.
Dec 26, 2025
dd13b53
Merge branch 'main' into refactor_attention
weijinqian0 Dec 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 36 additions & 25 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,14 +289,16 @@ def test_ascend_mla_metadata_builder_spec_decode(self, mock_dcp,
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.enable_chunked_prefill)

@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_ascend_mla_metadata_builder_build_full_graph(
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group):
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
Expand Down Expand Up @@ -330,7 +332,6 @@ def test_ascend_mla_metadata_builder_build_full_graph(
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
mock_device)
common_metadata = MagicMock()
model = MagicMock()
common_metadata.graph_pad_size = 8
common_metadata.num_reqs = 4
common_metadata.num_actual_tokens = 5
Expand All @@ -343,7 +344,9 @@ def test_ascend_mla_metadata_builder_build_full_graph(
block_table = torch.Tensor([[1, 0], [2, 0], [3, 0], [4, 0]]).int()
common_metadata.block_table_tensor = block_table
common_metadata.prefill_context_parallel_metadata = None
metadata = builder.build(0, common_metadata, model)
mock_get_cos_and_sin_mla.return_value = (torch.tensor([6, 6]),
torch.Tensor([6, 6]))
metadata = builder.build(0, common_metadata)

self.assertEqual(metadata.decode.actual_seq_lengths_q,
[1, 2, 4, 5, 6, 6, 7, 8])
Expand Down Expand Up @@ -526,6 +529,7 @@ def setUp(self):
self.kv_cache_spec.head_size = 128
self.kv_cache_spec.num_heads = 32

@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
Expand All @@ -534,7 +538,8 @@ def setUp(self):
@patch("torch.npu.is_available")
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
mock_zeros, mock_dcp_world_size,
mock_get_pcp_group):
mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_npu_available.return_value = False
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
Expand Down Expand Up @@ -579,9 +584,9 @@ def zeros_override(*args, **kwargs):
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)

mock_model = MagicMock()
metadata = builder.build(1, common_attn_metadata, mock_model)
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
torch.Tensor(10))
metadata = builder.build(1, common_attn_metadata)

self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens,
Expand All @@ -590,6 +595,7 @@ def zeros_override(*args, **kwargs):
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)

@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
Expand All @@ -598,7 +604,8 @@ def zeros_override(*args, **kwargs):
@patch("torch.npu.is_available")
def test_build_chunked_prefix_metadata(self, mock_npu_available,
mock_zeros, mock_dcp_world_size,
mock_get_pcp_group):
mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_npu_available.return_value = False
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
Expand Down Expand Up @@ -644,9 +651,9 @@ def zeros_override(*args, **kwargs):
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)

mock_model = MagicMock()
metadata = builder.build(1, common_attn_metadata, mock_model)
mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
torch.Tensor(10))
metadata = builder.build(1, common_attn_metadata)

self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens,
Expand All @@ -655,11 +662,13 @@ def zeros_override(*args, **kwargs):
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)

@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_build_decode_only_metadata(self, mock_dcp_world_size,
mock_get_pcp_group):
mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa

Expand Down Expand Up @@ -697,9 +706,9 @@ def test_build_decode_only_metadata(self, mock_dcp_world_size,
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)

mock_model = MagicMock()
metadata = builder.build(1, common_attn_metadata, mock_model)
mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]),
torch.Tensor([10, 10]))
metadata = builder.build(1, common_attn_metadata)

self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens,
Expand All @@ -708,11 +717,13 @@ def test_build_decode_only_metadata(self, mock_dcp_world_size,
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)

@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size,
mock_get_pcp_group):
mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa

Expand Down Expand Up @@ -750,10 +761,10 @@ def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size,
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)

mock_model = MagicMock()
mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]),
torch.Tensor([10, 10]))
metadata = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.DecodeOnly, mock_model)
common_attn_metadata, AscendAttentionState.DecodeOnly)

self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens,
Expand All @@ -762,11 +773,13 @@ def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size,
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)

@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_build_for_graph_capture_prefill(self, mock_dcp_world_size,
mock_get_pcp_group):
mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
Expand Down Expand Up @@ -795,13 +808,11 @@ def test_build_for_graph_capture_prefill(self, mock_dcp_world_size,
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)

mock_model = MagicMock()

mock_get_cos_and_sin_mla.return_value = (torch.tensor(10),
torch.Tensor(10))
with self.assertRaises(NotImplementedError) as ctx:
builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.PrefillNoCache,
mock_model)
common_attn_metadata, AscendAttentionState.PrefillNoCache)
self.assertIn(
"Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state",
str(ctx.exception))
Expand Down
21 changes: 10 additions & 11 deletions tests/ut/attention/test_sfa_v1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import torch
from vllm.v1.attention.backends.utils import AttentionCGSupport
Expand Down Expand Up @@ -102,7 +102,8 @@ def test_ascend_sfa_metadata_builder_default(self):
assert builder.device == device
assert builder.vllm_config == vllm_config

def test_ascend_sfa_metadata_builder_build(self):
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
Expand Down Expand Up @@ -133,21 +134,21 @@ def test_ascend_sfa_metadata_builder_build(self):
common_attn_metadata.sin = None
common_attn_metadata.num_input_tokens = 100

model = MagicMock()
model.model.layers = [MagicMock() for _ in range(10)]
model.model.start_layer = 0
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
torch.randn(100))

metadata = builder.build(
common_prefix_len=10,
common_attn_metadata=common_attn_metadata,
model=model,
)

assert isinstance(metadata, AscendSFAMetadata)
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
assert metadata.slot_mapping.shape == (100, 4, 1024)

def test_ascend_sfa_metadata_builder_build_for_graph_capture(self):
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
self, mock_get_cos_and_sin_mla):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
Expand Down Expand Up @@ -178,14 +179,12 @@ def test_ascend_sfa_metadata_builder_build_for_graph_capture(self):
common_attn_metadata.sin = None
common_attn_metadata.num_input_tokens = 100

model = MagicMock()
model.model.layers = [MagicMock() for _ in range(10)]
model.model.start_layer = 0
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
torch.randn(100))

attn_metadata = builder.build_for_graph_capture(
common_attn_metadata=common_attn_metadata,
attn_state=AscendAttentionState.DecodeOnly,
model=model,
)

assert isinstance(attn_metadata, AscendSFAMetadata)
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/attention/attention_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_npu
from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group,
Expand Down Expand Up @@ -90,7 +89,7 @@ def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: Optional[nn.Module] = None,
fast_build: bool = False,
):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
Expand Down
11 changes: 5 additions & 6 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import ClassVar, List, Optional, Tuple, Type

import torch
import torch.nn as nn
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
Expand All @@ -29,7 +28,8 @@
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec

Expand Down Expand Up @@ -170,7 +170,7 @@ class AscendMetadata:
model_runner_type: str = ""


class AscendAttentionMetadataBuilder:
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS
Expand Down Expand Up @@ -217,8 +217,8 @@ def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: Optional[nn.Module] = None,
):
fast_build: bool = False,
) -> AscendMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
Expand Down Expand Up @@ -261,7 +261,6 @@ def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
model: Optional[nn.Module] = None,
):
if attn_state == AscendAttentionState.DecodeOnly:
attn_metadata = self.build(
Expand Down
30 changes: 14 additions & 16 deletions vllm_ascend/attention/mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.distributed as dist
import torch_npu
from torch import nn
from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank,
Expand Down Expand Up @@ -50,14 +49,17 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
understand this class
"""

def __init__(self,
kv_cache_spec: MLAAttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[AscendMLAMetadata] = None):
def __init__(
self,
kv_cache_spec: MLAAttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: type[AscendMLAMetadata] | None = None,
supports_dcp_with_varlen: bool = False,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
metadata_cls)
metadata_cls, supports_dcp_with_varlen)

self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
Expand Down Expand Up @@ -92,7 +94,6 @@ def build_cp_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendPCPMetadata | None:
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert common_long_seq_metadata is not None
Expand Down Expand Up @@ -121,10 +122,9 @@ def build_chunked_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
):
chunked_context_metadata = super().build_chunked_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)
if chunked_context_metadata is None:
return None

Expand Down Expand Up @@ -205,12 +205,11 @@ def build_prefill_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLAPrefillMetadata:
prefill_metadata = super().build_prefill_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)
prefill_metadata.pcp_metadata = self.build_cp_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)
prefill_metadata.block_table = self.block_table[
self.num_decodes_flatten:, ...]
return prefill_metadata
Expand All @@ -219,10 +218,9 @@ def build_decode_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLADecodeMetadata:
decode_metadata = super().build_decode_metadata(
common_prefix_len, common_attn_metadata, model)
common_prefix_len, common_attn_metadata)

long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert long_seq_metadata is not None
Expand Down
Loading
Loading