From 7766f8d72e99b42fae01bae40146465de9145bb0 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 23 Dec 2025 13:32:54 +0800 Subject: [PATCH 01/19] [bugfix] fix xlite error: has no attribute 'query_start_loc_cpu' Signed-off-by: weijinqian_v1 --- vllm_ascend/attention/attention_v1.py | 5 ++++- vllm_ascend/xlite/xlite.py | 4 +--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 001d58fbebb..a2fbaea75f3 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -29,7 +29,7 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.attention.backends.utils import AttentionCGSupport, AttentionMetadataBuilder from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec @@ -146,6 +146,7 @@ class AscendMetadata: actual_seq_lengths_q: List[int] = None # type: ignore query_start_loc: torch.Tensor = None + query_lens: torch.Tensor = None # Maximum query length in the batch (None for decoding). max_query_len: Optional[int] = None @@ -229,6 +230,7 @@ def build( split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) block_table = common_attn_metadata.block_table_tensor + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] @@ -244,6 +246,7 @@ def build( num_decode_tokens=num_decode_tokens, block_tables=block_table, query_start_loc=query_start_loc, + query_lens=query_lens, seq_lens=seq_lens, seq_lens_list=seq_lens.tolist(), max_query_len=common_attn_metadata.max_query_len, diff --git a/vllm_ascend/xlite/xlite.py b/vllm_ascend/xlite/xlite.py index 462052d7e45..b594b7edfe7 100644 --- a/vllm_ascend/xlite/xlite.py +++ b/vllm_ascend/xlite/xlite.py @@ -257,9 +257,7 @@ def __call__( if not with_prefill or self.full_mode: batch = attn_metadata.num_prefills + attn_metadata.num_decodes seq_lens = attn_metadata.seq_lens[:batch] - query_lens = attn_metadata.query_start_loc_cpu[ - 1:] - attn_metadata.query_start_loc_cpu[:-1] - query_lens = query_lens[:batch] + query_lens = attn_metadata.query_lens[:batch] cached_lens = seq_lens - query_lens xlite_attn_metadata = ModelAttnMeta() From 2015711bf027713de783057b8c5dff295c69c2c8 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 23 Dec 2025 13:34:45 +0800 Subject: [PATCH 02/19] [bugfix] fix xlite error: has no attribute 'query_start_loc_cpu' Signed-off-by: weijinqian_v1 --- vllm_ascend/attention/attention_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index a2fbaea75f3..8a8942fde90 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -29,7 +29,7 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import AttentionCGSupport, AttentionMetadataBuilder +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec From 54324054b1a89faec9e6a6b06ecf673f47509d4f Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 24 Dec 2025 12:52:44 +0800 Subject: [PATCH 03/19] [Refactor] use cos_sin_cache Signed-off-by: weijinqian_v1 --- vllm_ascend/attention/attention_v1.py | 14 ++--- vllm_ascend/attention/mla_v1.py | 86 ++++++++++----------------- vllm_ascend/attention/sfa_v1.py | 23 ++++--- vllm_ascend/ops/rotary_embedding.py | 22 ++++++- 4 files changed, 69 insertions(+), 76 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index e1dd41cc6dc..0c14bea01ce 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -29,7 +29,7 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.attention.backends.utils import AttentionCGSupport, AttentionMetadataBuilder from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec @@ -171,7 +171,7 @@ class AscendMetadata: model_runner_type: str = "" -class AscendAttentionMetadataBuilder: +class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): # Does this backend/builder support ACL Graphs for attention (default: no). aclgraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.ALWAYS @@ -215,11 +215,11 @@ def reorder_batch(self, input_batch, return False def build( - self, - common_prefix_len: int, - common_attn_metadata: AscendCommonAttentionMetadata, - model: Optional[nn.Module] = None, - ): + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + fast_build: bool = False, + ) -> AscendMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index b77682dbd2f..3c0a4c0eb17 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -400,16 +400,17 @@ def build( self.slot_mapping = common_attn_metadata.slot_mapping[:self. num_actual_tokens] - if self.cos_cache is None: - self.cos_cache = model.model.layers[ - model.model.start_layer].self_attn.rotary_emb.cos_cached - self.sin_cache = model.model.layers[ - model.model.start_layer].self_attn.rotary_emb.sin_cached - if self.cos_cache.dtype != self.model_config.dtype: # type: ignore - self.cos_cache = self.cos_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - self.sin_cache = self.sin_cache.to( # type: ignore - self.model_config.dtype) # type: ignore + # if self.cos_cache is None: + # self.cos_cache = model.model.layers[ + # model.model.start_layer].self_attn.rotary_emb.cos_cached + # self.sin_cache = model.model.layers[ + # model.model.start_layer].self_attn.rotary_emb.sin_cached + # if self.cos_cache.dtype != self.model_config.dtype: # type: ignore + # self.cos_cache = self.cos_cache.to( # type: ignore + # self.model_config.dtype) # type: ignore + # self.sin_cache = self.sin_cache.to( # type: ignore + # self.model_config.dtype) # type: ignore + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] self.query_lens = query_seq_lens_cpu[:num_reqs] @@ -539,12 +540,13 @@ def build_prefill_metadata( reqs_start:] - query_start_loc[reqs_start] prefill_input_positions = input_positions[tokens_start:] - cos = self.cos_cache[ - prefill_input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[ - prefill_input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) + cos, sin = get_cos_and_sin_mla(prefill_input_positions) + # cos = self.cos_cache[ + # prefill_input_positions].unsqueeze( # type: ignore + # 1).unsqueeze(2) + # sin = self.sin_cache[ + # prefill_input_positions].unsqueeze( # type: ignore + # 1).unsqueeze(2) return AscendMLAPrefillMetadata( attn_mask=common_attn_metadata.attn_mask, query_lens=self.query_lens[reqs_start:].to(torch.int32), @@ -573,7 +575,6 @@ def build_decode_metadata( num_actual_tokens].long( ) - cos, sin = get_cos_and_sin_mla() # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario actual_seq_lengths_q = query_start_loc_cpu[1:self.num_decodes + 1].tolist() @@ -643,44 +644,19 @@ def build_decode_metadata( # TODO: After the fullgraph supports MTP, the if branch needs to deleted assert self.cos_cache is not None assert self.sin_cache is not None - if cos is None and sin is None: - cos = self.cos_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - - decode_metadata = AscendMLADecodeMetadata( - input_positions=input_positions, - block_table=self.block_table, - seq_lens=self.seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=max_seq_lens, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin, - cos=cos, - cp_seq_len=cp_seq_len, - batch_seq_mask=batch_seq_mask) - else: - cos[:self.num_decode_tokens, - ...] = self.cos_cache[input_positions].unsqueeze(1).unsqueeze( - 2) - sin[:self.num_decode_tokens, - ...] = self.sin_cache[input_positions].unsqueeze(1).unsqueeze( - 2) - - decode_metadata = AscendMLADecodeMetadata( - input_positions=input_positions, - block_table=self.block_table, - seq_lens=self.seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=max_seq_lens, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin[:self.num_decode_tokens, ...], - cos=cos[:self.num_decode_tokens, ...], - cp_seq_len=cp_seq_len, - batch_seq_mask=batch_seq_mask) + cos, sin = get_cos_and_sin_mla(input_positions) + decode_metadata = AscendMLADecodeMetadata( + input_positions=input_positions, + block_table=self.block_table, + seq_lens=self.seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin[:self.num_decode_tokens, ...], + cos=cos[:self.num_decode_tokens, ...], + cp_seq_len=cp_seq_len, + batch_seq_mask=batch_seq_mask) return decode_metadata def build_for_graph_capture( diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 48aac26c11f..31cc6388ce9 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -192,19 +192,18 @@ def build( cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1] seq_lens = common_attn_metadata.seq_lens[:num_reqs] - cos, sin = get_cos_and_sin_mla() - assert self.cos_cache is not None and self.sin_cache is not None - new_cos = self.cos_cache[input_positions][:, None, None] - new_sin = self.sin_cache[input_positions][:, None, None] - - if (cos is not None and sin is not None - and num_input_tokens <= cos.shape[0] - and num_input_tokens <= sin.shape[0]): - cos[:num_input_tokens] = new_cos - sin[:num_input_tokens] = new_sin - else: - cos, sin = new_cos, new_sin + # new_cos = self.cos_cache[input_positions][:, None, None] + # new_sin = self.sin_cache[input_positions][:, None, None] + + cos, sin = get_cos_and_sin_mla(input_positions) + # if (cos is not None and sin is not None + # and num_input_tokens <= cos.shape[0] + # and num_input_tokens <= sin.shape[0]): + # cos[:num_input_tokens] = new_cos + # sin[:num_input_tokens] = new_sin + # else: + # cos, sin = new_cos, new_sin sfa_cp_context = None if self.enable_sfa_cp: diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 12995575269..c0ef91830ec 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -42,6 +42,8 @@ # by different approaches. _cos_mla: Optional[torch.Tensor] = None _sin_mla: Optional[torch.Tensor] = None +_cos_cache: Optional[torch.Tensor] = None +_sin_cache: Optional[torch.Tensor] = None _cos_sin_cache: Optional[torch.Tensor] = None _cos: Optional[torch.Tensor] = None _sin: Optional[torch.Tensor] = None @@ -101,8 +103,15 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, device=device) -def get_cos_and_sin_mla(): - return _cos_mla, _sin_mla +def get_cos_and_sin_mla(positions): + global _cos_cache + global _sin_cache + global _cos_mla + global _sin_mla + num_tokens = positions.size(0) + _cos_mla[:num_tokens, ...] = _cos_cache[positions].unsqueeze(1).unsqueeze(2) + _sin_mla[:num_tokens, ...] = _sin_cache[positions].unsqueeze(1).unsqueeze(2) + return _cos_mla[:num_tokens, ...], _sin_mla[:num_tokens, ...] def _record_cos_sin_cache(cos_sin_cache): @@ -112,6 +121,13 @@ def _record_cos_sin_cache(cos_sin_cache): _cos_sin_cache = cos_sin_cache +def _record_cos_and_sin_cache(cos_cache, sin_cache): + global _cos_cache + global _sin_cache + _cos_cache = cos_cache + _sin_cache = sin_cache + + def update_cos_sin(positions): global _cos global _sin @@ -469,6 +485,8 @@ def _set_cos_sin_cache(self, max_seq_len, device, dtype): self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_cached", cos_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False) + _record_cos_sin_cache(cache) + _record_cos_and_sin_cache(cos_cached, sin_cached) def forward(self, positions: torch.Tensor, From e58e977e112694987547d1380da9a00d087e9c6f Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 24 Dec 2025 14:06:59 +0800 Subject: [PATCH 04/19] [Refactor] use cos_sin_cache Signed-off-by: weijinqian_v1 --- vllm_ascend/attention/mla_v1.py | 3 --- vllm_ascend/ops/rotary_embedding.py | 27 +++++++++++++-------------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 3c0a4c0eb17..a1d20313205 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -641,9 +641,6 @@ def build_decode_metadata( num_reqs_pad_size, num_reqs, actual_seq_lengths_q, common_attn_metadata) - # TODO: After the fullgraph supports MTP, the if branch needs to deleted - assert self.cos_cache is not None - assert self.sin_cache is not None cos, sin = get_cos_and_sin_mla(input_positions) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index c0ef91830ec..febf6c70e8f 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -69,20 +69,19 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens if model_config.use_mla: - if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: - rope_dim = model_config.hf_text_config.qk_rope_head_dim - _cos_mla = torch.ones(max_num_reqs * decode_token_per_req, - 1, - 1, - rope_dim, - dtype=dtype, - device=device) - _sin_mla = torch.zeros(max_num_reqs * decode_token_per_req, - 1, - 1, - rope_dim, - dtype=dtype, - device=device) + rope_dim = model_config.hf_text_config.qk_rope_head_dim + _cos_mla = torch.ones(max_num_reqs * decode_token_per_req, + 1, + 1, + rope_dim, + dtype=dtype, + device=device) + _sin_mla = torch.zeros(max_num_reqs * decode_token_per_req, + 1, + 1, + rope_dim, + dtype=dtype, + device=device) elif not is_vl_model(vllm_config) and has_rope(vllm_config): rope_dim = model_config.get_head_size() # For models using partial rope like Qwen3-Next. From 09ea3703d7aea942ec6adea7acf8ecc4a50a9e66 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 24 Dec 2025 14:09:45 +0800 Subject: [PATCH 05/19] [Refactor] use cos_sin_cache Signed-off-by: weijinqian_v1 --- vllm_ascend/ops/rotary_embedding.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index febf6c70e8f..f2a8be8ca39 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -64,19 +64,18 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, _sin is not None: return - compilation_config = vllm_config.compilation_config model_config = vllm_config.model_config max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens if model_config.use_mla: rope_dim = model_config.hf_text_config.qk_rope_head_dim - _cos_mla = torch.ones(max_num_reqs * decode_token_per_req, + _cos_mla = torch.ones(max_num_batched_tokens, 1, 1, rope_dim, dtype=dtype, device=device) - _sin_mla = torch.zeros(max_num_reqs * decode_token_per_req, + _sin_mla = torch.zeros(max_num_batched_tokens, 1, 1, rope_dim, From 5bfb03abc5513db386a6b9c42216de414b1c42ac Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 24 Dec 2025 15:14:28 +0800 Subject: [PATCH 06/19] [Refactor] use cos_sin_cache & remove parameter like model in builder. Signed-off-by: weijinqian_v1 --- tests/ut/attention/test_mla_v1.py | 8 +-- tests/ut/attention/test_sfa_v1.py | 5 -- vllm_ascend/attention/attention_cp.py | 2 +- vllm_ascend/attention/attention_v1.py | 1 - vllm_ascend/attention/mla_cp.py | 12 ++--- vllm_ascend/attention/mla_v1.py | 65 +++++++++---------------- vllm_ascend/attention/sfa_v1.py | 53 ++++++-------------- vllm_ascend/spec_decode/mtp_proposer.py | 4 +- vllm_ascend/worker/model_runner_v1.py | 21 ++------ 9 files changed, 53 insertions(+), 118 deletions(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 06c5dc6d327..05fb984dbf9 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -751,9 +751,8 @@ def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size, vllm_config=self.mock_vllm_config, device=self.mock_device) - mock_model = MagicMock() metadata = builder.build_for_graph_capture( - common_attn_metadata, AscendAttentionState.DecodeOnly, mock_model) + common_attn_metadata, AscendAttentionState.DecodeOnly) self.assertIsInstance(metadata, AscendMLAMetadata) self.assertEqual(metadata.num_actual_tokens, @@ -796,12 +795,9 @@ def test_build_for_graph_capture_prefill(self, mock_dcp_world_size, vllm_config=self.mock_vllm_config, device=self.mock_device) - mock_model = MagicMock() - with self.assertRaises(NotImplementedError) as ctx: builder.build_for_graph_capture( - common_attn_metadata, AscendAttentionState.PrefillNoCache, - mock_model) + common_attn_metadata, AscendAttentionState.PrefillNoCache) self.assertIn( "Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state", str(ctx.exception)) diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index caa8cec6e62..c4f9502d1c9 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -178,14 +178,9 @@ def test_ascend_sfa_metadata_builder_build_for_graph_capture(self): common_attn_metadata.sin = None common_attn_metadata.num_input_tokens = 100 - model = MagicMock() - model.model.layers = [MagicMock() for _ in range(10)] - model.model.start_layer = 0 - attn_metadata = builder.build_for_graph_capture( common_attn_metadata=common_attn_metadata, attn_state=AscendAttentionState.DecodeOnly, - model=model, ) assert isinstance(attn_metadata, AscendSFAMetadata) diff --git a/vllm_ascend/attention/attention_cp.py b/vllm_ascend/attention/attention_cp.py index 22c58369727..ff7242e1e68 100644 --- a/vllm_ascend/attention/attention_cp.py +++ b/vllm_ascend/attention/attention_cp.py @@ -90,7 +90,7 @@ def build( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: Optional[nn.Module] = None, + fast_build: bool = False, ): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 0c14bea01ce..2502dbfa621 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -264,7 +264,6 @@ def build_for_graph_capture( self, common_attn_metadata: AscendCommonAttentionMetadata, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, - model: Optional[nn.Module] = None, ): if attn_state == AscendAttentionState.DecodeOnly: attn_metadata = self.build( diff --git a/vllm_ascend/attention/mla_cp.py b/vllm_ascend/attention/mla_cp.py index 0a3aed14f18..4a9bcada180 100644 --- a/vllm_ascend/attention/mla_cp.py +++ b/vllm_ascend/attention/mla_cp.py @@ -92,7 +92,6 @@ def build_cp_metadata( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, ) -> AscendPCPMetadata | None: common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata assert common_long_seq_metadata is not None @@ -121,10 +120,9 @@ def build_chunked_metadata( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, ): chunked_context_metadata = super().build_chunked_metadata( - common_prefix_len, common_attn_metadata, model) + common_prefix_len, common_attn_metadata) if chunked_context_metadata is None: return None @@ -205,12 +203,11 @@ def build_prefill_metadata( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, ) -> AscendMLAPrefillMetadata: prefill_metadata = super().build_prefill_metadata( - common_prefix_len, common_attn_metadata, model) + common_prefix_len, common_attn_metadata) prefill_metadata.pcp_metadata = self.build_cp_metadata( - common_prefix_len, common_attn_metadata, model) + common_prefix_len, common_attn_metadata) prefill_metadata.block_table = self.block_table[ self.num_decodes_flatten:, ...] return prefill_metadata @@ -219,10 +216,9 @@ def build_decode_metadata( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, ) -> AscendMLADecodeMetadata: decode_metadata = super().build_decode_metadata( - common_prefix_len, common_attn_metadata, model) + common_prefix_len, common_attn_metadata) long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata assert long_seq_metadata is not None diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index a1d20313205..cdc5907af29 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -13,8 +13,9 @@ from vllm.logger import logger from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.utils.math_utils import cdiv, round_down -from vllm.v1.attention.backends.utils import AttentionCGSupport -from vllm.v1.kv_cache_interface import MLAAttentionSpec +from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder +from vllm.v1.attention.backends.utils import AttentionCGSupport, AttentionMetadataBuilder +from vllm.v1.kv_cache_interface import MLAAttentionSpec, AttentionSpec from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config @@ -177,7 +178,7 @@ def __post_init__(self): M = TypeVar("M", bound=AscendMLAMetadata) -class AscendMLAMetadataBuilder: +class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): # Does this backend/builder support ACL Graphs for attention (default: no). aclgraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_BATCH @@ -185,15 +186,18 @@ class AscendMLAMetadataBuilder: NOTE: Please read the comment at the top of the file before trying to understand this class """ - - def __init__(self, - kv_cache_spec: MLAAttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: Optional[AscendMLAMetadata] = None): - self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \ - if metadata_cls is not None else AscendMLAMetadata # type: ignore + def __init__( + self, + kv_cache_spec: MLAAttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: type[AscendMLAMetadata] | None = None, + supports_dcp_with_varlen: bool = False, + ): + self.metadata_cls = ( + metadata_cls if metadata_cls is not None else AscendMLAMetadata + ) self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.device = device @@ -381,10 +385,10 @@ def set_num_actual_tokens( self.num_actual_tokens = common_attn_metadata.num_actual_tokens def build( - self, - common_prefix_len: int, - common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + fast_build: bool = False, ) -> AscendMLAMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc @@ -400,18 +404,6 @@ def build( self.slot_mapping = common_attn_metadata.slot_mapping[:self. num_actual_tokens] - # if self.cos_cache is None: - # self.cos_cache = model.model.layers[ - # model.model.start_layer].self_attn.rotary_emb.cos_cached - # self.sin_cache = model.model.layers[ - # model.model.start_layer].self_attn.rotary_emb.sin_cached - # if self.cos_cache.dtype != self.model_config.dtype: # type: ignore - # self.cos_cache = self.cos_cache.to( # type: ignore - # self.model_config.dtype) # type: ignore - # self.sin_cache = self.sin_cache.to( # type: ignore - # self.model_config.dtype) # type: ignore - - query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] self.query_lens = query_seq_lens_cpu[:num_reqs] self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] @@ -421,12 +413,12 @@ def build( prefill_metadata = None if self.num_prefills > 0: prefill_metadata = self.build_prefill_metadata( - common_prefix_len, common_attn_metadata, model) + common_prefix_len, common_attn_metadata) decode_metadata = None if self.num_decodes > 0: decode_metadata = self.build_decode_metadata( - common_prefix_len, common_attn_metadata, model) + common_prefix_len, common_attn_metadata) return self.metadata_cls( # type: ignore num_actual_tokens_pcp_padded=self.num_actual_tokens, @@ -451,7 +443,6 @@ def build_chunked_metadata( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, ): if not self.chunked_prefill_enabled: return None @@ -521,7 +512,6 @@ def build_prefill_metadata( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, ) -> AscendMLAPrefillMetadata: query_start_loc = common_attn_metadata.query_start_loc @@ -531,7 +521,7 @@ def build_prefill_metadata( ) chunked_context_metadata = self.build_chunked_metadata( - common_prefix_len, common_attn_metadata, model) + common_prefix_len, common_attn_metadata) reqs_start = self.num_decodes # prefill_start tokens_start = self.num_decode_tokens max_query_len = self.query_lens[reqs_start:].max().item() @@ -541,12 +531,6 @@ def build_prefill_metadata( prefill_input_positions = input_positions[tokens_start:] cos, sin = get_cos_and_sin_mla(prefill_input_positions) - # cos = self.cos_cache[ - # prefill_input_positions].unsqueeze( # type: ignore - # 1).unsqueeze(2) - # sin = self.sin_cache[ - # prefill_input_positions].unsqueeze( # type: ignore - # 1).unsqueeze(2) return AscendMLAPrefillMetadata( attn_mask=common_attn_metadata.attn_mask, query_lens=self.query_lens[reqs_start:].to(torch.int32), @@ -566,7 +550,6 @@ def build_decode_metadata( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, ) -> AscendMLADecodeMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu @@ -660,7 +643,6 @@ def build_for_graph_capture( self, common_attn_metadata: AscendCommonAttentionMetadata, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, - model: Optional[nn.Module] = None, ): if attn_state in { AscendAttentionState.DecodeOnly, @@ -669,7 +651,6 @@ def build_for_graph_capture( attn_metadata = self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, - model=model, ) else: raise NotImplementedError( diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 31cc6388ce9..ef1c2751d8d 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -12,7 +12,9 @@ from vllm.model_executor.layers.linear import (ReplicatedLinear, UnquantizedLinearMethod) from vllm.triton_utils import HAS_TRITON +from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config @@ -107,7 +109,7 @@ class AscendSFAMetadata: M = TypeVar("M", bound=AscendSFAMetadata) -class AscendSFAMetadataBuilder: +class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): # Does this backend/builder support ACL Graphs for attention (default: no). aclgraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE @@ -117,14 +119,18 @@ class AscendSFAMetadataBuilder: """ # _attn_mask_builder = None - def __init__(self, - kv_cache_spec, - layer_names, - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: Optional[AscendSFAMetadata] = None): - self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \ - if metadata_cls is not None else AscendSFAMetadata # type: ignore + def __init__( + self, + kv_cache_spec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: type[AscendSFAMetadata] | None = None, + supports_dcp_with_varlen: bool = False, + ): + self.metadata_cls = ( + metadata_cls if metadata_cls is not None else AscendSFAMetadata + ) self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.device = device @@ -142,9 +148,6 @@ def __init__(self, got {self.decode_threshold}" self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim - self.cos_cache = None - self.sin_cache = None - self.enable_sfa_cp = enable_sp() and \ hasattr(self.model_config.hf_config, "index_topk") @@ -163,7 +166,7 @@ def build( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, + fast_build: bool = False, ) -> AscendSFAMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens @@ -178,32 +181,10 @@ def build( query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] has_prefill = any(query_lens_cpu > self.decode_threshold) - if self.cos_cache is None: - self.cos_cache = model.model.layers[ - model.model.start_layer].self_attn.rotary_emb.cos_cached - self.sin_cache = model.model.layers[ - model.model.start_layer].self_attn.rotary_emb.sin_cached - if self.cos_cache.dtype != self.model_config.dtype: # type: ignore - self.cos_cache = self.cos_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - self.sin_cache = self.sin_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1] seq_lens = common_attn_metadata.seq_lens[:num_reqs] - assert self.cos_cache is not None and self.sin_cache is not None - # new_cos = self.cos_cache[input_positions][:, None, None] - # new_sin = self.sin_cache[input_positions][:, None, None] - cos, sin = get_cos_and_sin_mla(input_positions) - # if (cos is not None and sin is not None - # and num_input_tokens <= cos.shape[0] - # and num_input_tokens <= sin.shape[0]): - # cos[:num_input_tokens] = new_cos - # sin[:num_input_tokens] = new_sin - # else: - # cos, sin = new_cos, new_sin sfa_cp_context = None if self.enable_sfa_cp: @@ -298,7 +279,6 @@ def build_for_graph_capture( self, common_attn_metadata: AscendCommonAttentionMetadata, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, - model: Optional[nn.Module] = None, ): if attn_state in { AscendAttentionState.DecodeOnly, @@ -307,7 +287,6 @@ def build_for_graph_capture( attn_metadata = self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, - model=model, ) else: raise NotImplementedError( diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 66dd65bdd77..7a472b26c9c 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -279,8 +279,8 @@ def dummy_run(self, builder = self.runner.attn_groups[0][0].get_metadata_builder() attn_metadata_mtp = builder.build_for_graph_capture( - common_attn_metadata, AscendAttentionState.SpecDecoding, - self.runner.get_model()) + common_attn_metadata, AscendAttentionState.SpecDecoding + ) attn_metadata = {} for layer_name in self.attn_layer_name: attn_metadata[layer_name] = attn_metadata_mtp diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c048a577836..6be48424f37 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1106,21 +1106,10 @@ def _prepare_inputs( num_decode_draft_tokens_cpu=self. num_decode_draft_tokens.cpu[:num_reqs], ) - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args) - elif self.model_config.runner_type == "pooling": - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args) - else: - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - model=self.get_model(), - **extra_attn_metadata_args) + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -1910,7 +1899,7 @@ def _build_dummy_attn_metadata( common_metadata) else: attn_metadata_full_attention = builder.build_for_graph_capture( - common_attn_metadata, attn_state, self.get_model()) + common_attn_metadata, attn_state) for layer_name in kv_cache_group_spec.layer_names: if "linear_attn" in layer_name: attn_metadata[ From 03630b8e5d80e029c04455fa201e1ffe48b17257 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 24 Dec 2025 15:15:01 +0800 Subject: [PATCH 07/19] [Refactor] use cos_sin_cache & remove parameter like model in builder. Signed-off-by: weijinqian_v1 --- vllm_ascend/attention/attention_cp.py | 1 - vllm_ascend/attention/attention_v1.py | 12 ++++----- vllm_ascend/attention/mla_cp.py | 1 - vllm_ascend/attention/mla_v1.py | 33 ++++++++++++------------- vllm_ascend/attention/sfa_v1.py | 20 +++++++-------- vllm_ascend/ops/rotary_embedding.py | 7 +++--- vllm_ascend/spec_decode/mtp_proposer.py | 3 +-- 7 files changed, 36 insertions(+), 41 deletions(-) diff --git a/vllm_ascend/attention/attention_cp.py b/vllm_ascend/attention/attention_cp.py index ff7242e1e68..759fb674620 100644 --- a/vllm_ascend/attention/attention_cp.py +++ b/vllm_ascend/attention/attention_cp.py @@ -20,7 +20,6 @@ import numpy as np import torch import torch.distributed as dist -import torch.nn as nn import torch_npu from vllm.config import VllmConfig from vllm.distributed import (get_dcp_group, diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 2502dbfa621..5870c39b502 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -20,7 +20,6 @@ from typing import ClassVar, List, Optional, Tuple, Type import torch -import torch.nn as nn import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) @@ -29,7 +28,8 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import AttentionCGSupport, AttentionMetadataBuilder +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec @@ -215,10 +215,10 @@ def reorder_batch(self, input_batch, return False def build( - self, - common_prefix_len: int, - common_attn_metadata: AscendCommonAttentionMetadata, - fast_build: bool = False, + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + fast_build: bool = False, ) -> AscendMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens diff --git a/vllm_ascend/attention/mla_cp.py b/vllm_ascend/attention/mla_cp.py index 4a9bcada180..96007c1bebb 100644 --- a/vllm_ascend/attention/mla_cp.py +++ b/vllm_ascend/attention/mla_cp.py @@ -4,7 +4,6 @@ import torch import torch.distributed as dist import torch_npu -from torch import nn from vllm.config import VllmConfig from vllm.distributed import (get_dcp_group, get_decode_context_model_parallel_rank, diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index cdc5907af29..e0b8c6e934e 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -5,7 +5,6 @@ import numpy as np import torch import torch_npu -from torch import nn from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig, get_current_vllm_config @@ -14,8 +13,8 @@ from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.utils.math_utils import cdiv, round_down from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder -from vllm.v1.attention.backends.utils import AttentionCGSupport, AttentionMetadataBuilder -from vllm.v1.kv_cache_interface import MLAAttentionSpec, AttentionSpec +from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.kv_cache_interface import MLAAttentionSpec from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config @@ -186,18 +185,18 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): NOTE: Please read the comment at the top of the file before trying to understand this class """ + def __init__( - self, - kv_cache_spec: MLAAttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: type[AscendMLAMetadata] | None = None, - supports_dcp_with_varlen: bool = False, + self, + kv_cache_spec: MLAAttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: type[AscendMLAMetadata] | None = None, + supports_dcp_with_varlen: bool = False, ): - self.metadata_cls = ( - metadata_cls if metadata_cls is not None else AscendMLAMetadata - ) + self.metadata_cls = (metadata_cls if metadata_cls is not None else + AscendMLAMetadata) self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.device = device @@ -385,10 +384,10 @@ def set_num_actual_tokens( self.num_actual_tokens = common_attn_metadata.num_actual_tokens def build( - self, - common_prefix_len: int, - common_attn_metadata: AscendCommonAttentionMetadata, - fast_build: bool = False, + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + fast_build: bool = False, ) -> AscendMLAMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index ef1c2751d8d..58dcc52b341 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -14,7 +14,6 @@ from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder from vllm.v1.attention.backends.utils import AttentionCGSupport -from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config @@ -120,17 +119,16 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): # _attn_mask_builder = None def __init__( - self, - kv_cache_spec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: type[AscendSFAMetadata] | None = None, - supports_dcp_with_varlen: bool = False, + self, + kv_cache_spec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: type[AscendSFAMetadata] | None = None, + supports_dcp_with_varlen: bool = False, ): - self.metadata_cls = ( - metadata_cls if metadata_cls is not None else AscendSFAMetadata - ) + self.metadata_cls = (metadata_cls if metadata_cls is not None else + AscendSFAMetadata) self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.device = device diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index f2a8be8ca39..bcb58d2daad 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -21,7 +21,6 @@ import einops import torch import torch_npu -from vllm.config import CUDAGraphMode from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, YaRNScalingRotaryEmbedding) @@ -107,8 +106,10 @@ def get_cos_and_sin_mla(positions): global _cos_mla global _sin_mla num_tokens = positions.size(0) - _cos_mla[:num_tokens, ...] = _cos_cache[positions].unsqueeze(1).unsqueeze(2) - _sin_mla[:num_tokens, ...] = _sin_cache[positions].unsqueeze(1).unsqueeze(2) + _cos_mla[:num_tokens, + ...] = _cos_cache[positions].unsqueeze(1).unsqueeze(2) + _sin_mla[:num_tokens, + ...] = _sin_cache[positions].unsqueeze(1).unsqueeze(2) return _cos_mla[:num_tokens, ...], _sin_mla[:num_tokens, ...] diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 7a472b26c9c..94137272c54 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -279,8 +279,7 @@ def dummy_run(self, builder = self.runner.attn_groups[0][0].get_metadata_builder() attn_metadata_mtp = builder.build_for_graph_capture( - common_attn_metadata, AscendAttentionState.SpecDecoding - ) + common_attn_metadata, AscendAttentionState.SpecDecoding) attn_metadata = {} for layer_name in self.attn_layer_name: attn_metadata[layer_name] = attn_metadata_mtp From 946971e2013346db9d71c2b5cede36fb000ccf4c Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 24 Dec 2025 15:54:34 +0800 Subject: [PATCH 08/19] [Refactor] use cos_sin_cache & remove parameter like model in builder. Signed-off-by: weijinqian_v1 --- vllm_ascend/attention/attention_v1.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 5870c39b502..2e9fc57ea39 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -146,7 +146,6 @@ class AscendMetadata: actual_seq_lengths_q: List[int] = None # type: ignore query_start_loc: torch.Tensor = None - query_lens: torch.Tensor = None # Maximum query length in the batch (None for decoding). max_query_len: Optional[int] = None @@ -230,7 +229,6 @@ def build( split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) block_table = common_attn_metadata.block_table_tensor - query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] @@ -246,7 +244,6 @@ def build( num_decode_tokens=num_decode_tokens, block_tables=block_table, query_start_loc=query_start_loc, - query_lens=query_lens, seq_lens=seq_lens, seq_lens_list=seq_lens.tolist(), max_query_len=common_attn_metadata.max_query_len, From 4e3095cfbb180cdf30b2b8bc91f4e114bb366f5b Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 24 Dec 2025 16:25:35 +0800 Subject: [PATCH 09/19] [Refactor] use cos_sin_cache & remove parameter like model in builder. Signed-off-by: weijinqian_v1 --- vllm_ascend/xlite/xlite.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/xlite/xlite.py b/vllm_ascend/xlite/xlite.py index b594b7edfe7..28b3da21634 100644 --- a/vllm_ascend/xlite/xlite.py +++ b/vllm_ascend/xlite/xlite.py @@ -257,7 +257,9 @@ def __call__( if not with_prefill or self.full_mode: batch = attn_metadata.num_prefills + attn_metadata.num_decodes seq_lens = attn_metadata.seq_lens[:batch] - query_lens = attn_metadata.query_lens[:batch] + query_lens = attn_metadata.query_start_loc_cpu[ + 1:] - attn_metadata.query_start_loc_cpu[:-1] + query_lens = query_lens[:batch] cached_lens = seq_lens - query_lens xlite_attn_metadata = ModelAttnMeta() From 45e184f4c56f489d9f0bd2756da582ac701ba1bc Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 24 Dec 2025 16:26:49 +0800 Subject: [PATCH 10/19] [Refactor] use cos_sin_cache & remove parameter like model in builder. Signed-off-by: weijinqian_v1 --- vllm_ascend/xlite/xlite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/xlite/xlite.py b/vllm_ascend/xlite/xlite.py index 28b3da21634..462052d7e45 100644 --- a/vllm_ascend/xlite/xlite.py +++ b/vllm_ascend/xlite/xlite.py @@ -258,7 +258,7 @@ def __call__( batch = attn_metadata.num_prefills + attn_metadata.num_decodes seq_lens = attn_metadata.seq_lens[:batch] query_lens = attn_metadata.query_start_loc_cpu[ - 1:] - attn_metadata.query_start_loc_cpu[:-1] + 1:] - attn_metadata.query_start_loc_cpu[:-1] query_lens = query_lens[:batch] cached_lens = seq_lens - query_lens From 31abe7a9b042cfecb32e42a0ed0561eca0935da5 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 24 Dec 2025 19:12:21 +0800 Subject: [PATCH 11/19] [Refactor] use cos_sin_cache & remove parameter like model in builder. Signed-off-by: weijinqian_v1 --- vllm_ascend/attention/mla_cp.py | 17 ++++++++++------- vllm_ascend/ops/rotary_embedding.py | 18 +++++++++--------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/vllm_ascend/attention/mla_cp.py b/vllm_ascend/attention/mla_cp.py index 96007c1bebb..4ce90cb12d7 100644 --- a/vllm_ascend/attention/mla_cp.py +++ b/vllm_ascend/attention/mla_cp.py @@ -49,14 +49,17 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): understand this class """ - def __init__(self, - kv_cache_spec: MLAAttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: Optional[AscendMLAMetadata] = None): + def __init__( + self, + kv_cache_spec: MLAAttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: type[AscendMLAMetadata] | None = None, + supports_dcp_with_varlen: bool = False, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device, - metadata_cls) + metadata_cls, supports_dcp_with_varlen) self.pcp_size = get_pcp_group().world_size self.pcp_rank = get_pcp_group( diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index bcb58d2daad..09a85f0e5e9 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -39,15 +39,15 @@ # AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by # attn_metadata. This causes that rope in GQA models must pass cos && sin # by different approaches. -_cos_mla: Optional[torch.Tensor] = None -_sin_mla: Optional[torch.Tensor] = None -_cos_cache: Optional[torch.Tensor] = None -_sin_cache: Optional[torch.Tensor] = None -_cos_sin_cache: Optional[torch.Tensor] = None -_cos: Optional[torch.Tensor] = None -_sin: Optional[torch.Tensor] = None -_cos_slice: Optional[torch.Tensor] = None -_sin_slice: Optional[torch.Tensor] = None +_cos_mla: torch.Tensor = None +_sin_mla: torch.Tensor = None +_cos_cache: torch.Tensor = None +_sin_cache: torch.Tensor = None +_cos_sin_cache: torch.Tensor = None +_cos: torch.Tensor = None +_sin: torch.Tensor = None +_cos_slice: torch.Tensor = None +_sin_slice: torch.Tensor = None def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, From 607384df3d25a85555f5f8f2d7b0eed10d34f61f Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Thu, 25 Dec 2025 13:04:46 +0800 Subject: [PATCH 12/19] [Refactor] use cos_sin_cache & remove parameter like model in builder. Signed-off-by: weijinqian_v1 --- tests/ut/attention/test_sfa_v1.py | 3 ++- vllm_ascend/attention/mla_v1.py | 2 +- vllm_ascend/ops/rotary_embedding.py | 12 +++++++----- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index c4f9502d1c9..8c210efb57a 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -1,5 +1,5 @@ import sys -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import torch from vllm.v1.attention.backends.utils import AttentionCGSupport @@ -147,6 +147,7 @@ def test_ascend_sfa_metadata_builder_build(self): assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens assert metadata.slot_mapping.shape == (100, 4, 1024) + @patch("vllm_ascend.device_allocator.camem.init_module") def test_ascend_sfa_metadata_builder_build_for_graph_capture(self): kv_cache_spec = MagicMock() layer_names = ["layer1", "layer2"] diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index e0b8c6e934e..1d3df0e133b 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -623,7 +623,7 @@ def build_decode_metadata( num_reqs_pad_size, num_reqs, actual_seq_lengths_q, common_attn_metadata) - cos, sin = get_cos_and_sin_mla(input_positions) + cos, sin = get_cos_and_sin_mla(input_positions, use_cache=True) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=self.block_table, diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 09a85f0e5e9..63aa3e28e2c 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -100,16 +100,18 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, device=device) -def get_cos_and_sin_mla(positions): +def get_cos_and_sin_mla(positions, use_cache=False): global _cos_cache global _sin_cache + cos = _cos_cache[positions].unsqueeze(1).unsqueeze(2) + sin = _sin_cache[positions].unsqueeze(1).unsqueeze(2) + if not use_cache: + return cos, sin global _cos_mla global _sin_mla num_tokens = positions.size(0) - _cos_mla[:num_tokens, - ...] = _cos_cache[positions].unsqueeze(1).unsqueeze(2) - _sin_mla[:num_tokens, - ...] = _sin_cache[positions].unsqueeze(1).unsqueeze(2) + _cos_mla[:num_tokens, ...] = cos + _sin_mla[:num_tokens, ...] = sin return _cos_mla[:num_tokens, ...], _sin_mla[:num_tokens, ...] From cef4b3fd6581b47c890a3bb3de4a7f1000a5185d Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Thu, 25 Dec 2025 13:27:03 +0800 Subject: [PATCH 13/19] [Refactor] use cos_sin_cache & remove parameter like model in builder. Signed-off-by: weijinqian_v1 --- tests/ut/attention/test_mla_v1.py | 24 ++++++++++++++++++------ tests/ut/attention/test_sfa_v1.py | 8 +++++--- vllm_ascend/ops/rotary_embedding.py | 1 + 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 05fb984dbf9..4dd12545f49 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -289,6 +289,7 @@ def test_ascend_mla_metadata_builder_spec_decode(self, mock_dcp, builder.chunked_prefill_enabled, mock_vllm_config.scheduler_config.enable_chunked_prefill) + @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch('vllm.distributed.parallel_state._PCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @@ -296,7 +297,8 @@ def test_ascend_mla_metadata_builder_spec_decode(self, mock_dcp, @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) def test_ascend_mla_metadata_builder_build_full_graph( - self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group): + self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group, + mock_get_cos_and_sin_mla): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -526,6 +528,7 @@ def setUp(self): self.kv_cache_spec.head_size = 128 self.kv_cache_spec.num_heads = 32 + @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) @@ -534,7 +537,8 @@ def setUp(self): @patch("torch.npu.is_available") def test_build_prefix_no_cache_metadata(self, mock_npu_available, mock_zeros, mock_dcp_world_size, - mock_get_pcp_group): + mock_get_pcp_group, + mock_get_cos_and_sin_mla): mock_npu_available.return_value = False mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa @@ -590,6 +594,7 @@ def zeros_override(*args, **kwargs): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) @@ -598,7 +603,8 @@ def zeros_override(*args, **kwargs): @patch("torch.npu.is_available") def test_build_chunked_prefix_metadata(self, mock_npu_available, mock_zeros, mock_dcp_world_size, - mock_get_pcp_group): + mock_get_pcp_group, + mock_get_cos_and_sin_mla): mock_npu_available.return_value = False mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa @@ -655,11 +661,13 @@ def zeros_override(*args, **kwargs): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) def test_build_decode_only_metadata(self, mock_dcp_world_size, - mock_get_pcp_group): + mock_get_pcp_group, + mock_get_cos_and_sin_mla): mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa @@ -708,11 +716,13 @@ def test_build_decode_only_metadata(self, mock_dcp_world_size, torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size, - mock_get_pcp_group): + mock_get_pcp_group, + mock_get_cos_and_sin_mla): mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa @@ -761,11 +771,13 @@ def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size, torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) def test_build_for_graph_capture_prefill(self, mock_dcp_world_size, - mock_get_pcp_group): + mock_get_pcp_group, + mock_get_cos_and_sin_mla): mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa pcp_group = MagicMock(spec=GroupCoordinator) diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index 8c210efb57a..6c4803d969d 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -102,7 +102,8 @@ def test_ascend_sfa_metadata_builder_default(self): assert builder.device == device assert builder.vllm_config == vllm_config - def test_ascend_sfa_metadata_builder_build(self): + @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") + def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla): kv_cache_spec = MagicMock() layer_names = ["layer1", "layer2"] vllm_config = MagicMock() @@ -147,8 +148,9 @@ def test_ascend_sfa_metadata_builder_build(self): assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens assert metadata.slot_mapping.shape == (100, 4, 1024) - @patch("vllm_ascend.device_allocator.camem.init_module") - def test_ascend_sfa_metadata_builder_build_for_graph_capture(self): + @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") + def test_ascend_sfa_metadata_builder_build_for_graph_capture( + self, mock_get_cos_and_sin_mla): kv_cache_spec = MagicMock() layer_names = ["layer1", "layer2"] vllm_config = MagicMock() diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 63aa3e28e2c..d09a1990f88 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -63,6 +63,7 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, _sin is not None: return + compilation_config = vllm_config.compilation_config model_config = vllm_config.model_config max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens From 35d4c8928c52d4ed9430f458923f3e9297accb40 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Thu, 25 Dec 2025 13:27:58 +0800 Subject: [PATCH 14/19] [Refactor] use cos_sin_cache & remove parameter like model in builder. Signed-off-by: weijinqian_v1 --- vllm_ascend/ops/rotary_embedding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index d09a1990f88..63aa3e28e2c 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -63,7 +63,6 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, _sin is not None: return - compilation_config = vllm_config.compilation_config model_config = vllm_config.model_config max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens From c79b784215f97a9934fe960ea741f8dc9121777c Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Thu, 25 Dec 2025 13:36:49 +0800 Subject: [PATCH 15/19] [Refactor] use cos_sin_cache & remove parameter like model in builder. Signed-off-by: weijinqian_v1 --- vllm_ascend/attention/sfa_v1.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 58dcc52b341..8c3a2226ee8 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -181,8 +181,10 @@ def build( cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1] seq_lens = common_attn_metadata.seq_lens[:num_reqs] - - cos, sin = get_cos_and_sin_mla(input_positions) + if has_prefill: + cos, sin = get_cos_and_sin_mla(input_positions) + else: + cos, sin = get_cos_and_sin_mla(input_positions, True) sfa_cp_context = None if self.enable_sfa_cp: From f5a795e0e3046ff2789ddb11b58853b292103174 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Thu, 25 Dec 2025 14:17:48 +0800 Subject: [PATCH 16/19] [Refactor] use cos_sin_cache & remove parameter like model in builder. Signed-off-by: weijinqian_v1 --- tests/ut/attention/test_mla_v1.py | 31 ++++++++++++++----------------- tests/ut/attention/test_sfa_v1.py | 7 +++---- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 4dd12545f49..389f13831f2 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -289,7 +289,7 @@ def test_ascend_mla_metadata_builder_spec_decode(self, mock_dcp, builder.chunked_prefill_enabled, mock_vllm_config.scheduler_config.enable_chunked_prefill) - @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") + @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch('vllm.distributed.parallel_state._PCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @@ -332,7 +332,6 @@ def test_ascend_mla_metadata_builder_build_full_graph( builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, mock_device) common_metadata = MagicMock() - model = MagicMock() common_metadata.graph_pad_size = 8 common_metadata.num_reqs = 4 common_metadata.num_actual_tokens = 5 @@ -345,7 +344,8 @@ def test_ascend_mla_metadata_builder_build_full_graph( block_table = torch.Tensor([[1, 0], [2, 0], [3, 0], [4, 0]]).int() common_metadata.block_table_tensor = block_table common_metadata.prefill_context_parallel_metadata = None - metadata = builder.build(0, common_metadata, model) + mock_get_cos_and_sin_mla.return_value = (torch.tensor(6), torch.Tensor(6)) + metadata = builder.build(0, common_metadata) self.assertEqual(metadata.decode.actual_seq_lengths_q, [1, 2, 4, 5, 6, 6, 7, 8]) @@ -528,7 +528,7 @@ def setUp(self): self.kv_cache_spec.head_size = 128 self.kv_cache_spec.num_heads = 32 - @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") + @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) @@ -583,9 +583,8 @@ def zeros_override(*args, **kwargs): layer_names=["layer_0", "layer_1"], vllm_config=self.mock_vllm_config, device=self.mock_device) - - mock_model = MagicMock() - metadata = builder.build(1, common_attn_metadata, mock_model) + mock_get_cos_and_sin_mla.return_value = (torch.tensor(10), torch.Tensor(10)) + metadata = builder.build(1, common_attn_metadata) self.assertIsInstance(metadata, AscendMLAMetadata) self.assertEqual(metadata.num_actual_tokens, @@ -594,7 +593,7 @@ def zeros_override(*args, **kwargs): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) - @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") + @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) @@ -650,9 +649,8 @@ def zeros_override(*args, **kwargs): layer_names=["layer_0", "layer_1"], vllm_config=self.mock_vllm_config, device=self.mock_device) - - mock_model = MagicMock() - metadata = builder.build(1, common_attn_metadata, mock_model) + mock_get_cos_and_sin_mla.return_value = (torch.tensor(10), torch.Tensor(10)) + metadata = builder.build(1, common_attn_metadata) self.assertIsInstance(metadata, AscendMLAMetadata) self.assertEqual(metadata.num_actual_tokens, @@ -661,7 +659,7 @@ def zeros_override(*args, **kwargs): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) - @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") + @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) @@ -705,9 +703,8 @@ def test_build_decode_only_metadata(self, mock_dcp_world_size, layer_names=["layer_0", "layer_1"], vllm_config=self.mock_vllm_config, device=self.mock_device) - - mock_model = MagicMock() - metadata = builder.build(1, common_attn_metadata, mock_model) + mock_get_cos_and_sin_mla.return_value = (torch.tensor(10), torch.Tensor(10)) + metadata = builder.build(1, common_attn_metadata) self.assertIsInstance(metadata, AscendMLAMetadata) self.assertEqual(metadata.num_actual_tokens, @@ -716,7 +713,7 @@ def test_build_decode_only_metadata(self, mock_dcp_world_size, torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) - @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") + @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) @@ -771,7 +768,7 @@ def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size, torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) - @patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla") + @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index 6c4803d969d..2da07a16958 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -134,14 +134,11 @@ def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla): common_attn_metadata.sin = None common_attn_metadata.num_input_tokens = 100 - model = MagicMock() - model.model.layers = [MagicMock() for _ in range(10)] - model.model.start_layer = 0 + mock_get_cos_and_sin_mla.return_value = (torch.randn(100), torch.randn(100)) metadata = builder.build( common_prefix_len=10, common_attn_metadata=common_attn_metadata, - model=model, ) assert isinstance(metadata, AscendSFAMetadata) @@ -181,6 +178,8 @@ def test_ascend_sfa_metadata_builder_build_for_graph_capture( common_attn_metadata.sin = None common_attn_metadata.num_input_tokens = 100 + mock_get_cos_and_sin_mla.return_value = (torch.randn(100), torch.randn(100)) + attn_metadata = builder.build_for_graph_capture( common_attn_metadata=common_attn_metadata, attn_state=AscendAttentionState.DecodeOnly, From 2604468405df9a50654fdf77d69740cf3dc93813 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Thu, 25 Dec 2025 14:39:22 +0800 Subject: [PATCH 17/19] [Refactor] use cos_sin_cache & remove parameter like model in builder. Signed-off-by: weijinqian_v1 --- tests/ut/attention/test_mla_v1.py | 18 ++++++++++++------ tests/ut/attention/test_sfa_v1.py | 6 ++++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 389f13831f2..88d5071d7b9 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -344,7 +344,8 @@ def test_ascend_mla_metadata_builder_build_full_graph( block_table = torch.Tensor([[1, 0], [2, 0], [3, 0], [4, 0]]).int() common_metadata.block_table_tensor = block_table common_metadata.prefill_context_parallel_metadata = None - mock_get_cos_and_sin_mla.return_value = (torch.tensor(6), torch.Tensor(6)) + mock_get_cos_and_sin_mla.return_value = (torch.tensor([6, 6]), + torch.Tensor([6, 6])) metadata = builder.build(0, common_metadata) self.assertEqual(metadata.decode.actual_seq_lengths_q, @@ -583,7 +584,8 @@ def zeros_override(*args, **kwargs): layer_names=["layer_0", "layer_1"], vllm_config=self.mock_vllm_config, device=self.mock_device) - mock_get_cos_and_sin_mla.return_value = (torch.tensor(10), torch.Tensor(10)) + mock_get_cos_and_sin_mla.return_value = (torch.tensor(10), + torch.Tensor(10)) metadata = builder.build(1, common_attn_metadata) self.assertIsInstance(metadata, AscendMLAMetadata) @@ -649,7 +651,8 @@ def zeros_override(*args, **kwargs): layer_names=["layer_0", "layer_1"], vllm_config=self.mock_vllm_config, device=self.mock_device) - mock_get_cos_and_sin_mla.return_value = (torch.tensor(10), torch.Tensor(10)) + mock_get_cos_and_sin_mla.return_value = (torch.tensor(10), + torch.Tensor(10)) metadata = builder.build(1, common_attn_metadata) self.assertIsInstance(metadata, AscendMLAMetadata) @@ -703,7 +706,8 @@ def test_build_decode_only_metadata(self, mock_dcp_world_size, layer_names=["layer_0", "layer_1"], vllm_config=self.mock_vllm_config, device=self.mock_device) - mock_get_cos_and_sin_mla.return_value = (torch.tensor(10), torch.Tensor(10)) + mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]), + torch.Tensor([10, 10])) metadata = builder.build(1, common_attn_metadata) self.assertIsInstance(metadata, AscendMLAMetadata) @@ -757,7 +761,8 @@ def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size, layer_names=["layer_0", "layer_1"], vllm_config=self.mock_vllm_config, device=self.mock_device) - + mock_get_cos_and_sin_mla.return_value = (torch.tensor([10, 10]), + torch.Tensor([10, 10])) metadata = builder.build_for_graph_capture( common_attn_metadata, AscendAttentionState.DecodeOnly) @@ -803,7 +808,8 @@ def test_build_for_graph_capture_prefill(self, mock_dcp_world_size, layer_names=["layer_0", "layer_1"], vllm_config=self.mock_vllm_config, device=self.mock_device) - + mock_get_cos_and_sin_mla.return_value = (torch.tensor(10), + torch.Tensor(10)) with self.assertRaises(NotImplementedError) as ctx: builder.build_for_graph_capture( common_attn_metadata, AscendAttentionState.PrefillNoCache) diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index 2da07a16958..dd4c2f5e8e4 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -134,7 +134,8 @@ def test_ascend_sfa_metadata_builder_build(self, mock_get_cos_and_sin_mla): common_attn_metadata.sin = None common_attn_metadata.num_input_tokens = 100 - mock_get_cos_and_sin_mla.return_value = (torch.randn(100), torch.randn(100)) + mock_get_cos_and_sin_mla.return_value = (torch.randn(100), + torch.randn(100)) metadata = builder.build( common_prefix_len=10, @@ -178,7 +179,8 @@ def test_ascend_sfa_metadata_builder_build_for_graph_capture( common_attn_metadata.sin = None common_attn_metadata.num_input_tokens = 100 - mock_get_cos_and_sin_mla.return_value = (torch.randn(100), torch.randn(100)) + mock_get_cos_and_sin_mla.return_value = (torch.randn(100), + torch.randn(100)) attn_metadata = builder.build_for_graph_capture( common_attn_metadata=common_attn_metadata, From e1cf2632e7dab140cb05eed2ab9293870e4fb64f Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Fri, 26 Dec 2025 11:20:43 +0800 Subject: [PATCH 18/19] [Refactor] use cos_sin_cache & remove parameter like model in builder. Signed-off-by: weijinqian_v1 --- vllm_ascend/spec_decode/mtp_proposer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 94137272c54..4a85a8215da 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -34,6 +34,7 @@ from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, update_mla_attn_dcp_pcp_params, update_mla_attn_params) +from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, shared_expert_dp_enabled) @@ -944,10 +945,7 @@ def _propose( graph_pad_size - batch_size, batch_size, decode_metadata.actual_seq_lengths_q) - decode_metadata.cos = builder.cos_cache[ - positions[:batch_size]].unsqueeze(1).unsqueeze(2) - decode_metadata.sin = builder.sin_cache[ - positions[:batch_size]].unsqueeze(1).unsqueeze(2) + decode_metadata.cos, decode_metadata.sin = get_cos_and_sin_mla(positions[:batch_size]) # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex # to remove such requests from the batch, we keep them in the batch From ac20bdf531234671b2946f3936ca1f0c2993e573 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Fri, 26 Dec 2025 11:30:30 +0800 Subject: [PATCH 19/19] [Refactor] use cos_sin_cache & remove parameter like model in builder. Signed-off-by: weijinqian_v1 --- vllm_ascend/spec_decode/mtp_proposer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 4a85a8215da..d577304efe8 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -945,7 +945,8 @@ def _propose( graph_pad_size - batch_size, batch_size, decode_metadata.actual_seq_lengths_q) - decode_metadata.cos, decode_metadata.sin = get_cos_and_sin_mla(positions[:batch_size]) + decode_metadata.cos, decode_metadata.sin = get_cos_and_sin_mla( + positions[:batch_size]) # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex # to remove such requests from the batch, we keep them in the batch