diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_accuracy.py b/tests/e2e/multicard/4-cards/long_sequence/test_accuracy.py index 61d99c97e8e..7df19de43b1 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_accuracy.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_accuracy.py @@ -210,3 +210,72 @@ def test_accuracy_pcp_only(max_tokens: int, ) -> None: name_0="vllm_eager_outputs", name_1="vllm_pcp_only_outputs", ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [10]) +def test_models_long_sequence_cp_kv_interleave_size_output_between_tp_and_cp( + model: str, + max_tokens: int, +) -> None: + prompts = [ + "The president of the United States is", "The capital of France is" + ] + + common_kwargs = { + "max_model_len": 1024, + } + + if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8": + cp_kwargs = { + "tensor_parallel_size": 2, + "decode_context_parallel_size": 2, + "prefill_context_parallel_size": 2, + "enable_expert_parallel": True, + "cp_kv_cache_interleave_size": 128, + "enforce_eager": True, + "quantization": "ascend", + } + tp_kwargs = { + "tensor_parallel_size": 4, + "enable_expert_parallel": True, + "enforce_eager": True, + "quantization": "ascend", + } + + else: + cp_kwargs = { + "tensor_parallel_size": 1, + "decode_context_parallel_size": 1, + "prefill_context_parallel_size": 2, + "cp_kv_cache_interleave_size": 128, + "compilation_config": { + "cudagraph_mode": "FULL_DECODE_ONLY", + "cudagraph_capture_sizes": [4, 8, 24, 48, 60] + }, + } + tp_kwargs = { + "tensor_parallel_size": 2, + "enforce_eager": True, + } + + cp_full_kwargs = {} + cp_full_kwargs.update(common_kwargs) # type: ignore + cp_full_kwargs.update(cp_kwargs) # type: ignore + + tp_full_kwargs = {} + tp_full_kwargs.update(common_kwargs) # type: ignore + tp_full_kwargs.update(tp_kwargs) # type: ignore + with VllmRunner(model, **cp_full_kwargs) as runner: # type: ignore + vllm_context_parallel_outputs = runner.generate_greedy( + prompts, max_tokens) + + with VllmRunner(model, **tp_full_kwargs) as runner: # type: ignore + vllm_eager_outputs = runner.generate_greedy(prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs, + outputs_1_lst=vllm_context_parallel_outputs, + name_0="vllm_eager_outputs", + name_1="vllm_context_parallel_outputs", + ) \ No newline at end of file diff --git a/tests/ut/attention/test_attention_cp.py b/tests/ut/attention/test_attention_cp.py index cc518fdab53..699f829acd2 100644 --- a/tests/ut/attention/test_attention_cp.py +++ b/tests/ut/attention/test_attention_cp.py @@ -113,8 +113,6 @@ def mock_npu_fused_infer_attention_score_func(query, k_nope, value, attn_metadata = MagicMock() attn_metadata.decode_meta = MagicMock() - attn_metadata.decode_meta.batch_seq_mask = torch.tensor( - [1, 0], dtype=torch.bool) output = self.impl._forward_decode_pcp_dcp(query, attn_metadata) self.assertEqual(output.shape[0], 2) diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index f02c1b88f14..3cb0cee18b5 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -439,11 +439,7 @@ def test_process_attn_out_lse(self): decode_metadata = MagicMock() decode_metadata.actual_seq_lengths_q = MagicMock() decode_metadata.seq_lens_list = MagicMock() - decode_metadata.batch_seq_mask = torch.tensor([True, False], - dtype=torch.bool) - - result = _process_attn_out_lse(attn_output, softmax_lse, - decode_metadata.batch_seq_mask) + result = _process_attn_out_lse(attn_output, softmax_lse) self.assertEqual(result.shape[0], B * self.impl.pcp_size) self.assertEqual(result.shape[1], N) @@ -478,8 +474,6 @@ def test_forward_decode_pcp_dcp(self, mock_npu_attention_update, attn_metadata.decode = MagicMock() attn_metadata.decode.actual_seq_lengths_q = MagicMock() attn_metadata.decode.seq_lens_list = MagicMock() - attn_metadata.decode.batch_seq_mask = torch.tensor([False, False], - dtype=torch.bool) self.impl.enable_kv_nz = True @@ -886,12 +880,8 @@ def test_process_attn_out_lse_with_dcp_pcp(self, mock_all_to_all, mock_dcp, # Inputs attn_output = torch.randn(B, H, D) softmax_lse = torch.randn(B, H, 1) - batch_seq_mask = torch.tensor([False, True, False, False]) # [B] - decode_meta = MagicMock() - decode_meta.batch_seq_mask = batch_seq_mask - result = _process_attn_out_lse(attn_output, softmax_lse, - batch_seq_mask) + result = _process_attn_out_lse(attn_output, softmax_lse) # [PCP * S, DCP * H, D + 1] self.assertIsInstance(result, torch.Tensor) assert result.shape == (B * self.impl.pcp_size, H, D + 1) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 46a58626753..e7bea5135dd 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -121,7 +121,6 @@ def test_ascend_mla_decode_metadata_default(self): seq_lens_list = [2, 3] attn_mask = None cp_seq_len = torch.tensor([2, 3]) - batch_seq_mask = torch.tensor([[1, 1, 0, 0], [1, 1, 1, 0]]) metadata = AscendMLADecodeMetadata(input_positions=input_positions, block_table=block_table, @@ -129,8 +128,7 @@ def test_ascend_mla_decode_metadata_default(self): max_seq_lens=max_seq_lens, seq_lens_list=seq_lens_list, attn_mask=attn_mask, - cp_seq_len=cp_seq_len, - batch_seq_mask=batch_seq_mask) + cp_seq_len=cp_seq_len) self.assertIs(metadata.input_positions, input_positions) self.assertIs(metadata.block_table, block_table) @@ -139,7 +137,6 @@ def test_ascend_mla_decode_metadata_default(self): self.assertEqual(metadata.seq_lens_list, seq_lens_list) self.assertIsNone(attn_mask) self.assertIs(metadata.cp_seq_len, cp_seq_len) - self.assertIs(metadata.batch_seq_mask, batch_seq_mask) class TestAscendMLAMetadata(TestBase): diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index 29c0d8072a6..7dcd84f217d 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -59,10 +59,6 @@ def __init__( device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) - self.batch_seq_mask_buf = torch.empty( - vllm_config.scheduler_config.max_num_batched_tokens, - dtype=torch.uint8, - device=device) self.pcp_size = get_pcp_group().world_size self.pcp_rank = get_pcp_group( ).rank_in_group if self.pcp_size > 1 else 0 @@ -229,16 +225,11 @@ def build( num_computed_tokens_array = np.array( num_computed_tokens_of_pcp_dcp) num_computed_tokens_array = num_computed_tokens_array[:num_decodes] - batch_seq_mask = (num_computed_tokens_array[:, self.pcp_rank, - self.dcp_rank] == 0) # TODO: numpy array mode of the shared memory is used to improve performance - self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( - torch.from_numpy(batch_seq_mask), non_blocking=True) decode_metadata = AscendMetadataForDecode( num_computed_tokens_of_pcp_dcp=num_computed_tokens_array, - batch_seq_mask=self.batch_seq_mask_buf[:batch_seq_mask. - shape[0]], - block_tables=block_table[:num_decodes]) + block_tables=block_table[:num_decodes], + ) attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, @@ -549,8 +540,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, else: attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score( query, k_nope, value, **common_kwargs) - attn_out_lse = _process_attn_out_lse( - attn_out, attn_lse, attn_metadata.decode_meta.batch_seq_mask) + attn_out_lse = _process_attn_out_lse(attn_out, attn_lse) attn_out = _npu_attention_update(self.head_size, attn_out_lse) return attn_out @@ -661,11 +651,8 @@ def _compute_prefill_context(self, query: torch.Tensor, actual_seq_lengths_kv=prefill_metadata.chunked_context. actual_seq_lengths_kv, actual_seq_lengths=attn_metadata.prefill.chunked_context. - actual_chunk_seq_lengths) - batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask - lse_mask = batch_chunk_seq_mask[:, None, - None].expand_as(prefix_chunk_lse) - prefix_chunk_lse = torch.where(lse_mask, -torch.inf, prefix_chunk_lse) + actual_chunk_seq_lengths, + ) return prefix_chunk_output, prefix_chunk_lse diff --git a/vllm_ascend/attention/context_parallel/common_cp.py b/vllm_ascend/attention/context_parallel/common_cp.py index 018919c01dd..a23e9c23d72 100644 --- a/vllm_ascend/attention/context_parallel/common_cp.py +++ b/vllm_ascend/attention/context_parallel/common_cp.py @@ -69,21 +69,17 @@ class ChunkedContextMetadata: @dataclass class AscendMetadataForDecode: - """ Decode Specific Metadata for Ascend""" - num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None - batch_seq_mask: torch.Tensor = None + """Decode-specific metadata for Ascend attention with Context Parallelism.""" + + num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None block_tables: torch.Tensor = None -def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor, - batch_seq_mask: torch.Tensor) -> torch.Tensor: +def _process_attn_out_lse(attn_output: torch.Tensor, + softmax_lse: torch.Tensor) -> torch.Tensor: pcp_size = get_pcp_group().world_size dcp_size = get_decode_context_model_parallel_world_size() dcp_group = get_dcp_group().device_group if dcp_size > 1 else None - out_mask = batch_seq_mask[:, None, None].expand_as(attn_output) - attn_output = torch.where(out_mask, 0, attn_output) - lse_mask = batch_seq_mask[:, None, None].expand_as(softmax_lse) - softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse) softmax_lse = softmax_lse.to(torch.float32) attn_output = attn_output.to(torch.float32) # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index 0a75323b353..539f2a330e6 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -63,14 +63,6 @@ def __init__( ) if self.dcp_size > 1 else 0 self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size - scheduler_config = vllm_config.scheduler_config - decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs', - 0) - max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) - self.batch_seq_mask_buf = torch.empty(max_num_seqs * - self.decode_threshold, - dtype=torch.uint8, - device=device) self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd( self.block_size, self.cp_virtual_block_size) @@ -256,13 +248,7 @@ def build_decode_metadata( cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank, self.dcp_rank] cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32) - batch_seq_mask = (cp_seq_len == 0) - self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( - batch_seq_mask, non_blocking=True) - batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.shape[0]] - cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) decode_metadata.cp_seq_len = cp_seq_len.tolist() - decode_metadata.batch_seq_mask = batch_seq_mask actual_seq_lengths_q = torch.arange(self.num_decodes_flatten) + 1 decode_metadata.actual_seq_lengths_q = actual_seq_lengths_q @@ -688,8 +674,7 @@ def _forward_decode( B_lse * Q_S, N_lse, 1) # Update out&lse - attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse, - decode_meta.batch_seq_mask) + attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse) attn_output = _npu_attention_update(self.kv_lora_rank, attn_out_lse) return self._v_up_proj(attn_output) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 5f77bfea222..3ddbbc1c2df 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -130,7 +130,6 @@ class AscendMLADecodeMetadata: sin: torch.Tensor = None cos: torch.Tensor = None cp_seq_len: torch.Tensor = None - batch_seq_mask: torch.Tensor = None @dataclass @@ -590,7 +589,7 @@ def build_decode_metadata( self.block_table = self.block_table[:self.graph_pad_size, ...] seq_lens_list = self.seq_lens.tolist() - cp_seq_len, batch_seq_mask = None, None + cp_seq_len = None if self.graph_pad_size > num_reqs: if self.speculative_config.disable_padded_drafter_batch: @@ -651,8 +650,7 @@ def build_decode_metadata( actual_seq_lengths_q=actual_seq_lengths_q, sin=sin[:self.num_decode_tokens, ...], cos=cos[:self.num_decode_tokens, ...], - cp_seq_len=cp_seq_len, - batch_seq_mask=batch_seq_mask) + cp_seq_len=cp_seq_len) return decode_metadata def build_for_graph_capture( diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index e61692532f8..9a2476f3456 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -512,7 +512,7 @@ def _propose( self.positions[:batch_size] = clamped_positions self.hidden_states[:hidden_states.shape[0]] = hidden_states if self.pcp_size * self.dcp_size > 1: - # update local seq_len and batch_seq_mask + # update local seq_len num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens( ori_seq_len + step + 1, self.pcp_size, @@ -521,14 +521,7 @@ def _propose( ) cp_seq_len = \ num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank] - batch_seq_mask = (cp_seq_len == 0) - builder.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( - batch_seq_mask, non_blocking=True) - batch_seq_mask = builder.batch_seq_mask_buf[:batch_seq_mask. - shape[0]] - cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) attn_metadata_i.decode.cp_seq_len = cp_seq_len - attn_metadata_i.decode.batch_seq_mask = batch_seq_mask # update slot_mapping slot_indices += self.pcp_size slot_mapping = mtp_slot_mapping[slot_indices]