Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,38 @@ def test_ascend_mla_metadata_default(self):

class TestAscendMLAMetadataBuilder(TestBase):

def setUp(self):
# Mock parent class __init__ to avoid complex initialization,
# but still set the essential attributes that child class needs
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
device, metadata_cls, supports_dcp_with_varlen):
self.metadata_cls = metadata_cls
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.device = device
self.chunked_prefill_workspace_size = 128 * 1024
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
vllm_config.model_config.get_head_size()),
dtype=vllm_config.model_config.dtype,
device=device,
)

self.parent_init_patcher = patch(
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
mock_parent_init)
self.parent_init_patcher.start()

def tearDown(self):
self.parent_init_patcher.stop()

def test_ascend_mla_metadata_builder_default(self):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
Expand All @@ -238,6 +265,7 @@ def test_ascend_mla_metadata_builder_spec_decode(self):
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
Expand Down Expand Up @@ -274,10 +302,12 @@ def test_ascend_mla_metadata_builder_build_full_graph(
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
mock_device = 'cpu'
torch.Tensor.pin_memory = lambda x: x # noqa

Expand Down Expand Up @@ -314,6 +344,9 @@ def test_reorder_batch(self):

mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
Expand Down Expand Up @@ -352,10 +385,12 @@ def test_pad_actual_seq_lens_q_mtp_disable_pad(self):
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
mock_device = 'cpu'
mock_vllm_config.speculative_config = None

Expand All @@ -374,10 +409,12 @@ def test_pad_actual_seq_lens_q_mtp_enable_pad(self):
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
mock_device = 'cpu'
mock_vllm_config.speculative_config = None

Expand All @@ -398,11 +435,34 @@ def test_pad_actual_seq_lens_q_mtp_enable_pad(self):
class TestAscendMLAMetadataBuilderBuild(TestBase):

def setUp(self):
# Mock parent class __init__ to avoid complex initialization,
# but still set the essential attributes that child class needs
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
device, metadata_cls, supports_dcp_with_varlen):
self.metadata_cls = metadata_cls
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.device = device
self.chunked_prefill_workspace_size = 128 * 1024
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
vllm_config.model_config.get_head_size()),
dtype=vllm_config.model_config.dtype,
device=device,
)

self.parent_init_patcher = patch(
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
mock_parent_init)
self.parent_init_patcher.start()

self.mock_vllm_config = MagicMock(spec=VllmConfig)
self.mock_vllm_config.cache_config = CacheConfig(block_size=32)
mock_scheduler_config = MagicMock(spec=SchedulerConfig)
mock_scheduler_config.max_num_seqs = 8
mock_scheduler_config.chunked_prefill_enabled = True
mock_scheduler_config.enable_chunked_prefill = True
self.mock_vllm_config.scheduler_config = mock_scheduler_config
self.mock_vllm_config.speculative_config = None
self.mock_device = torch.device("cpu")
Expand All @@ -423,6 +483,9 @@ def setUp(self):
self.kv_cache_spec.head_size = 64
self.kv_cache_spec.num_heads = 32

def tearDown(self):
self.parent_init_patcher.stop()

@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
@patch('vllm.distributed.parallel_state.get_pcp_group')
Expand Down
41 changes: 41 additions & 0 deletions tests/ut/attention/test_sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,44 @@ def setUp(self):
return_value=self.mock_cfg)
self.patcher.start()

# Mock parent class __init__ to avoid complex initialization,
# but still set the essential attributes that child class needs
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
device, metadata_cls, supports_dcp_with_varlen):
self.metadata_cls = metadata_cls
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.device = device
self.chunked_prefill_workspace_size = 128 * 1024
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
vllm_config.model_config.get_head_size()),
dtype=vllm_config.model_config.dtype,
device=device,
)

self.parent_init_patcher = patch(
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
mock_parent_init)
self.parent_init_patcher.start()

if hasattr(enable_dsa_cp, "cache_clear"):
enable_dsa_cp.cache_clear()

def tearDown(self):
self.patcher.stop()
self.parent_init_patcher.stop()

def test_ascend_sfa_metadata_builder_default(self):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
vllm_config.cache_config.block_size = 16
vllm_config.model_config.max_model_len = 1024
vllm_config.model_config.get_head_size.return_value = 64
vllm_config.model_config.dtype = torch.float16
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
Expand Down Expand Up @@ -138,6 +169,11 @@ def test_ascend_sfa_metadata_builder_build(
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
vllm_config.cache_config.block_size = 16
vllm_config.model_config.max_model_len = 1024
vllm_config.model_config.get_head_size.return_value = 64
vllm_config.model_config.dtype = torch.float16
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
Expand Down Expand Up @@ -190,6 +226,11 @@ def test_ascend_sfa_metadata_builder_build_for_graph_capture(
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
vllm_config.cache_config.block_size = 16
vllm_config.model_config.max_model_len = 1024
vllm_config.model_config.get_head_size.return_value = 64
vllm_config.model_config.dtype = torch.float16
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
Expand Down
49 changes: 15 additions & 34 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.context_parallel.common_cp import (
AscendPCPMetadata, CPChunkedContextMetadata)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
enable_cp,
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.attention.utils import (
AscendCommonAttentionMetadata, ascend_chunked_prefill_workspace_size,
enable_cp, maybe_save_kv_layer_to_connector, split_decodes_and_prefills,
trans_rope_weight, transdata, wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (
get_draft_graph_params, get_graph_params,
update_draft_graph_params_workspaces, update_graph_params_workspaces)
Expand Down Expand Up @@ -215,11 +213,11 @@ def __init__(
metadata_cls: type[AscendMLAMetadata] | None = None,
supports_dcp_with_varlen: bool = False,
):
self.metadata_cls = (metadata_cls if metadata_cls is not None else
AscendMLAMetadata)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
super().__init__(
kv_cache_spec, layer_names, vllm_config, device,
metadata_cls if metadata_cls is not None else AscendMLAMetadata,
supports_dcp_with_varlen)

scheduler_config = vllm_config.scheduler_config
self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len +
Expand All @@ -236,29 +234,7 @@ def __init__(
got {self.decode_threshold}"

self.reorder_batch_threshold = self.decode_threshold
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(8 * self.model_config.max_model_len,
4 * scheduler_config.max_num_seqs * self.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * self.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)

self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
Expand All @@ -280,6 +256,11 @@ def __init__(
self.seq_lens: torch.Tensor = None
self.attn_mask_builder = AttentionMaskBuilder(self.device)

@staticmethod
def determine_chunked_prefill_workspace_size(
vllm_config: VllmConfig) -> int:
return ascend_chunked_prefill_workspace_size(vllm_config)

@classmethod
def get_cudagraph_support(
cls: type["AscendMLAMetadataBuilder"],
Expand Down
17 changes: 11 additions & 6 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
ascend_chunked_prefill_workspace_size,
maybe_save_kv_layer_to_connector,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
Expand Down Expand Up @@ -131,7 +132,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
understand this class
"""

# _attn_mask_builder = None
def __init__(
self,
kv_cache_spec,
Expand All @@ -141,11 +141,11 @@ def __init__(
metadata_cls: type[AscendSFAMetadata] | None = None,
supports_dcp_with_varlen: bool = False,
):
self.metadata_cls = (metadata_cls if metadata_cls is not None else
AscendSFAMetadata)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
super().__init__(
kv_cache_spec, layer_names, vllm_config, device,
metadata_cls if metadata_cls is not None else AscendSFAMetadata,
supports_dcp_with_varlen)

self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len +
self.block_size - 1) // self.block_size
Expand All @@ -169,6 +169,11 @@ def __init__(
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
self.attn_mask_builder = AttentionMaskBuilder(self.device)

@staticmethod
def determine_chunked_prefill_workspace_size(
vllm_config: VllmConfig) -> int:
return ascend_chunked_prefill_workspace_size(vllm_config)

@classmethod
def get_cudagraph_support(
cls: type["AscendSFAMetadataBuilder"],
Expand Down
28 changes: 28 additions & 0 deletions vllm_ascend/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,34 @@
from vllm_ascend.utils import AscendDeviceType, get_ascend_config, get_ascend_device_type


def ascend_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
scheduler_config = vllm_config.scheduler_config
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config

chunked_prefill_workspace_size = min(
# Make sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(8 * model_config.max_model_len, 4 * scheduler_config.max_num_seqs * cache_config.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 128k tokens,
# which would result in the workspace being:
# 2*(576)*(128*1024) = 288mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(128*1024) = 6gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024,
)

chunked_prefill_workspace_size = max(
chunked_prefill_workspace_size,
scheduler_config.max_num_seqs * cache_config.block_size,
)

return chunked_prefill_workspace_size


def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
if vllm_config.speculative_config is not None:
return False
Expand Down