diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index a8857d14b09..f02c1b88f14 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -450,11 +450,11 @@ def test_process_attn_out_lse(self): self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1) @patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context') - @patch("torch_npu.atb.npu_multi_head_latent_attention") + @patch("torch_npu.npu_fused_infer_attention_score") @patch('torch_npu.npu_attention_update') @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) def test_forward_decode_pcp_dcp(self, mock_npu_attention_update, - mock_npu_multi_head_latent_attention, + mock_npu_fused_infer_attention_score, mock_get_forward_context): self.impl.dcp_size = 2 self.impl.pcp_size = 2 @@ -470,8 +470,8 @@ def test_forward_decode_pcp_dcp(self, mock_npu_attention_update, q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim) q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim) - k_nope = torch.randn(NB, BS, 1, self.impl.kv_lora_rank) - k_pe = torch.randn(NB, BS, 1, self.impl.qk_rope_head_dim) + k_nope = torch.randn(NB, 1, BS, self.impl.kv_lora_rank) + k_pe = torch.randn(NB, 1, BS, self.impl.qk_rope_head_dim) attn_metadata = MagicMock() attn_metadata.attn_state = AscendAttentionState.SpecDecoding @@ -485,7 +485,7 @@ def test_forward_decode_pcp_dcp(self, mock_npu_attention_update, mock_npu_attention_update.return_value = (torch.randn( B, self.impl.num_heads, self.impl.kv_lora_rank), None) - mock_npu_multi_head_latent_attention.return_value = [ + mock_npu_fused_infer_attention_score.return_value = [ torch.randn(B, N, self.impl.kv_lora_rank), torch.randn(B, N, 1) ] diff --git a/tests/ut/compilation/test_acl_graph.py b/tests/ut/compilation/test_acl_graph.py index 88c687cf836..0261971ff82 100644 --- a/tests/ut/compilation/test_acl_graph.py +++ b/tests/ut/compilation/test_acl_graph.py @@ -754,7 +754,7 @@ def setUp(self): @patch('torch.npu.graph_task_update_end', ) @patch('torch.npu.graph_task_update_begin', MagicMock()) - @patch('torch_npu.atb.npu_multi_head_latent_attention', MagicMock()) + @patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock()) def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end): input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) block_table = torch.zeros(2, 5, dtype=torch.long) @@ -793,16 +793,20 @@ def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end): qk_rope_head_dim = 32 qk_nope_head_dim = 64 query = torch.randn(4, num_heads, qk_head_dim) - q_pe = query[..., qk_nope_head_dim:] + q_nope = query[..., :qk_nope_head_dim] + q_pe = query[..., qk_rope_head_dim:] k_nope = torch.randn(4, num_heads, qk_nope_head_dim) k_pe = torch.randn(4, num_heads, qk_rope_head_dim) + input_layout = "BNSD" + actual_seq_lengths_kv = [1, 1] out = torch.randn(2, 16, 128) lse = torch.randn(2, 16, 8) self.graph_params.attn_params[4] = [] self.graph_params.attn_params[4].append( - (q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads, - scale, num_kv_heads, out, lse)) + (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, + None, 0, scale, block_table, 128, None, actual_seq_lengths_kv, + out, lse)) with patch("torch_npu._C._npu_setStream", return_value=None): update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index c9c8bd0b1b5..0a75323b353 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -13,6 +13,8 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec +from vllm_ascend.attention.attention_v1 import AscendAttentionState + # isort: off from vllm_ascend.attention.mla_v1 import ( AscendMLADecodeMetadata, AscendMLAImpl, AscendMLAMetadata, @@ -259,8 +261,12 @@ def build_decode_metadata( batch_seq_mask, non_blocking=True) batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.shape[0]] cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) - decode_metadata.cp_seq_len = cp_seq_len + decode_metadata.cp_seq_len = cp_seq_len.tolist() decode_metadata.batch_seq_mask = batch_seq_mask + + actual_seq_lengths_q = torch.arange(self.num_decodes_flatten) + 1 + decode_metadata.actual_seq_lengths_q = actual_seq_lengths_q + return decode_metadata @@ -575,19 +581,51 @@ def _forward_decode( else: num_heads = self.num_heads - k_nope = k_nope.view(-1, block_size, self.num_kv_heads, + # use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, self.kv_lora_rank) - k_pe = k_pe.view(-1, block_size, self.num_kv_heads, + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, self.qk_rope_head_dim) - q_nope = q_nope.view(num_tokens, num_heads, -1) - q_pe = q_pe.view(num_tokens, num_heads, -1) - # use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask - seq_len = decode_meta.cp_seq_len + + actual_seq_lengths = None + input_layout = "BNSD" + + if (attn_metadata.attn_state in [ + AscendAttentionState.SpecDecoding, + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.DecodeOnly, + ] and self.speculative_config is not None): + input_layout = "TND" + # TODO: If the driver is upgraded later, the contiguous function can be deleted. + q_nope = q_nope.view(num_tokens, num_heads, -1).contiguous() + q_pe = q_pe.view(num_tokens, num_heads, -1) + sparse_mode = 3 + spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore + actual_seq_lengths = decode_meta.actual_seq_lengths_q + else: + q_nope = q_nope.view(num_tokens, num_heads, 1, -1).contiguous() + q_pe = q_pe.view(num_tokens, num_heads, 1, -1) + sparse_mode = 0 + spec_attn_mask = None common_kwargs = { - "return_lse": True, - "calc_type": "calc_type_ring", + "query_rope": q_pe, + "key_rope": k_pe, + "num_heads": num_heads, + "num_key_value_heads": self.num_kv_heads, + "input_layout": input_layout, + "atten_mask": spec_attn_mask, + "sparse_mode": sparse_mode, + "scale": self.scale, + "antiquant_mode": 0, + "antiquant_scale": None, + "block_table": decode_meta.block_table, + "block_size": block_size, + "actual_seq_lengths": actual_seq_lengths, + "actual_seq_lengths_kv": decode_meta.cp_seq_len, + "softmax_lse_flag": True, } + forward_context: ForwardContext = get_forward_context() if forward_context.is_draft_model: graph_params = get_draft_graph_params() @@ -601,57 +639,53 @@ def _forward_decode( graph_params.events[num_tokens].append(event) workspace = graph_params.workspaces.get(num_tokens) if workspace is None: - workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace( - q_nope, q_pe, k_nope, k_pe, decode_meta.block_table, - seq_len, num_heads, self.scale, self.num_kv_heads, - **common_kwargs) + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + q_nope, k_nope, k_nope, **common_kwargs) update_graph_params_workspaces(num_tokens, workspace) attn_output = torch.empty_like(q_nope) - softmax_lse = torch.empty((num_tokens, num_heads, 1), - dtype=q_nope.dtype, - device=q_nope.device) + if input_layout == "BNSD": + softmax_lse = torch.empty((num_tokens, num_heads, 1, 1), + dtype=torch.float, + device=q_nope.device) + else: + softmax_lse = torch.empty((num_tokens, num_heads, 1), + dtype=torch.float, + device=q_nope.device) + graph_params.attn_params[num_tokens].append( - (weak_ref_tensors(q_nope), weak_ref_tensors(q_pe), - weak_ref_tensors(k_nope), weak_ref_tensors(k_pe), - decode_meta.block_table, seq_len, num_heads, self.scale, - self.num_kv_heads, weak_ref_tensors(attn_output), - weak_ref_tensors(softmax_lse))) + (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), + weak_ref_tensors(q_pe), weak_ref_tensors(k_pe), num_heads, + self.num_kv_heads, input_layout, + weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None + else None, sparse_mode, self.scale, + weak_ref_tensors(decode_meta.block_table), block_size, + actual_seq_lengths, decode_meta.cp_seq_len, + weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse))) torch.npu.graph_task_group_begin(stream) - torch_npu.atb.npu_multi_head_latent_attention( + torch_npu.npu_fused_infer_attention_score.out( q_nope, - q_pe, k_nope, - k_pe, - decode_meta.block_table, - seq_len, - num_heads, - self.scale, - self.num_kv_heads, + k_nope, **common_kwargs, workspace=workspace, - output=attn_output, - lse=softmax_lse) + out=[attn_output, softmax_lse]) handle = torch.npu.graph_task_group_end(stream) graph_params.handles[num_tokens].append(handle) else: - attn_output = torch.empty_like(q_nope) - softmax_lse = torch.empty((num_tokens, num_heads, 1), - dtype=q_nope.dtype, - device=q_nope.device) - torch_npu.atb.npu_multi_head_latent_attention( + attn_output, softmax_lse = torch_npu.npu_fused_infer_attention_score( q_nope, - q_pe, k_nope, - k_pe, - decode_meta.block_table, - seq_len, - num_heads, - self.scale, - self.num_kv_heads, - return_lse=True, - calc_type="calc_type_ring", - output=attn_output, - lse=softmax_lse) + k_nope, + **common_kwargs, + ) + if input_layout == "BNSD": + B_attn, N_attn, S, D = attn_output.shape + B_lse, N_lse, Q_S, _ = softmax_lse.shape + + attn_output = attn_output.permute(0, 2, 1, 3).reshape( + B_attn * S, N_attn, D) + softmax_lse = softmax_lse.permute(0, 2, 1, 3).reshape( + B_lse * Q_S, N_lse, 1) # Update out&lse attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse, diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 842a48890bf..bbe0b61d1bc 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -466,38 +466,60 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context, graph_params.handles[runtime_shape], graph_params.events[runtime_shape], ): - (q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, - scale, num_kv_heads, attn_output, softmax_lse) = param + ( + q_nope, + k_nope, + q_pe, + k_pe, + num_heads, + num_kv_heads, + input_layout, + spec_attn_mask, + sparse_mode, + scale, + block_table, + block_size, + actual_seq_lengths, + actual_seq_lengths_kv, + attn_output, + softmax_lse, + ) = param decode_meta = forward_context.attn_metadata[key].decode seq_len = decode_meta.cp_seq_len + if isinstance(seq_len, torch.Tensor): + seq_len = seq_len.tolist() + actual_seq_lengths_kv = seq_len - # For pcp + spec decode, we flatten seq_lens - # to avoid irregular attn_mask shape, - # so there's no need to divide runtime_shape by spec_multiple - pad_length = runtime_shape - len(seq_len) - pad_tensor = torch.zeros(pad_length, - dtype=seq_len.dtype, - device=seq_len.device) - seq_len = torch.cat([seq_len, pad_tensor], dim=0) + pad_length = runtime_shape - len(actual_seq_lengths_kv) + if pad_length > 0: + actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * ( + runtime_shape - len(actual_seq_lengths_kv)) torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu.atb.npu_multi_head_latent_attention( + torch_npu.npu_fused_infer_attention_score.out( q_nope, - q_pe, k_nope, - k_pe, - block_table, - seq_len, - num_heads, - scale, - num_kv_heads, - return_lse=True, - calc_type="calc_type_ring", + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_heads=num_heads, + num_key_value_heads=num_kv_heads, + input_layout=input_layout, + atten_mask=spec_attn_mask, + sparse_mode=sparse_mode, + scale=scale, + antiquant_mode=0, + antiquant_scale=None, + softmax_lse_flag=True, + block_table=block_table, + block_size=block_size, + actual_seq_lengths_kv=actual_seq_lengths_kv, + actual_seq_lengths=actual_seq_lengths, workspace=graph_params.workspaces.get(runtime_shape), - output=attn_output, - lse=softmax_lse) + out=[attn_output, softmax_lse], + ) torch.npu.graph_task_update_end(update_stream) event.record(update_stream)