diff --git a/docs/source/user_guide/feature_guide/context_parallel.md b/docs/source/user_guide/feature_guide/context_parallel.md index c7df7df8dfc..04e004d7d24 100644 --- a/docs/source/user_guide/feature_guide/context_parallel.md +++ b/docs/source/user_guide/feature_guide/context_parallel.md @@ -16,8 +16,8 @@ To learn more about the theory and implementation details of context parallel, p Currently context parallel can be used together with most other features, supported features are as follows: | | Eager | Graph | Prefix
Cache | Chunked
Prefill | SpecDecode
(MTP) | PD
disaggregation | MLAPO | | ------- | ----- | ----- | ------ | ------ | ----- | ----- | ----- | -| **PCP** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| **DCP** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| **PCP** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅| +| **DCP** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ## How to use Context Parallel You can enable `PCP` and `DCP` by `prefill_context_parallel_size` and `decode_context_parallel_size`, refer to the following example: diff --git a/tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py b/tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py index 816c25f0512..4053ccd2bb4 100644 --- a/tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py +++ b/tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py @@ -79,7 +79,7 @@ def test_qwen3_next_mtp_acceptance_tp4(model_name): for num_accepted_tokens in num_accepted_tokens_per_pos ] - match = all(abs(a - b) < 0.05 for a, b in zip(acceptance_per_pos, golden)) + match = all(abs(a - b) < 0.06 for a, b in zip(acceptance_per_pos, golden)) if not match: print(f"acceptance_per_pos: {acceptance_per_pos}") print(f"golden: {golden}") diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py index 307daacf76a..7c69c12c9a8 100644 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -278,6 +278,7 @@ def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer): [0, 8, 16, 24], dtype=torch.int32) mock_common_attn_metadata.seq_lens = torch.tensor([8, 8, 8], dtype=torch.int32) + mock_common_attn_metadata.num_actual_tokens = 24 mock_common_attn_metadata.num_reqs = 3 mock_common_attn_metadata.num_computed_tokens_cpu = torch.tensor( [5, 6, 7], dtype=torch.int32) @@ -293,10 +294,12 @@ def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer): mock_runner.actual_seq_lengths_q = MagicMock() mock_runner.attn_state = MagicMock() mock_runner.graph_pad_size = 0 + mock_runner.pcp_size = 1 mock_runner.decode_token_per_req = MagicMock() proposer = MagicMock(spec=MtpProposer) proposer.runner = mock_runner + proposer.pcp_size = 1 proposer.arange = torch.arange(100, dtype=torch.int32) proposer.prepare_inputs_padded = MtpProposer.prepare_inputs_padded.__get__( proposer) diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index 8b37765cf9c..6ff20557336 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -70,6 +70,26 @@ def __init__( dtype=torch.uint8, device=device) + def build( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + fast_build: bool = False, + ) -> AscendMLAMetadata: + metadata_cls = super().build(common_prefix_len, common_attn_metadata) + if self.num_prefills == 0 and self.pcp_size > 1: + self.slot_mapping[:self. + num_decode_tokens] = self.slot_mapping[:self. + num_decode_tokens + * self. + pcp_size: + self. + pcp_size] + self.slot_mapping[self.num_decode_tokens:self.num_decode_tokens * + self.pcp_size].fill_(-1) + metadata_cls.slot_mapping = self.slot_mapping + return metadata_cls + @classmethod def get_cudagraph_support( cls: type["AscendMlaCPMetadataBuilder"], @@ -363,8 +383,7 @@ def mla_preprocess_decode(self, q_c, kv_no_split, kv_cache, attn_metadata): decode_ql_nope, decode_q_pe = self.reorg_decode_q( decode_ql_nope, decode_q_pe) decode_q_pe = self.rope_single(decode_q_pe, cos, sin) - decode_slots = attn_metadata.slot_mapping[:num_decode_tokens * - self.pcp_size:self.pcp_size] + decode_slots = attn_metadata.slot_mapping[:num_decode_tokens] decode_kv_no_split = kv_no_split[:num_decode_tokens] decode_k_pe, decode_k_nope = self.exec_kv_decode( decode_kv_no_split, cos, sin, kv_cache, decode_slots) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 04b59dd5c03..c00e5777ba2 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -439,7 +439,6 @@ def build( if self.num_decodes > 0: decode_metadata = self.build_decode_metadata( common_prefix_len, common_attn_metadata) - return self.metadata_cls( # type: ignore num_actual_tokens_pcp_padded=self.num_actual_tokens, num_input_tokens=common_attn_metadata.num_input_tokens, @@ -1334,7 +1333,7 @@ def _mla_preprocess_only_decode(self, hidden_states, kv_cache, self.W_UK_T, decode_k_nope, decode_k_pe, - attn_metadata.slot_mapping[:bsz].flatten(), + attn_metadata.slot_mapping[:bsz], quant_scale0=self.quant_scale0, quant_offset0=self.quant_offset0, bias0=self.quant_bias_qkv, diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 1684ab562a1..985a7efe76f 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -750,7 +750,8 @@ def prepare_inputs_padded( query_start_loc_cpu=query_start_loc_cpu, seq_lens_cpu=common_attn_metadata.seq_lens_cpu, num_reqs=common_attn_metadata.num_reqs, - num_actual_tokens=total_num_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens + if self.pcp_size > 1 else total_num_tokens, num_input_tokens=common_attn_metadata.num_input_tokens, max_query_len=new_query_len_per_req.max().item(), actual_seq_lengths_q=self.runner.actual_seq_lengths_q, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 5b16e433fe4..b88af20c7c8 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -912,12 +912,15 @@ def _prepare_inputs( self.input_batch) blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1) if self.pcp_size > 1: - slot_mapping = self.pcp_manager.get_padded_slot_mapping( + slot_mapping_pcp = self.pcp_manager.get_padded_slot_mapping( total_num_scheduled_tokens, slot_mapping, ) blk_table.slot_mapping.gpu[:self.pcp_manager. - num_actual_tokens_pcp_padded] = slot_mapping + num_actual_tokens_pcp_padded] = slot_mapping_pcp + slot_mapping = blk_table.slot_mapping.gpu[:self. + pcp_manager. + num_actual_tokens_pcp_padded] # NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs # has been split to multiple parts, and there are 3 parts that is related to this