diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_accuracy.py b/tests/e2e/multicard/4-cards/long_sequence/test_accuracy.py index 61d99c97e8e..ee0b4aeb5c8 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_accuracy.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_accuracy.py @@ -210,3 +210,72 @@ def test_accuracy_pcp_only(max_tokens: int, ) -> None: name_0="vllm_eager_outputs", name_1="vllm_pcp_only_outputs", ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [10]) +def test_cp_kv_cache_interleave_size_between_tp_and_cp( + model: str, + max_tokens: int, +) -> None: + prompts = [ + "The president of the United States is", "The capital of France is" + ] + + common_kwargs = { + "max_model_len": 1024, + } + + if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8": + cp_kwargs = { + "tensor_parallel_size": 2, + "decode_context_parallel_size": 2, + "prefill_context_parallel_size": 2, + "enable_expert_parallel": True, + "cp_kv_cache_interleave_size": 128, + "enforce_eager": True, + "quantization": "ascend", + } + tp_kwargs = { + "tensor_parallel_size": 4, + "enable_expert_parallel": True, + "enforce_eager": True, + "quantization": "ascend", + } + + else: + cp_kwargs = { + "tensor_parallel_size": 1, + "decode_context_parallel_size": 1, + "prefill_context_parallel_size": 2, + "cp_kv_cache_interleave_size": 128, + "compilation_config": { + "cudagraph_mode": "FULL_DECODE_ONLY", + "cudagraph_capture_sizes": [4, 8, 24, 48, 60] + }, + } + tp_kwargs = { + "tensor_parallel_size": 2, + "enforce_eager": True, + } + + cp_full_kwargs = {} + cp_full_kwargs.update(common_kwargs) # type: ignore + cp_full_kwargs.update(cp_kwargs) # type: ignore + + tp_full_kwargs = {} + tp_full_kwargs.update(common_kwargs) # type: ignore + tp_full_kwargs.update(tp_kwargs) # type: ignore + with VllmRunner(model, **cp_full_kwargs) as runner: # type: ignore + vllm_context_parallel_outputs = runner.generate_greedy( + prompts, max_tokens) + + with VllmRunner(model, **tp_full_kwargs) as runner: # type: ignore + vllm_eager_outputs = runner.generate_greedy(prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs, + outputs_1_lst=vllm_context_parallel_outputs, + name_0="vllm_eager_outputs", + name_1="vllm_context_parallel_outputs", + ) diff --git a/tests/ut/attention/test_attention_cp.py b/tests/ut/attention/test_attention_cp.py index cc518fdab53..4d94a4b2fe6 100644 --- a/tests/ut/attention/test_attention_cp.py +++ b/tests/ut/attention/test_attention_cp.py @@ -113,8 +113,6 @@ def mock_npu_fused_infer_attention_score_func(query, k_nope, value, attn_metadata = MagicMock() attn_metadata.decode_meta = MagicMock() - attn_metadata.decode_meta.batch_seq_mask = torch.tensor( - [1, 0], dtype=torch.bool) output = self.impl._forward_decode_pcp_dcp(query, attn_metadata) self.assertEqual(output.shape[0], 2) @@ -137,8 +135,10 @@ def test_prefill_query_all_gather(self): self.assertEqual(output.shape[2], 128) @patch('torch.ops.npu.npu_fused_infer_attention_score') + @patch('torch_npu.npu_attention_update') @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) - def test_compute_prefill_context(self, mock_npu_attention): + def test_compute_prefill_context(self, mock_npu_attention_update, + mock_npu_attention): block_num = 100 block_size = 128 @@ -181,7 +181,9 @@ def mock_load_kv_for_chunk(attn_metadata, kv_cache, head_size), torch.randn( batch_size, num_heads, 1) - + mock_npu_attention_update.return_value = torch.randn( + batch_size, self.impl.num_heads, + head_size), torch.randn(batch_size, self.impl.num_heads, 1) context_output = self.impl._compute_prefill_context( query, kv_cache, attn_metadata) local_context_output = torch.cat(context_output, @@ -406,11 +408,9 @@ def test_attention_with_nomask_none(self, mock_npu_attention): self.assertEqual(attn_lse.shape, (96, 8, 1)) @patch('torch.ops.npu.npu_fused_infer_attention_score') - @patch( - 'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._update_out_and_lse' - ) + @patch('torch_npu.npu_attention_update') def test_attention_with_nomask_and_mask_chunk( - self, mock_update_out_and_lse, + self, mock_npu_attention_update, mock_npu_fused_infer_attention_score): # Mock input data q = torch.randn(self.q_total_tokens, self.impl.num_heads, @@ -432,7 +432,7 @@ def test_attention_with_nomask_and_mask_chunk( self.q_total_tokens, self.impl.num_heads, self.impl.head_size), torch.randn(self.q_total_tokens, self.impl.num_heads, 1) - mock_update_out_and_lse.return_value = torch.randn( + mock_npu_attention_update.return_value = torch.randn( self.q_total_tokens, self.impl.num_heads, self.impl.head_size), torch.randn(self.q_total_tokens, self.impl.num_heads, 1) @@ -481,8 +481,12 @@ def test_attention_with_nomask_and_mask_nochunk( self.q_total_tokens, self.impl.num_heads, self.impl.head_size), torch.randn(self.q_total_tokens, self.impl.num_heads, 1) - mock_npu_attn_out_lse_update.return_value = torch.randn( - self.q_total_tokens, self.impl.num_heads, self.impl.head_size) + mock_npu_attn_out_lse_update.return_value = (torch.randn( + self.q_total_tokens, self.impl.num_heads, self.impl.head_size), + torch.randn( + self.q_total_tokens, + self.impl.num_heads, + 1)) # Call the method under test output, attn_lse = self.impl._attention_with_nomask_and_mask( @@ -500,7 +504,6 @@ def test_attention_with_nomask_and_mask_nochunk( mock_npu_attn_out_lse_update.assert_called_once() self.assertEqual(mock_npu_fused_infer_attention_score.call_count, 2) self.assertIsNotNone(output) - self.assertEqual(attn_lse, None) @patch( 'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update' @@ -550,14 +553,14 @@ def test_npu_attn_out_lse_update(self, mock_npu_attention_update): attn_out_nomask = torch.randn(8, 128, 128) # Mock output - mock_npu_attention_update.return_value = (torch.randn(8 * 128, - 128), None) + mock_npu_attention_update.return_value = (torch.randn(8 * 128, 128), + torch.randn(8 * 128, 1)) # Call the method under test - output = self.impl._npu_attn_out_lse_update(attn_lse_mask, - attn_lse_nomask, - attn_out_mask, - attn_out_nomask) + output, _ = self.impl._npu_attn_out_lse_update(attn_lse_mask, + attn_lse_nomask, + attn_out_mask, + attn_out_nomask) # Assert the method call self.assertIsInstance(output, torch.Tensor) @@ -565,28 +568,11 @@ def test_npu_attn_out_lse_update(self, mock_npu_attention_update): mock_npu_attention_update.assert_called_once() - def test_update_out_and_lse(self): - # Mock input data - out_list = torch.randn(3, 2, 4, - 8) # [N, batch_size, num_heads, head_size] - lse_list = torch.randn(3, 2, 4, 1) # [N, batch_size, num_heads, 1] - - # Call the method under test - out_final, lse_final = self.impl._update_out_and_lse( - out_list, lse_list) - - # Assert the method call - self.assertEqual(out_final.shape, - (2, 4, 8)) # [batch_size, num_heads, head_size] - self.assertEqual(lse_final.shape, - (2, 4, 1)) # [batch_size, num_heads, 1] - - self.assertIsInstance(out_final, torch.Tensor) - self.assertIsInstance(lse_final, torch.Tensor) - + @patch('torch_npu.npu_attention_update') @patch_distributed_groups(dcp_size=2, pcp_size=3) def test_update_chunk_attn_out_lse_dcp2_pcp3(self, mock_all_to_all_single, - mock_dcp, mock_pcp): + mock_dcp, mock_pcp, + mock_npu_attention_update): # Mock input data prefix_chunk_output = torch.randn(2, 4, 8) prefix_chunk_lse = torch.randn(2, 4, 1) @@ -601,6 +587,8 @@ def test_update_chunk_attn_out_lse_dcp2_pcp3(self, mock_all_to_all_single, chunk_data) global_context_output = global_context_output.permute([2, 0, 1 ]).contiguous() + mock_npu_attention_update.return_value = (torch.randn(2, 2, 8), + torch.randn(2, 2, 1)) output, lse = self.impl._update_global_context_output( global_context_output) @@ -613,9 +601,11 @@ def test_update_chunk_attn_out_lse_dcp2_pcp3(self, mock_all_to_all_single, mock_all_to_all_single.assert_called_once() mock_pcp.all_gather.assert_called_once() + @patch('torch_npu.npu_attention_update') @patch_distributed_groups(dcp_size=2) def test_update_chunk_attn_out_lse_dcp2_pcp1(self, mock_all_to_all_single, - mock_dcp, mock_pcp): + mock_dcp, mock_pcp, + mock_npu_attention_update): # Mock input data prefix_chunk_output = torch.randn(2, 4, 8) prefix_chunk_lse = torch.randn(2, 4, 1) @@ -631,6 +621,8 @@ def test_update_chunk_attn_out_lse_dcp2_pcp1(self, mock_all_to_all_single, chunk_data) global_context_output = global_context_output.permute([2, 0, 1 ]).contiguous() + mock_npu_attention_update.return_value = (torch.randn(2, 2, 8), + torch.randn(2, 2, 1)) output, lse = self.impl._update_global_context_output( global_context_output) @@ -643,9 +635,11 @@ def test_update_chunk_attn_out_lse_dcp2_pcp1(self, mock_all_to_all_single, mock_all_to_all_single.assert_called_once() mock_pcp.all_gather.assert_not_called() + @patch('torch_npu.npu_attention_update') @patch_distributed_groups(pcp_size=2) def test_update_chunk_attn_out_lse_dcp1_pcp2(self, mock_all_to_all_single, - mock_dcp, mock_pcp): + mock_dcp, mock_pcp, + mock_npu_attention_update): # Mock input data prefix_chunk_output = torch.randn(2, 4, 8) prefix_chunk_lse = torch.randn(2, 4, 1) @@ -661,6 +655,9 @@ def test_update_chunk_attn_out_lse_dcp1_pcp2(self, mock_all_to_all_single, chunk_data) global_context_output = global_context_output.permute([2, 0, 1 ]).contiguous() + mock_npu_attention_update.return_value = torch.randn(2, 4, + 8), torch.randn( + 2, 4, 1) output, lse = self.impl._update_global_context_output( global_context_output) diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index a8857d14b09..a8efce3756a 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -439,22 +439,19 @@ def test_process_attn_out_lse(self): decode_metadata = MagicMock() decode_metadata.actual_seq_lengths_q = MagicMock() decode_metadata.seq_lens_list = MagicMock() - decode_metadata.batch_seq_mask = torch.tensor([True, False], - dtype=torch.bool) - result = _process_attn_out_lse(attn_output, softmax_lse, - decode_metadata.batch_seq_mask) + result = _process_attn_out_lse(attn_output, softmax_lse) self.assertEqual(result.shape[0], B * self.impl.pcp_size) self.assertEqual(result.shape[1], N) self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1) @patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context') - @patch("torch_npu.atb.npu_multi_head_latent_attention") + @patch("torch_npu.npu_fused_infer_attention_score") @patch('torch_npu.npu_attention_update') @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) def test_forward_decode_pcp_dcp(self, mock_npu_attention_update, - mock_npu_multi_head_latent_attention, + mock_npu_fused_infer_attention_score, mock_get_forward_context): self.impl.dcp_size = 2 self.impl.pcp_size = 2 @@ -470,22 +467,20 @@ def test_forward_decode_pcp_dcp(self, mock_npu_attention_update, q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim) q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim) - k_nope = torch.randn(NB, BS, 1, self.impl.kv_lora_rank) - k_pe = torch.randn(NB, BS, 1, self.impl.qk_rope_head_dim) + k_nope = torch.randn(NB, 1, BS, self.impl.kv_lora_rank) + k_pe = torch.randn(NB, 1, BS, self.impl.qk_rope_head_dim) attn_metadata = MagicMock() attn_metadata.attn_state = AscendAttentionState.SpecDecoding attn_metadata.decode = MagicMock() attn_metadata.decode.actual_seq_lengths_q = MagicMock() attn_metadata.decode.seq_lens_list = MagicMock() - attn_metadata.decode.batch_seq_mask = torch.tensor([False, False], - dtype=torch.bool) self.impl.enable_kv_nz = True mock_npu_attention_update.return_value = (torch.randn( B, self.impl.num_heads, self.impl.kv_lora_rank), None) - mock_npu_multi_head_latent_attention.return_value = [ + mock_npu_fused_infer_attention_score.return_value = [ torch.randn(B, N, self.impl.kv_lora_rank), torch.randn(B, N, 1) ] @@ -886,12 +881,8 @@ def test_process_attn_out_lse_with_dcp_pcp(self, mock_all_to_all, mock_dcp, # Inputs attn_output = torch.randn(B, H, D) softmax_lse = torch.randn(B, H, 1) - batch_seq_mask = torch.tensor([False, True, False, False]) # [B] - decode_meta = MagicMock() - decode_meta.batch_seq_mask = batch_seq_mask - result = _process_attn_out_lse(attn_output, softmax_lse, - batch_seq_mask) + result = _process_attn_out_lse(attn_output, softmax_lse) # [PCP * S, DCP * H, D + 1] self.assertIsInstance(result, torch.Tensor) assert result.shape == (B * self.impl.pcp_size, H, D + 1) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 46a58626753..e7bea5135dd 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -121,7 +121,6 @@ def test_ascend_mla_decode_metadata_default(self): seq_lens_list = [2, 3] attn_mask = None cp_seq_len = torch.tensor([2, 3]) - batch_seq_mask = torch.tensor([[1, 1, 0, 0], [1, 1, 1, 0]]) metadata = AscendMLADecodeMetadata(input_positions=input_positions, block_table=block_table, @@ -129,8 +128,7 @@ def test_ascend_mla_decode_metadata_default(self): max_seq_lens=max_seq_lens, seq_lens_list=seq_lens_list, attn_mask=attn_mask, - cp_seq_len=cp_seq_len, - batch_seq_mask=batch_seq_mask) + cp_seq_len=cp_seq_len) self.assertIs(metadata.input_positions, input_positions) self.assertIs(metadata.block_table, block_table) @@ -139,7 +137,6 @@ def test_ascend_mla_decode_metadata_default(self): self.assertEqual(metadata.seq_lens_list, seq_lens_list) self.assertIsNone(attn_mask) self.assertIs(metadata.cp_seq_len, cp_seq_len) - self.assertIs(metadata.batch_seq_mask, batch_seq_mask) class TestAscendMLAMetadata(TestBase): diff --git a/tests/ut/compilation/test_acl_graph.py b/tests/ut/compilation/test_acl_graph.py index 88c687cf836..0261971ff82 100644 --- a/tests/ut/compilation/test_acl_graph.py +++ b/tests/ut/compilation/test_acl_graph.py @@ -754,7 +754,7 @@ def setUp(self): @patch('torch.npu.graph_task_update_end', ) @patch('torch.npu.graph_task_update_begin', MagicMock()) - @patch('torch_npu.atb.npu_multi_head_latent_attention', MagicMock()) + @patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock()) def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end): input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) block_table = torch.zeros(2, 5, dtype=torch.long) @@ -793,16 +793,20 @@ def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end): qk_rope_head_dim = 32 qk_nope_head_dim = 64 query = torch.randn(4, num_heads, qk_head_dim) - q_pe = query[..., qk_nope_head_dim:] + q_nope = query[..., :qk_nope_head_dim] + q_pe = query[..., qk_rope_head_dim:] k_nope = torch.randn(4, num_heads, qk_nope_head_dim) k_pe = torch.randn(4, num_heads, qk_rope_head_dim) + input_layout = "BNSD" + actual_seq_lengths_kv = [1, 1] out = torch.randn(2, 16, 128) lse = torch.randn(2, 16, 8) self.graph_params.attn_params[4] = [] self.graph_params.attn_params[4].append( - (q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads, - scale, num_kv_heads, out, lse)) + (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, + None, 0, scale, block_table, 128, None, actual_seq_lengths_kv, + out, lse)) with patch("torch_npu._C._npu_setStream", return_value=None): update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index affcf643247..bce0c202527 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -59,10 +59,6 @@ def __init__( device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) - self.batch_seq_mask_buf = torch.empty( - vllm_config.scheduler_config.max_num_batched_tokens, - dtype=torch.uint8, - device=device) self.pcp_size = get_pcp_group().world_size self.pcp_rank = get_pcp_group( ).rank_in_group if self.pcp_size > 1 else 0 @@ -226,15 +222,9 @@ def build( num_computed_tokens_array = np.array( num_computed_tokens_of_pcp_dcp) num_computed_tokens_array = num_computed_tokens_array[:num_decodes] - batch_seq_mask = (num_computed_tokens_array[:, self.pcp_rank, - self.dcp_rank] == 0) # TODO: numpy array mode of the shared memory is used to improve performance - self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( - torch.from_numpy(batch_seq_mask), non_blocking=True) decode_metadata = AscendMetadataForDecode( num_computed_tokens_of_pcp_dcp=num_computed_tokens_array, - batch_seq_mask=self.batch_seq_mask_buf[:batch_seq_mask. - shape[0]], block_tables=block_table[:num_decodes]) attn_metadata = AscendMetadata( @@ -337,17 +327,8 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor, output = attn_out_mask attn_lse = attn_lse_mask if k_nomask is not None: - if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is None: - output = self._npu_attn_out_lse_update(attn_lse_mask, - attn_lse_nomask, - attn_out_mask, - attn_out_nomask) - attn_lse = None - else: - output, attn_lse = self._update_out_and_lse( - torch.stack([attn_out_nomask, attn_out_mask], dim=0), - torch.stack([attn_lse_nomask, attn_lse_mask], dim=0)) - + output, attn_lse = self._npu_attn_out_lse_update( + attn_lse_mask, attn_lse_nomask, attn_out_mask, attn_out_nomask) return output, attn_lse def _npu_attn_out_lse_update(self, attn_lse_mask, attn_lse_nomask, @@ -363,13 +344,14 @@ def _npu_attn_out_lse_update(self, attn_lse_mask, attn_lse_nomask, attn_out_nomask = attn_out_nomask.to(torch.float32) attn_lse_mask = attn_lse_mask.to(torch.float32) attn_lse_nomask = attn_lse_nomask.to(torch.float32) - attn_output = [attn_out_nomask, attn_out_mask] - attn_lse = [attn_lse_nomask, attn_lse_mask] - update_type = 0 - output, _ = torch_npu.npu_attention_update(attn_lse, attn_output, - update_type) - output = output.view(T, N, D) - return output + attn_output_list = [attn_out_nomask, attn_out_mask] + attn_lse_list = [attn_lse_nomask, attn_lse_mask] + update_type = 1 + attn_output, attn_lse = torch_npu.npu_attention_update( + attn_lse_list, attn_output_list, update_type) + attn_output = attn_output.view(T, N, D) + attn_lse = attn_lse.view(T, N, 1) + return attn_output, attn_lse def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -546,26 +528,10 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, else: attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score( query, k_nope, value, **common_kwargs) - attn_out_lse = _process_attn_out_lse( - attn_out, attn_lse, attn_metadata.decode_meta.batch_seq_mask) + attn_out_lse = _process_attn_out_lse(attn_out, attn_lse) attn_out = _npu_attention_update(self.head_size, attn_out_lse) return attn_out - def _update_out_and_lse(self, out_list: torch.Tensor, - lse_list: torch.Tensor) -> torch.Tensor: - """LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i) - Args: - out_list: shape = [N, batch_size, num_heads, head_size] - lse_list: shape = [N, batch_size, num_heads, 1] - Returns: - out_final: shape = [batch_size, num_heads, head_size] - lse_final: shape = [batch_size, num_heads, 1] - """ - lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False) - out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list, - dim=0) - return out_final, lse_final - def _update_chunk_attn_out_lse_with_current_attn_out_lse( self, current_attn_output_prefill, current_attn_lse_prefill, attn_output_full_chunk, attn_lse_full_chunk, prefill_query, @@ -651,11 +617,6 @@ def _compute_prefill_context(self, query: torch.Tensor, actual_seq_lengths_kv, actual_seq_lengths=attn_metadata.prefill.chunked_context. actual_chunk_seq_lengths) - batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask - lse_mask = batch_chunk_seq_mask[:, None, - None].expand_as(prefix_chunk_lse) - prefix_chunk_lse = torch.where(lse_mask, -torch.inf, - prefix_chunk_lse) return prefix_chunk_output, prefix_chunk_lse @@ -781,8 +742,9 @@ def _update_global_context_output(self, global_context_output): # Split out lse attn_out_allgather, attn_lse_allgather = torch.split( x, [D, 1], dim=-1) # [N, S, H, D], [N, S, H, 1] - context_output, context_lse = self._update_out_and_lse( - attn_out_allgather, attn_lse_allgather) + context_output, context_lse = self._npu_attn_out_lse_update( + attn_lse_allgather[1], attn_lse_allgather[0], + attn_out_allgather[1], attn_out_allgather[0]) return context_output, context_lse def forward_impl( diff --git a/vllm_ascend/attention/context_parallel/common_cp.py b/vllm_ascend/attention/context_parallel/common_cp.py index 018919c01dd..00c305c0ff7 100644 --- a/vllm_ascend/attention/context_parallel/common_cp.py +++ b/vllm_ascend/attention/context_parallel/common_cp.py @@ -71,19 +71,14 @@ class ChunkedContextMetadata: class AscendMetadataForDecode: """ Decode Specific Metadata for Ascend""" num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None - batch_seq_mask: torch.Tensor = None block_tables: torch.Tensor = None -def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor, - batch_seq_mask: torch.Tensor) -> torch.Tensor: +def _process_attn_out_lse(attn_output: torch.Tensor, + softmax_lse: torch.Tensor) -> torch.Tensor: pcp_size = get_pcp_group().world_size dcp_size = get_decode_context_model_parallel_world_size() dcp_group = get_dcp_group().device_group if dcp_size > 1 else None - out_mask = batch_seq_mask[:, None, None].expand_as(attn_output) - attn_output = torch.where(out_mask, 0, attn_output) - lse_mask = batch_seq_mask[:, None, None].expand_as(softmax_lse) - softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse) softmax_lse = softmax_lse.to(torch.float32) attn_output = attn_output.to(torch.float32) # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index 6ff20557336..211a094e768 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -13,6 +13,8 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec +from vllm_ascend.attention.attention_v1 import AscendAttentionState + # isort: off from vllm_ascend.attention.mla_v1 import ( AscendMLADecodeMetadata, AscendMLAImpl, AscendMLAMetadata, @@ -61,14 +63,6 @@ def __init__( ) if self.dcp_size > 1 else 0 self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size - scheduler_config = vllm_config.scheduler_config - decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs', - 0) - max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) - self.batch_seq_mask_buf = torch.empty(max_num_seqs * - self.decode_threshold, - dtype=torch.uint8, - device=device) def build( self, @@ -251,13 +245,11 @@ def build_decode_metadata( cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank, self.dcp_rank] cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32) - batch_seq_mask = (cp_seq_len == 0) - self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( - batch_seq_mask, non_blocking=True) - batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.shape[0]] - cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) - decode_metadata.cp_seq_len = cp_seq_len - decode_metadata.batch_seq_mask = batch_seq_mask + + decode_metadata.cp_seq_len = cp_seq_len.tolist() + actual_seq_lengths_q = torch.arange(self.num_decodes_flatten) + 1 + decode_metadata.actual_seq_lengths_q = actual_seq_lengths_q + return decode_metadata @@ -571,20 +563,51 @@ def _forward_decode( num_heads = self.num_heads * self.dcp_size else: num_heads = self.num_heads - - k_nope = k_nope.view(-1, block_size, self.num_kv_heads, + # use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, self.kv_lora_rank) - k_pe = k_pe.view(-1, block_size, self.num_kv_heads, + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, self.qk_rope_head_dim) - q_nope = q_nope.view(num_tokens, num_heads, -1) - q_pe = q_pe.view(num_tokens, num_heads, -1) - # use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask - seq_len = decode_meta.cp_seq_len + + actual_seq_lengths = None + input_layout = "BNSD" + + if attn_metadata.attn_state in [ + AscendAttentionState.SpecDecoding, + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.DecodeOnly, + ] and self.speculative_config is not None: + input_layout = "TND" + # TODO: If the driver is upgraded later, the contiguous function can be deleted. + q_nope = q_nope.view(num_tokens, num_heads, -1).contiguous() + q_pe = q_pe.view(num_tokens, num_heads, -1) + sparse_mode = 3 + spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore + actual_seq_lengths = decode_meta.actual_seq_lengths_q + else: + q_nope = q_nope.view(num_tokens, num_heads, 1, -1).contiguous() + q_pe = q_pe.view(num_tokens, num_heads, 1, -1) + sparse_mode = 0 + spec_attn_mask = None common_kwargs = { - "return_lse": True, - "calc_type": "calc_type_ring", + 'query_rope': q_pe, + 'key_rope': k_pe, + 'num_heads': num_heads, + 'num_key_value_heads': self.num_kv_heads, + 'input_layout': input_layout, + 'atten_mask': spec_attn_mask, + 'sparse_mode': sparse_mode, + 'scale': self.scale, + 'antiquant_mode': 0, + 'antiquant_scale': None, + 'block_table': decode_meta.block_table, + 'block_size': block_size, + 'actual_seq_lengths': actual_seq_lengths, + 'actual_seq_lengths_kv': decode_meta.cp_seq_len, + 'softmax_lse_flag': True, } + forward_context: ForwardContext = get_forward_context() if forward_context.is_draft_model: graph_params = get_draft_graph_params() @@ -598,61 +621,56 @@ def _forward_decode( graph_params.events[num_tokens].append(event) workspace = graph_params.workspaces.get(num_tokens) if workspace is None: - workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace( - q_nope, q_pe, k_nope, k_pe, decode_meta.block_table, - seq_len, num_heads, self.scale, self.num_kv_heads, - **common_kwargs) + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + q_nope, k_nope, k_nope, **common_kwargs) update_graph_params_workspaces(num_tokens, workspace) attn_output = torch.empty_like(q_nope) - softmax_lse = torch.empty((num_tokens, num_heads, 1), - dtype=q_nope.dtype, - device=q_nope.device) + if input_layout == "BNSD": + softmax_lse = torch.empty((num_tokens, num_heads, 1, 1), + dtype=torch.float, + device=q_nope.device) + else: + softmax_lse = torch.empty((num_tokens, num_heads, 1), + dtype=torch.float, + device=q_nope.device) + graph_params.attn_params[num_tokens].append( - (weak_ref_tensors(q_nope), weak_ref_tensors(q_pe), - weak_ref_tensors(k_nope), weak_ref_tensors(k_pe), - decode_meta.block_table, seq_len, num_heads, self.scale, - self.num_kv_heads, weak_ref_tensors(attn_output), - weak_ref_tensors(softmax_lse))) + (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), + weak_ref_tensors(q_pe), weak_ref_tensors(k_pe), num_heads, + self.num_kv_heads, input_layout, + weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None + else None, sparse_mode, self.scale, + weak_ref_tensors(decode_meta.block_table), block_size, + actual_seq_lengths, decode_meta.cp_seq_len, + weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse))) torch.npu.graph_task_group_begin(stream) - torch_npu.atb.npu_multi_head_latent_attention( + torch_npu.npu_fused_infer_attention_score.out( q_nope, - q_pe, k_nope, - k_pe, - decode_meta.block_table, - seq_len, - num_heads, - self.scale, - self.num_kv_heads, + k_nope, **common_kwargs, workspace=workspace, - output=attn_output, - lse=softmax_lse) + out=[attn_output, softmax_lse]) handle = torch.npu.graph_task_group_end(stream) graph_params.handles[num_tokens].append(handle) else: - attn_output = torch.empty_like(q_nope) - softmax_lse = torch.empty((num_tokens, num_heads, 1), - dtype=q_nope.dtype, - device=q_nope.device) - torch_npu.atb.npu_multi_head_latent_attention( + attn_output, softmax_lse = torch_npu.npu_fused_infer_attention_score( q_nope, - q_pe, k_nope, - k_pe, - decode_meta.block_table, - seq_len, - num_heads, - self.scale, - self.num_kv_heads, - return_lse=True, - calc_type="calc_type_ring", - output=attn_output, - lse=softmax_lse) + k_nope, + **common_kwargs, + ) + if input_layout == "BNSD": + B_attn, N_attn, S, D = attn_output.shape + B_lse, N_lse, Q_S, _ = softmax_lse.shape + + attn_output = attn_output.permute(0, 2, 1, 3).reshape( + B_attn * S, N_attn, D) + softmax_lse = softmax_lse.permute(0, 2, 1, 3).reshape( + B_lse * Q_S, N_lse, 1) # Update out&lse - attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse, - decode_meta.batch_seq_mask) + attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse) attn_output = _npu_attention_update(self.kv_lora_rank, attn_out_lse) return self._v_up_proj(attn_output) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 38cc7fd336a..9a6b1a0f881 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -128,7 +128,6 @@ class AscendMLADecodeMetadata: sin: torch.Tensor = None cos: torch.Tensor = None cp_seq_len: torch.Tensor = None - batch_seq_mask: torch.Tensor = None @dataclass @@ -588,7 +587,7 @@ def build_decode_metadata( self.block_table = self.block_table[:self.graph_pad_size, ...] seq_lens_list = self.seq_lens.tolist() - cp_seq_len, batch_seq_mask = None, None + cp_seq_len = None if self.graph_pad_size > num_reqs: if self.speculative_config.disable_padded_drafter_batch: @@ -649,8 +648,7 @@ def build_decode_metadata( actual_seq_lengths_q=actual_seq_lengths_q, sin=sin[:self.num_decode_tokens, ...], cos=cos[:self.num_decode_tokens, ...], - cp_seq_len=cp_seq_len, - batch_seq_mask=batch_seq_mask) + cp_seq_len=cp_seq_len) return decode_metadata def build_for_graph_capture( diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 29ec57930c0..73575fe3d86 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -473,38 +473,46 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context, graph_params.handles[runtime_shape], graph_params.events[runtime_shape], ): - (q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, - scale, num_kv_heads, attn_output, softmax_lse) = param + + (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, + spec_attn_mask, sparse_mode, scale, block_table, block_size, + actual_seq_lengths, actual_seq_lengths_kv, attn_output, + softmax_lse) = param decode_meta = forward_context.attn_metadata[key].decode seq_len = decode_meta.cp_seq_len + if isinstance(seq_len, torch.Tensor): + seq_len = seq_len.tolist() + actual_seq_lengths_kv = seq_len - # For pcp + spec decode, we flatten seq_lens - # to avoid irregular attn_mask shape, - # so there's no need to divide runtime_shape by spec_multiple - pad_length = runtime_shape - len(seq_len) - pad_tensor = torch.zeros(pad_length, - dtype=seq_len.dtype, - device=seq_len.device) - seq_len = torch.cat([seq_len, pad_tensor], dim=0) + pad_length = runtime_shape - len(actual_seq_lengths_kv) + if pad_length > 0: + actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * ( + runtime_shape - len(actual_seq_lengths_kv)) torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu.atb.npu_multi_head_latent_attention( + torch_npu.npu_fused_infer_attention_score.out( q_nope, - q_pe, k_nope, - k_pe, - block_table, - seq_len, - num_heads, - scale, - num_kv_heads, - return_lse=True, - calc_type="calc_type_ring", + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_heads=num_heads, + num_key_value_heads=num_kv_heads, + input_layout=input_layout, + atten_mask=spec_attn_mask, + sparse_mode=sparse_mode, + scale=scale, + antiquant_mode=0, + antiquant_scale=None, + softmax_lse_flag=True, + block_table=block_table, + block_size=block_size, + actual_seq_lengths_kv=actual_seq_lengths_kv, + actual_seq_lengths=actual_seq_lengths, workspace=graph_params.workspaces.get(runtime_shape), - output=attn_output, - lse=softmax_lse) + out=[attn_output, softmax_lse]) torch.npu.graph_task_update_end(update_stream) event.record(update_stream) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 2f4ac4adc75..ad8f5784002 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -508,7 +508,7 @@ def _propose( self.positions[:batch_size] = clamped_positions self.hidden_states[:hidden_states.shape[0]] = hidden_states if self.pcp_size * self.dcp_size > 1: - # update local seq_len and batch_seq_mask + # update local seq_len num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens( ori_seq_len + step + 1, self.pcp_size, @@ -517,14 +517,7 @@ def _propose( ) cp_seq_len = \ num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank] - batch_seq_mask = (cp_seq_len == 0) - builder.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( - batch_seq_mask, non_blocking=True) - batch_seq_mask = builder.batch_seq_mask_buf[:batch_seq_mask. - shape[0]] - cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) attn_metadata_i.decode.cp_seq_len = cp_seq_len - attn_metadata_i.decode.batch_seq_mask = batch_seq_mask # update slot_mapping slot_indices += self.pcp_size slot_mapping = mtp_slot_mapping[slot_indices]