Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions tests/ut/attention/test_sfa_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from unittest.mock import MagicMock

import torch
from vllm.v1.attention.backends.utils import AttentionCGSupport

from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.sfa_v1 import (AscendSFABackend, AscendSFAImpl,
AscendSFAMetadata,
AscendSFAMetadataBuilder)


class TestAscendSFABackend(TestBase):

def test_get_name(self):
self.assertEqual(AscendSFABackend.get_name(), "ASCEND_SFA")

def test_get_metadata_cls(self):
self.assertEqual(AscendSFABackend.get_metadata_cls(),
AscendSFAMetadata)

def test_get_builder_cls(self):
self.assertEqual(AscendSFABackend.get_builder_cls(),
AscendSFAMetadataBuilder)

def test_get_kv_cache_shape(self):
result = AscendSFABackend.get_kv_cache_shape(2, 4, 8, 128)
self.assertEqual(result, (2, 4, 8, 128))

def test_get_impl_cls(self):
result = AscendSFABackend.get_impl_cls()
self.assertEqual(result, AscendSFAImpl)


class TestAscendSFAMetadata(TestBase):

def test_ascend_sfa_metadata_default(self):
has_prefill = True
num_actual_tokens = 100
slot_mapping = torch.randn(100, 4, 1024)
seq_lens = torch.tensor([30, 50])
cum_query_lens = torch.tensor([0, 30, 80])
block_tables = torch.randint(0, 100, (100, 4))

rope_dim = 32
max_seq_len = int(seq_lens.max().item())
sin = torch.randn(max_seq_len, rope_dim)
cos = torch.randn(max_seq_len, rope_dim)

num_input_tokens = 2
head_dim = None
attn_mask = None
attn_state = AscendAttentionState.ChunkedPrefill

metadata = AscendSFAMetadata(
has_prefill=has_prefill,
num_actual_tokens=num_actual_tokens,
slot_mapping=slot_mapping,
seq_lens=seq_lens,
cum_query_lens=cum_query_lens,
block_tables=block_tables,
sin=sin,
cos=cos,
num_input_tokens=num_input_tokens,
head_dim=head_dim,
attn_mask=attn_mask,
attn_state=attn_state,
)

self.assertEqual(metadata.has_prefill, has_prefill)
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
self.assertIs(metadata.slot_mapping, slot_mapping)
self.assertTrue(torch.equal(metadata.seq_lens, seq_lens))
self.assertTrue(torch.equal(metadata.cum_query_lens, cum_query_lens))
self.assertIs(metadata.block_tables, block_tables)
self.assertIs(metadata.sin, sin)
self.assertIs(metadata.cos, cos)
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
self.assertIs(metadata.head_dim, head_dim)
self.assertIs(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.attn_state, attn_state)


class TestAscendSFAMetadataBuilder(TestBase):

def test_ascend_sfa_metadata_builder_default(self):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
device = torch.device("cpu")

builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device)

assert builder.aclgraph_support == AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
assert builder.device == device
assert builder.vllm_config == vllm_config

def test_ascend_sfa_metadata_builder_build(self):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
device = torch.device("cpu")

builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device)

common_attn_metadata = MagicMock()
common_attn_metadata.num_reqs = 10
common_attn_metadata.num_actual_tokens = 100
common_attn_metadata.query_start_loc = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.query_start_loc_cpu = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.slot_mapping = torch.randn(100, 4, 1024)
common_attn_metadata.seq_lens_cpu = torch.tensor([2] * 10)
common_attn_metadata.positions = torch.randn(100)
common_attn_metadata.attn_mask = None
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
common_attn_metadata.block_table_tensor = torch.randn(100, 4)

model = MagicMock()
model.model.layers = [MagicMock() for _ in range(10)]
model.model.start_layer = 0

metadata = builder.build(
common_prefix_len=10,
common_attn_metadata=common_attn_metadata,
model=model,
)

assert isinstance(metadata, AscendSFAMetadata)
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
assert metadata.slot_mapping.shape == (100, 4, 1024)

def test_ascend_sfa_metadata_builder_build_for_graph_capture(self):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
device = torch.device("cpu")

builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device)

common_attn_metadata = MagicMock()
common_attn_metadata.num_reqs = 10
common_attn_metadata.num_actual_tokens = 100
common_attn_metadata.query_start_loc = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.query_start_loc_cpu = torch.tensor(
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
common_attn_metadata.slot_mapping = torch.randn(100, 4, 1024)
common_attn_metadata.seq_lens_cpu = torch.tensor([2] * 10)
common_attn_metadata.positions = torch.randn(100)
common_attn_metadata.attn_mask = None
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
common_attn_metadata.block_table_tensor = torch.randn(100, 4)

model = MagicMock()
model.model.layers = [MagicMock() for _ in range(10)]
model.model.start_layer = 0

attn_metadata = builder.build_for_graph_capture(
common_attn_metadata=common_attn_metadata,
attn_state=AscendAttentionState.DecodeOnly,
model=model,
)

assert isinstance(attn_metadata, AscendSFAMetadata)
assert attn_metadata.attn_state == AscendAttentionState.DecodeOnly
22 changes: 21 additions & 1 deletion vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class AscendSFAMetadata:
class AscendSFAMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
Expand Down Expand Up @@ -189,6 +189,26 @@ def build(
sin=sin,
cos=cos)

def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
model: Optional[nn.Module] = None,
):
if attn_state == AscendAttentionState.DecodeOnly:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
model=model,
)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly state"
)

attn_metadata.attn_state = attn_state
return attn_metadata


class AscendSFAImpl(MLAAttentionImpl):
"""
Expand Down
15 changes: 10 additions & 5 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,7 +1864,8 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
)

forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \
and not self.use_sparse:
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
if self.vllm_config.model_config.use_mla:
if self.pcp_size * self.dcp_size > 1:
Expand Down Expand Up @@ -2657,11 +2658,15 @@ def _build_dummy_attn_metadata(
[0] * dcp_world_size for _ in range(pcp_world_size)
] for _ in range(num_tokens)]
long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor(
if self.speculative_config:
query_start_loc = torch.tensor(
[0] + self.actual_seq_lengths_q[:num_reqs],
device=self.device,
dtype=torch.int32),
dtype=torch.int32)
else:
query_start_loc = self.query_start_loc[:num_reqs + 1]
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
1],
seq_lens_cpu=self.seq_lens_cpu,
Expand Down Expand Up @@ -2707,7 +2712,7 @@ def _generate_dummy_run_hidden_states(self, with_prefill,
forward_context = get_forward_context()
assert forward_context is not None
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
not forward_context.capturing:
not forward_context.capturing and not self.use_sparse:
if self.vllm_config.model_config.use_mla:
# FIXME: Try using `auto_dispatch_capture=True`
if self.pcp_size * self.dcp_size > 1:
Expand Down
Loading