From 8ddc4035390afb4b0ddbd383d9b1ddd16bb0780e Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Thu, 20 Nov 2025 09:04:44 +0800 Subject: [PATCH 1/7] [PCP] common supports for PCP Co-authored-by: QiuChunshuo Co-authored-by: FENP Co-authored-by: LookAround Co-authored-by: Jingchun Gao Co-authored-by: zhenwenqi2024 Signed-off-by: QiuChunshuo Signed-off-by: FENP Signed-off-by: LookAround Signed-off-by: Jingchun Gao Signed-off-by: zhenwenqi2024 --- tests/distributed/test_context_parallel.py | 2 +- vllm/config/parallel.py | 5 - vllm/v1/attention/backends/flash_attn.py | 4 +- vllm/v1/attention/backends/mla/common.py | 4 +- vllm/v1/attention/backends/utils.py | 45 ++- vllm/v1/spec_decode/eagle.py | 8 +- vllm/v1/worker/gpu_model_runner.py | 396 +++++++++++++++++++-- 7 files changed, 406 insertions(+), 58 deletions(-) diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 7e4713b8aece..293423297e19 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -196,7 +196,7 @@ def _compare_cp_with_tp( str(pp_size), "--decode-context-parallel-size", str(dcp_size), - "--dcp-kv-cache-interleave-size", + "--cp-kv-cache-interleave-size", str(cp_kv_cache_interleave_size), "--distributed-executor-backend", distributed_backend, diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 4b0236d8de3f..39ffae4e9a41 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -324,11 +324,6 @@ def _validate_parallel_config(self) -> Self: "num_redundant_experts." ) - if self.prefill_context_parallel_size > 1: - raise ValueError( - "Prefill context parallelism is not fully supported. " - "Please set prefill_context_parallel_size to 1." - ) return self @property diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index cf3c1d05f5b3..e29d21f73068 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -45,7 +45,7 @@ AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - get_dcp_local_seq_lens, + get_cp_local_seq_lens, get_kv_cache_layout, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -384,7 +384,7 @@ def schedule( ) dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu - dcp_context_kv_lens_cpu = get_dcp_local_seq_lens( + dcp_context_kv_lens_cpu = get_cp_local_seq_lens( dcp_context_kv_lens_cpu, self.dcp_world_size, self.dcp_rank, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 32f406980f2e..1a7f3c1d0a59 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -225,7 +225,7 @@ from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, - get_dcp_local_seq_lens, + get_cp_local_seq_lens, get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills, @@ -832,7 +832,7 @@ def build( ) if self.dcp_world_size > 1: - local_context_lens_allranks = get_dcp_local_seq_lens( + local_context_lens_allranks = get_cp_local_seq_lens( context_lens_cpu, self.dcp_world_size, None, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 540a8e2b1d01..bbad992975e6 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -48,6 +48,24 @@ def is_valid_kv_cache_layout(value: str) -> bool: return value in get_args(KVCacheLayoutType) +@dataclass +class PrefillContextParallelMetadata: + """ + Attention metadata for prefill context parallel + """ + allgather_restore_idx: torch.Tensor + """ + We split and concatenate the sequence in a head-tail style, + and use this variable to restore the original order. + """ + q_head_indices: torch.Tensor | None = None + q_tail_indices: torch.Tensor | None = None + q_head_start_loc: torch.Tensor | None = None + kv_for_head_indices: torch.Tensor | None = None + kv_for_tail_indices : torch.Tensor | None = None + kv_for_head_indptr: torch.Tensor | None = None + kv_for_tail_indptr: torch.Tensor | None = None + q_full_indices: torch.Tensor | None = None @dataclass class CommonAttentionMetadata: @@ -91,10 +109,11 @@ class CommonAttentionMetadata: # Needed by CrossAttentionBuilder encoder_seq_lens: np.ndarray | None = None - dcp_local_seq_lens: torch.Tensor | None = None - dcp_local_seq_lens_cpu: torch.Tensor | None = None - """Sequence lengths of the local rank in decode context parallelism world""" + cp_local_seq_lens: torch.Tensor | None = None + cp_local_seq_lens_cpu: torch.Tensor | None = None + """Sequence lengths of the local rank in prefill/decode context parallelism world""" + pcp_metadata: PrefillContextParallelMetadata | None = None def slice_query_start_locs( query_start_loc: torch.Tensor, @@ -1078,10 +1097,10 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): return nums_dict, batch_ptr, token_chunk_offset_ptr -def get_dcp_local_seq_lens( +def get_cp_local_seq_lens( seq_lens: torch.Tensor, - dcp_size: int = 1, - dcp_rank: int | None = None, + cp_size: int = 1, + cp_rank: int | None = None, cp_kv_cache_interleave_size: int = 1, ) -> torch.Tensor: """While using dcp, kv_cache size stored on each rank may be different, @@ -1089,28 +1108,28 @@ def get_dcp_local_seq_lens( Only consider dcp now, we can extend the case of cp based on this. """ num_requests = seq_lens.size(0) - if dcp_rank is None: + if cp_rank is None: rank_offsets = ( - torch.arange(dcp_size, dtype=torch.int32) + torch.arange(cp_size, dtype=torch.int32) .unsqueeze(0) .repeat(num_requests, 1) ) else: - rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32) + rank_offsets = torch.Tensor([[cp_rank]]).to(dtype=torch.int32) seq_lens_tiled = ( seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1]) ) base = ( seq_lens_tiled // cp_kv_cache_interleave_size - // dcp_size + // cp_size * cp_kv_cache_interleave_size ) - remainder = seq_lens_tiled - base * dcp_size + remainder = seq_lens_tiled - base * cp_size remainder = torch.clip( remainder - rank_offsets * cp_kv_cache_interleave_size, 0, cp_kv_cache_interleave_size, ) - dcp_local_seq_lens = base + remainder - return dcp_local_seq_lens.squeeze(1) + cp_local_seq_lens = base + remainder + return cp_local_seq_lens.squeeze(1) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 406bb696bd4c..5cab0f6adea1 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -644,7 +644,9 @@ def prepare_inputs_padded( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, - dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, + cp_local_seq_lens=common_attn_metadata.cp_local_seq_lens, + cp_local_seq_lens_cpu=common_attn_metadata.cp_local_seq_lens_cpu, + pcp_metadata=common_attn_metadata.pcp_metadata, ) token_indices_to_sample = ( @@ -921,7 +923,9 @@ def prepare_inputs( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, - dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, + cp_local_seq_lens=common_attn_metadata.cp_local_seq_lens, + cp_local_seq_lens_cpu=common_attn_metadata.cp_local_seq_lens_cpu, + pcp_metadata=common_attn_metadata.pcp_metadata, ) return spec_common_attn_metadata, token_indices diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 80f8344d4410..df41096f9352 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -41,6 +41,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( get_dcp_group, + get_pcp_group, get_pp_group, get_tp_group, graph_capture, @@ -93,7 +94,8 @@ AttentionMetadataBuilder, CommonAttentionMetadata, create_fast_prefill_custom_backend, - get_dcp_local_seq_lens, + get_cp_local_seq_lens, + PrefillContextParallelMetadata, reorder_batch_to_split_decodes_and_prefills, split_attn_metadata, ) @@ -300,6 +302,10 @@ def __init__( self.calculate_kv_scales = self.cache_config.calculate_kv_scales self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group + self.pcp_world_size = self.parallel_config.prefill_context_parallel_size + self.pcp_rank = 0 if self.pcp_world_size <= 1 else get_pcp_group().rank_in_group + self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size + self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -452,24 +458,30 @@ def __init__( # Cache the device properties. self._init_device_properties() + if self.pcp_world_size > 1: + # Note(qcs): we will pad the tokens of each request + # to a multiple of 2 * pcp_size. + max_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_world_size + else: + max_num_tokens = self.max_num_tokens # Persistent buffers for CUDA graphs. - self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) - self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) + self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32) + self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64) self.query_start_loc = self._make_buffer( self.max_num_reqs + 1, dtype=torch.int32 ) self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) - if self.dcp_world_size > 1: - self.dcp_local_seq_lens = self._make_buffer( + if self.total_cp_world_size > 1: + self.cp_local_seq_lens = self._make_buffer( self.max_num_reqs, dtype=torch.int32 ) # Because inputs_embeds may be bfloat16 and we don't need a numpy # version of this tensor, avoid a RuntimeError by not creating a # numpy buffer. self.inputs_embeds = self._make_buffer( - self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False + max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False ) - self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + self.is_token_ids = self._make_buffer(max_num_tokens, dtype=torch.bool) self.discard_request_indices = self._make_buffer( self.max_num_reqs, dtype=torch.int64 ) @@ -484,7 +496,27 @@ def __init__( # Only relevant for multimodal models if self.supports_mm_inputs: - self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + self.is_mm_embed = self._make_buffer(max_num_tokens, dtype=torch.bool) + + # Persistent buffers for Prefill Context Parallism + if self.pcp_world_size > 1: + self.pcp_allgather_restore_idx = self._make_buffer( + max_num_tokens, + dtype=torch.int64 + ) + self.pcp_padded_slot_mapping = torch.empty( + (max_num_tokens,), + dtype=torch.int64, + device=self.device, + ) + self.num_pcp_pads_cpu_tensor = torch.zeros( + (self.max_num_reqs,), device="cpu", dtype=torch.int64, pin_memory=True + ) + self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy() + self.pcp_unpad_mask_cpu_tensor = torch.zeros( + (max_num_tokens,), device="cpu", dtype=torch.bool, pin_memory=True + ) + self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -499,7 +531,7 @@ def __init__( # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = self._make_buffer( - (3, self.max_num_tokens + 1), dtype=torch.int64 + (3, max_num_tokens + 1), dtype=torch.int64 ) # None in the first PP rank. The rest are set after load_model. @@ -508,7 +540,7 @@ def __init__( # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context self.arange_np = np.arange( - max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), + max(self.max_num_reqs + 1, self.max_model_len, max_num_tokens), dtype=np.int64, ) @@ -522,7 +554,7 @@ def __init__( self.kv_sharing_fast_prefill_logits_indices = None if self.cache_config.kv_sharing_fast_prefill: self.kv_sharing_fast_prefill_logits_indices = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device=self.device + max_num_tokens, dtype=torch.int32, device=self.device ) self.uniform_decode_query_len = 1 + self.num_spec_tokens @@ -1002,6 +1034,217 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: dummy_modality = mm_budget.get_modality_with_max_tokens() return self._get_mm_dummy_batch(dummy_modality, num_seqs) + def _get_pcp_metadata( + self, + q_lens: np.ndarray, + kv_lens: np.ndarray, + allgather_restore_idx: torch.Tensor, + ) -> PrefillContextParallelMetadata: + """ + During the prefill phrase, the attention computation is divided into + two parts: q_head and q_tail. Here, we calculate the kv indices + corresponding to q_head or q_tail. Meawhile, the q and kv indptr are + also computed to build the attention wrapper. + If the pcp_size is 2, the variables are following: + >>> q_lens [4, 8] kv_lens [8, 16] + >>> pcp_chunk_sizes[2, 4] + >>> q_indptr [0, 2, 4] + >>> q_head_indices [0, 1, 4, 5, 6, 7] q_tail_indices [2, 3, 8, 9, 10, 11] + >>> kv_head_len r0 [2, 4] / r1 [4, 8] + >>> kv_for_head_indptr r0 [0, 2, 6] / r1 [0, 4, 12] + >>> kv_for_head_indices r0 [0, 1, 8, 9, 10, 11] + >>> r1 [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15] + >>> kv_tail_len r0 [8, 16] / r1 [6, 12] + >>> kv_for_tail_indptr r0 [0, 8, 24] / r1 [0, 6, 18] + >>> kv_for_tail_indices r0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 23] + >>> r1 [0, 1, 2, 3, 4, 5, 8, 9, ..., 19] + """ + if len(q_lens) == 0: + return PrefillContextParallelMetadata( + allgather_restore_idx=allgather_restore_idx, + ) + + def _get_partial_kv_idx(kv_len_per_pcp_chunk): + kv_partial_len = pcp_chunk_sizes * kv_len_per_pcp_chunk + kv_partial_indptr = np.zeros(len(kv_partial_len) + 1) + kv_partial_indptr[1:], kv_partial_arange = self._get_cumsum_and_arange(kv_partial_len) + kv_parial_indices = kv_partial_arange + np.repeat( + kv_start_loc, + kv_partial_len, + ) + return kv_partial_indptr, kv_parial_indices + + def _to_tensor(data, **kwargs): + return {k: torch.from_numpy(v).to(**kwargs) for k, v in data.items()} + + pcp_chunk_sizes = q_lens // 2 + q_indptr = np.zeros(len(pcp_chunk_sizes) + 1) + q_indptr[1:], q_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) + + q_head_start_loc = np.roll(np.cumsum(q_lens), 1) + q_head_start_loc[0] = 0 + q_head_indices = q_chunk_arange + np.repeat( + q_head_start_loc, + pcp_chunk_sizes, + ) + + q_tail_start_loc = q_head_start_loc + pcp_chunk_sizes + q_tail_indices = q_chunk_arange + np.repeat( + q_tail_start_loc, + pcp_chunk_sizes, + ) + + kv_start_loc = np.roll(np.cumsum(kv_lens), 1) + kv_start_loc[0] = 0 + # kv_for_q_head + kv_for_head_indptr, kv_for_head_indices = _get_partial_kv_idx(self.pcp_rank + 1) + # kv_for_q_tail + kv_for_tail_indptr, kv_for_tail_indices = _get_partial_kv_idx( + 2 * self.pcp_world_size - self.pcp_rank + ) + + head_tail_indices = _to_tensor({ + "q_head": q_head_indices, + "q_tail": q_tail_indices, + "kv_head": kv_for_head_indices, + "kv_tail": kv_for_tail_indices, + }, device=self.device, dtype=torch.int64, non_blocking=True) + head_tail_indptr = _to_tensor({ + "q": q_indptr, + "kv_head": kv_for_head_indptr, + "kv_tail": kv_for_tail_indptr + }, dtype=torch.int64) + + q_full_indices = torch.cat([head_tail_indices["q_head"], head_tail_indices["q_tail"]]) + q_full_indices = q_full_indices.to(torch.float32).argsort().to(torch.int32) + + return PrefillContextParallelMetadata( + allgather_restore_idx=allgather_restore_idx, + q_head_indices=head_tail_indices["q_head"], + q_tail_indices=head_tail_indices["q_tail"], + q_head_start_loc=head_tail_indptr["q"], + kv_for_head_indices=head_tail_indices["kv_head"], + kv_for_tail_indices=head_tail_indices["kv_tail"], + kv_for_head_indptr=head_tail_indptr["kv_head"], + kv_for_tail_indptr=head_tail_indptr["kv_tail"], + q_full_indices=q_full_indices, + ) + + def _update_tokens_for_pcp( + self, + tokens: np.ndarray, + dummy_input: bool = False, + num_reqs: int | None = None, + num_decode_reqs: int | None = None, + ) -> tuple[np.ndarray, np.ndarray, PrefillContextParallelMetadata]: + """ + If prefill context parallelism is enabled, we will update + the number of `tokens` after sequence splitting. + Meanwhile, we will compute: + `positions` the new token positions, + `self.num_pcp_pads_cpu` the number of padding tokens + per request for alignment, + `self.pcp_unpad_mask_cpu` the mask for non-padded tokens, + `self.pcp_allgather_restore_idx` indices to restore the + original vector order after PCP allgather. + Example: + >>> tokens = [1, 5, 8] + >>> pcp_world_size = 2 + >>> pcp_rank = 0 + >>> _update_tokens_for_pcp(tokens) + ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7]) + >>> pcp_rank = 1 + >>> _update_tokens_for_pcp(tokens) + ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5]) + >>> # the following results are same for each pcp rank + >>> self.num_pcp_pads_cpu + [1, 3, 0] + >>> self.pcp_unpad_mask_cpu + [True, False, True, True, True, True, True, False, False, + False, True, True, True, True, True, True, True, True] + >>> self.pcp_allgather_resotre_idx + [0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8] + """ + if not dummy_input: + num_reqs = self.input_batch.num_reqs + num_decode_reqs = sum( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + >= self.input_batch.num_prompt_tokens[:num_reqs] + ) + self.num_pcp_pads_cpu[:num_reqs] = 0 + + num_decode_tokens = sum(tokens[:num_decode_reqs]) + + num_padded_scheduled_tokens = np.ceil( + tokens / (2 * self.pcp_world_size) + ).astype(np.int32) * (2 * self.pcp_world_size) + # we duplicate scheduled tokens of decode reqs to pcp_world_size + num_padded_scheduled_tokens[:num_decode_reqs] = ( + tokens[:num_decode_reqs] * self.pcp_world_size + ) + self.num_pcp_pads_cpu[:num_reqs] = num_padded_scheduled_tokens - tokens + cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange( + num_padded_scheduled_tokens + ) + self.pcp_unpad_mask_cpu[: pcp_padded_arange.shape[0]] = ( + pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens) + ) + + pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size + pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) + pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs] + _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens) + _, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) + pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, pcp_tokens) + + def get_current_rank_positions( + positions_start_loc: int | np.ndarray, rank: int + ): + positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) + head_start_loc = positions_start_loc + rank * pcp_chunk_sizes + tail_start_loc = ( + positions_start_loc + + (2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes + ) + positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat( + head_start_loc, pcp_chunk_sizes + ) + # Decode reqs do not have tail chunks. + positions[~pcp_head_chunk_mask] = ( + pcp_chunk_arange[num_decode_tokens:] + + np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:] + ) + return positions + + positions = get_current_rank_positions(0, self.pcp_rank) + # Decode tokens are duplicated only after AG. But their positions are + # same without prefill context parallel. + if num_decode_reqs > 0: + positions[:num_decode_tokens] = self._get_cumsum_and_arange( + tokens[:num_decode_reqs] + )[1] + + padded_pos_start_loc = np.roll(cu_padded_tokens, 1) + padded_pos_start_loc[0] = 0 + all_positions_lst = [ + get_current_rank_positions(padded_pos_start_loc, rank_i) + for rank_i in range(self.pcp_world_size) + ] + all_positions = np.concatenate(all_positions_lst) + self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = ( + all_positions.argsort() + ) + self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) + return ( + pcp_tokens[:num_reqs], + positions, + self._get_pcp_metadata( + pcp_tokens[num_decode_reqs:], + num_padded_scheduled_tokens[num_decode_reqs:], + self.pcp_allgather_restore_idx.gpu[: all_positions.shape[0]] + ) + ) + def _get_cumsum_and_arange( self, num_tokens: np.ndarray, @@ -1172,6 +1415,7 @@ def _prepare_inputs( SpecDecodeMetadata | None, UBatchSlices | None, torch.Tensor | None, + PrefillContextParallelMetadata | None, ]: """ :return: tuple[ @@ -1204,6 +1448,29 @@ def _prepare_inputs( out=positions_np, ) + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + + pcp_metadata = None + if self.pcp_world_size > 1: + num_scheduled_tokens[:num_reqs], pcp_positions, pcp_metadata = \ + self._update_tokens_for_pcp( + num_scheduled_tokens[:num_reqs] + ) + + # Re-update after PCP split sequences. + total_num_scheduled_tokens = sum(num_scheduled_tokens) + max_num_scheduled_tokens = max(num_scheduled_tokens) + + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + cu_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + positions_np = self.positions.np[:total_num_scheduled_tokens] + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + pcp_positions[:total_num_scheduled_tokens], + out=positions_np, + ) + # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -1274,9 +1541,6 @@ def _prepare_inputs( output_idx += num_sched - self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) - # Prepare the attention metadata. self.query_start_loc.np[0] = 0 self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens @@ -1320,7 +1584,14 @@ def _prepare_inputs( # Record the index of requests that should not be sampled, # so that we could clear the sampled tokens before returning - discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np + if self.pcp_world_size > 1: + discard_requests_mask = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens * self.pcp_world_size + - self.num_pcp_pads_cpu[:num_reqs] + ) < num_tokens_np + else: + discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np discard_request_indices = np.nonzero(discard_requests_mask)[0] self.num_discarded_requests = len(discard_request_indices) self.discard_request_indices.np[: self.num_discarded_requests] = ( @@ -1353,11 +1624,19 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 + if self.pcp_world_size > 1: + logits_indices = ( + torch.from_numpy(cu_num_tokens) * self.pcp_world_size + - self.num_pcp_pads_cpu_tensor[:num_reqs] + - 1 + ) + else: + logits_indices = query_start_loc[1:] - 1 num_draft_tokens = None spec_decode_metadata = None num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) else: + assert self.pcp_world_size == 1, "PCP not support spec decode now" # Get the number of draft tokens for each request. # Iterate over the dictionary rather than all requests since not all # requests have draft tokens. @@ -1404,6 +1683,7 @@ def _prepare_inputs( spec_decode_metadata, ubatch_slices, num_tokens_across_dp, + pcp_metadata, ) def _build_attention_metadata( @@ -1417,6 +1697,7 @@ def _build_attention_metadata( for_cudagraph_capture: bool = False, scheduled_encoder_inputs: dict[str, list[int]] | None = None, cascade_attn_prefix_lens: list[list[int]] | None = None, + pcp_metadata: PrefillContextParallelMetadata | None = None, ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: """ :return: tuple[attn_metadata, spec_decode_common_attn_metadata] @@ -1431,14 +1712,14 @@ def _build_attention_metadata( ) # update seq_lens of decode reqs under DCP. - if self.dcp_world_size > 1: - self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( + if self.total_cp_world_size > 1: + self.cp_local_seq_lens.cpu[:num_reqs] = get_cp_local_seq_lens( self.seq_lens.cpu[:num_reqs], - self.dcp_world_size, - self.dcp_rank, + self.total_cp_world_size, + self.total_cp_rank, self.parallel_config.cp_kv_cache_interleave_size, ) - self.dcp_local_seq_lens.copy_to_gpu(num_reqs) + self.cp_local_seq_lens.copy_to_gpu(num_reqs) attn_metadata: PerLayerAttnMetadata = {} if ubatch_slices is not None: @@ -1453,10 +1734,10 @@ def _build_attention_metadata( :num_reqs ] - dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None - if self.dcp_world_size > 1: - dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs] - dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs] + cp_local_seq_lens, cp_local_seq_lens_cpu = None, None + if self.total_cp_world_size > 1: + cp_local_seq_lens = self.cp_local_seq_lens.gpu[:num_reqs] + cp_local_seq_lens_cpu = self.cp_local_seq_lens.cpu[:num_reqs] spec_decode_common_attn_metadata = None @@ -1486,6 +1767,12 @@ def _build_attention_metadata( num_reqs, ) + slot_mapping_size = ( + total_num_scheduled_tokens + if self.pcp_world_size == 1 + else total_num_scheduled_tokens * self.pcp_world_size + - sum(self.num_pcp_pads_cpu[:num_reqs]) + ) if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): # Encoder-only layers do not have KV cache, so we need to # create a dummy block table and slot mapping for them. @@ -1495,18 +1782,32 @@ def _build_attention_metadata( device=self.device, ) slot_mapping = torch.zeros( - (total_num_scheduled_tokens,), + (slot_mapping_size,), dtype=torch.int64, device=self.device, ) else: blk_table = self.input_batch.block_table[kv_cache_gid] blk_table_tensor = blk_table.get_device_tensor(num_reqs) - slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens] + + slot_mapping = blk_table.slot_mapping.gpu[:slot_mapping_size] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. - blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + blk_table.slot_mapping.gpu[slot_mapping_size:].fill_(-1) + + if self.pcp_world_size > 1: + # After pcp allgather and restore, there are padded tokens in + # kv, so we need pad slotmapping for alignment. + pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[ + : total_num_scheduled_tokens * self.pcp_world_size + ] + cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[ + : total_num_scheduled_tokens * self.pcp_world_size + ] + pcp_padded_slot_mapping.fill_(-1) + pcp_padded_slot_mapping[cp_unpad_mask] = slot_mapping + slot_mapping = pcp_padded_slot_mapping common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -1524,8 +1825,9 @@ def _build_attention_metadata( num_logits_indices=num_logits_indices, causal=True, encoder_seq_lens=encoder_seq_lens, - dcp_local_seq_lens=dcp_local_seq_lens, - dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu, + cp_local_seq_lens=cp_local_seq_lens, + cp_local_seq_lens_cpu=cp_local_seq_lens_cpu, + pcp_metadata=pcp_metadata, ) if self.speculative_config and spec_decode_common_attn_metadata is None: @@ -2723,6 +3025,7 @@ def execute_model( spec_decode_metadata, ubatch_slices, num_tokens_across_dp, + pcp_metadata, ) = self._prepare_inputs( scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens ) @@ -2744,7 +3047,9 @@ def execute_model( use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 attn_metadata, spec_decode_common_attn_metadata = ( self._build_attention_metadata( - total_num_scheduled_tokens=total_num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens + if self.pcp_world_size == 1 + else num_scheduled_tokens_np.sum(), max_num_scheduled_tokens=max_num_scheduled_tokens, num_reqs=num_reqs, ubatch_slices=ubatch_slices, @@ -2752,6 +3057,7 @@ def execute_model( use_spec_decode=use_spec_decode, scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, cascade_attn_prefix_lens=cascade_attn_prefix_lens, + pcp_metadata=pcp_metadata, ) ) @@ -2833,6 +3139,19 @@ def execute_model( hidden_states = model_output aux_hidden_states = None + if self.pcp_world_size > 1: + # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. + hidden_states = get_pcp_group().all_gather( + hidden_states[:num_scheduled_tokens_np.sum()], + 0, + ) + hidden_states = torch.index_select( + hidden_states, + 0, + self.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]], + ) + if not self.broadcast_pp_output: # Common case. if not get_pp_group().is_last_rank: @@ -3731,6 +4050,16 @@ def _dummy_run( assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + pcp_metadata = None + if self.pcp_world_size > 1 and force_attention: + num_decode_reqs = sum(num_scheduled_tokens == 1) + num_scheduled_tokens[:num_reqs], _, pcp_metadata = \ + self._update_tokens_for_pcp( + num_scheduled_tokens[:num_reqs], + dummy_input=True, + num_reqs=num_reqs, + num_decode_reqs=num_decode_reqs, + ) total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) @@ -3748,7 +4077,7 @@ def _dummy_run( uniform_decode=uniform_decode, num_scheduled_tokens_per_request=num_scheduled_tokens, ) - num_tokens_after_padding = num_tokens + num_tokens_after_padding = total_num_scheduled_tokens if num_tokens_across_dp is not None: dp_rank = self.parallel_config.data_parallel_rank num_tokens_after_padding = int(num_tokens_across_dp[dp_rank]) @@ -3774,11 +4103,12 @@ def _dummy_run( self.query_start_loc.copy_to_gpu() attn_metadata, _ = self._build_attention_metadata( - total_num_scheduled_tokens=num_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, max_num_scheduled_tokens=max_query_len, num_reqs=num_reqs, ubatch_slices=ubatch_slices, for_cudagraph_capture=True, + pcp_metadata=pcp_metadata if self.pcp_world_size > 1 else None, ) with self.maybe_dummy_run_with_lora( From 1cac317eadbe59b473af704ef5d4ba3153d4e62d Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Thu, 20 Nov 2025 09:07:53 +0800 Subject: [PATCH 2/7] [PCP] [PCP] flashinfer support for PCP Co-authored-by: QiuChunshuo Co-authored-by: FENP Co-authored-by: LookAround Co-authored-by: Jingchun Gao Co-authored-by: zhenwenqi2024 Signed-off-by: QiuChunshuo Signed-off-by: FENP Signed-off-by: LookAround Signed-off-by: Jingchun Gao Signed-off-by: zhenwenqi2024 --- vllm/v1/attention/backends/flashinfer.py | 372 ++++++++++++++++------- 1 file changed, 263 insertions(+), 109 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4da1637d96eb..dfd11703879c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -25,11 +25,11 @@ AttentionType, MultipleOf, ) -from vllm.attention.ops.common import cp_lse_ag_out_rs +from vllm.attention.ops.common import cp_lse_ag_out_ar, cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.config import CUDAGraphMode, VllmConfig from vllm.config.cache import CacheDType -from vllm.distributed.parallel_state import get_dcp_group +from vllm.distributed.parallel_state import get_pcp_group, get_dcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -54,11 +54,12 @@ AttentionMetadataBuilder, CommonAttentionMetadata, KVCacheLayoutType, - get_dcp_local_seq_lens, + get_cp_local_seq_lens, get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills, + PrefillContextParallelMetadata, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -165,17 +166,29 @@ def trtllm_prefill_attn_kvfp8_dequant( return mock_kv_cache, mock_block_table -class BatchDCPPrefillWrapper: +class BatchCPPrefillWrapper: def __init__( self, + dcp_world_size, + pcp_world_size, workspace_buffer: torch.Tensor | None = None, ): + self.dcp_world_size = dcp_world_size + self.pcp_world_size = pcp_world_size self._context = BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, get_kv_cache_layout() ) - self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper( - workspace_buffer, get_kv_cache_layout() - ) + if self.pcp_world_size > 1: + self._new_tokens_head = BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, get_kv_cache_layout() + ) + self._new_tokens_tail = BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, get_kv_cache_layout() + ) + else: + self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, get_kv_cache_layout() + ) def plan( self, @@ -186,7 +199,6 @@ def plan( prefill_start: int, page_size: int, num_qo_heads: int, - dcp_world_size: int, num_kv_heads: int, head_dim: int, sm_scale: float, @@ -196,39 +208,126 @@ def plan( kv_cache_dtype: torch.dtype, prefill_fixed_split_size: int, disable_split_kv: bool, + pcp_metadata: PrefillContextParallelMetadata | None, ): """Plan the prefill operation with given parameters.""" + common_args = [ + num_kv_heads, + head_dim, + ] + common_kwargs = { + "sm_scale": sm_scale, + "window_left": window_left, + "logits_soft_cap": logits_soft_cap, + "q_data_type": q_data_type, + "kv_data_type": kv_cache_dtype, + } self._context.plan( qo_indptr_cpu, paged_kv_indptr_cpu, paged_kv_indices, paged_kv_last_page_len_cpu[prefill_start:], - num_qo_heads * dcp_world_size, - num_kv_heads, - head_dim, + num_qo_heads * self.dcp_world_size, + *common_args, page_size, causal=False, # This is context run - sm_scale=sm_scale, - window_left=window_left, - logits_soft_cap=logits_soft_cap, - q_data_type=q_data_type, - kv_data_type=kv_cache_dtype, fixed_split_size=prefill_fixed_split_size, disable_split_kv=disable_split_kv, + **common_kwargs, ) - self._new_tokens.plan( - qo_indptr=qo_indptr_cpu, - kv_indptr=qo_indptr_cpu, - num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_dim_qk=head_dim, - head_dim_vo=head_dim, - causal=True, # This is newtokens run - sm_scale=sm_scale, - window_left=window_left, - logits_soft_cap=logits_soft_cap, - q_data_type=q_data_type, + if self.pcp_world_size > 1: + assert pcp_metadata is not None + qo_indptr_cpu = pcp_metadata.q_head_start_loc + kv_for_head_indptr = pcp_metadata.kv_for_head_indptr + kv_for_tail_indptr = pcp_metadata.kv_for_tail_indptr + self._new_tokens_head.plan( + qo_indptr_cpu, + kv_for_head_indptr, + num_qo_heads, + *common_args, + causal=True, + **common_kwargs, + ) + self._new_tokens_tail.plan( + qo_indptr_cpu, + kv_for_tail_indptr, + num_qo_heads, + *common_args, + causal=True, + **common_kwargs, + ) + else: + self._new_tokens.plan( + qo_indptr_cpu, + qo_indptr_cpu, + num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + causal=True, # This is newtokens run + **common_kwargs, + ) + + def _attention_with_head_and_tail( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + metadata: PrefillContextParallelMetadata, + return_lse: bool = False, + ): + """ + For prompt with tokens [T0, T1, T2, T3], the query on PCP0 is [Q0, Q3] + and we all-gather full K as [K0, K1, K2, K3]. + There are two attn ops. The "head" is [Q0]x[K0] and the "tail" is + [Q3]x[K0, K1, K2, K3]. + """ + q_head_indices = metadata.q_head_indices + q_tail_indices = metadata.q_tail_indices + kv_for_head_indices = metadata.kv_for_head_indices + kv_for_tail_indices = metadata.kv_for_tail_indices + q_full_indices = metadata.q_full_indices + + q_head = torch.index_select(query, 0, q_head_indices) + q_tail = torch.index_select(query, 0, q_tail_indices) + k_head = torch.index_select(key, 0, kv_for_head_indices) + v_head = torch.index_select(value, 0, kv_for_head_indices) + k_tail = torch.index_select(key, 0, kv_for_tail_indices) + v_tail = torch.index_select(value, 0, kv_for_tail_indices) + + output_head = self._new_tokens_head.run( + q_head, + k_head, + v_head, + return_lse=return_lse, + ) + output_tail = self._new_tokens_tail.run( + q_tail, + k_tail, + v_tail, + return_lse=return_lse, ) + if return_lse: + output_head, lse_head = output_head + output_tail, lse_tail = output_tail + output = torch.index_select( + torch.cat([output_head, output_tail], dim=0), + 0, + q_full_indices, + ) + lse = torch.index_select( + torch.cat([lse_head, lse_tail], dim=0), + 0, + q_full_indices, + ) + return output, lse + else: + output = torch.index_select( + torch.cat([output_head, output_tail], dim=0), + 0, + q_full_indices, + ) + return output def run( self, @@ -238,37 +337,48 @@ def run( key: torch.Tensor, value: torch.Tensor, out: torch.Tensor, + pcp_metadata: PrefillContextParallelMetadata | None, ): - prefill_query_across_dcp = get_dcp_group().all_gather( - prefill_query.contiguous(), dim=1 - ) - output_context_tmp, lse_context_tmp = self._context.run( - prefill_query_across_dcp, - kv_cache_permute, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - return_lse=True, - ) - output_context, lse_context = cp_lse_ag_out_rs( - output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True - ) - lse_context = lse_context.transpose(0, 1).contiguous() + if self.pcp_world_size > 1: + assert pcp_metadata is not None + out[:], lse_query = self._attention_with_head_and_tail( + prefill_query, + key, + value, + pcp_metadata, + return_lse=True, + ) + else: + prefill_query_across_dcp = get_dcp_group().all_gather( + prefill_query.contiguous(), dim=1 + ) + output_context_tmp, lse_context_tmp = self._context.run( + prefill_query_across_dcp, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + return_lse=True, + ) + output_context, lse_context = cp_lse_ag_out_rs( + output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True + ) + lse_context = lse_context.transpose(0, 1).contiguous() - output_query, lse_query = self._new_tokens.run( - prefill_query, - key, - value, - return_lse=True, - ) - lse_query = lse_query.transpose(0, 1).contiguous() - - merge_attn_states( - out, - output_context, - lse_context, - output_query, - lse_query, - ) + output_query, lse_query = self._new_tokens.run( + prefill_query, + key, + value, + return_lse=True, + ) + lse_query = lse_query.transpose(0, 1).contiguous() + + merge_attn_states( + out, + output_context, + lse_context, + output_query, + lse_query, + ) return out @@ -394,7 +504,7 @@ class FlashInferMetadata: use_cascade: bool prefill_wrapper: ( - BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None + BatchPrefillWithPagedKVCacheWrapper | BatchCPPrefillWrapper | None ) = None decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None @@ -402,6 +512,9 @@ class FlashInferMetadata: qo_indptr_gpu: torch.Tensor | None = None paged_kv_indptr_gpu: torch.Tensor | None = None + # For context parallel + pcp_metadata: PrefillContextParallelMetadata | None = None + class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): reorder_batch_threshold: int = 1 @@ -418,7 +531,7 @@ def __init__( self.model_config = vllm_config.model_config self._workspace_buffer = None self._prefill_wrapper: ( - BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None + BatchPrefillWithPagedKVCacheWrapper | BatchCPPrefillWrapper | None ) = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) @@ -457,21 +570,28 @@ def __init__( self.compilation_config.max_cudagraph_capture_size, ) + try: + self.pcp_world_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group + except AssertionError: + # PCP might not be initialized in testing + self.pcp_world_size = 1 + self.pcp_rank = 0 + try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group - self.dcp_kv_cache_interleave_size = ( - vllm_config.parallel_config.dcp_kv_cache_interleave_size - ) + self.cp_kv_cache_interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size except AssertionError: # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 - self.dcp_kv_cache_interleave_size = 1 + self.cp_kv_cache_interleave_size = 1 + self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size + self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank self.num_qo_heads = ( self.model_config.get_num_attention_heads(self.vllm_config.parallel_config) - * self.dcp_world_size ) self.num_kv_heads = self.kv_cache_spec.num_kv_heads @@ -587,11 +707,13 @@ def _get_workspace_buffer(self): def _get_prefill_wrapper( self, - ) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper: + ) -> BatchPrefillWithPagedKVCacheWrapper | BatchCPPrefillWrapper: if self._prefill_wrapper is None: - if self.dcp_world_size > 1: - self._prefill_wrapper = BatchDCPPrefillWrapper( - workspace_buffer=self._get_workspace_buffer(), + if self.total_cp_world_size > 1: + self._prefill_wrapper = BatchCPPrefillWrapper( + self.dcp_world_size, + self.pcp_world_size, + self._get_workspace_buffer(), ) else: self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( @@ -667,7 +789,7 @@ def build( block_table_tensor = common_attn_metadata.block_table_tensor qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu - if self.dcp_world_size > 1: + if self.total_cp_world_size > 1: if num_prefills > 0: qo_indptr_prefill_cpu = ( qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes] @@ -679,11 +801,11 @@ def build( seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu ) - seq_lens_cpu = get_dcp_local_seq_lens( + seq_lens_cpu = get_cp_local_seq_lens( seq_lens_cpu, - self.dcp_world_size, - self.dcp_rank, - self.dcp_kv_cache_interleave_size, + self.total_cp_world_size, + self.total_cp_rank, + self.cp_kv_cache_interleave_size, ) seq_lens_np = seq_lens_cpu.numpy() @@ -766,7 +888,7 @@ def build( has_spec=uses_spec_reorder, ) decode_use_trtllm = ( - self.use_trtllm_decode_attention and self.dcp_world_size <= 1 + self.use_trtllm_decode_attention and self.total_cp_world_size <= 1 ) if not (prefill_use_trtllm and decode_use_trtllm): @@ -810,6 +932,7 @@ def build( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, + pcp_metadata=common_attn_metadata.pcp_metadata, ) paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] @@ -863,9 +986,9 @@ def build( attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) if not attn_metadata.prefill_use_trtllm: - if self.dcp_world_size > 1: + if self.total_cp_world_size > 1: assert isinstance( - attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper + attn_metadata.prefill_wrapper, BatchCPPrefillWrapper ) attn_metadata.prefill_wrapper.plan( qo_indptr_cpu=qo_indptr_cpu, @@ -875,7 +998,6 @@ def build( prefill_start=prefill_start, page_size=self.page_size, num_qo_heads=self.num_qo_heads, - dcp_world_size=self.dcp_world_size, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, sm_scale=self.sm_scale, @@ -885,6 +1007,7 @@ def build( kv_cache_dtype=self.kv_cache_dtype, prefill_fixed_split_size=self.prefill_fixed_split_size, disable_split_kv=self.disable_split_kv, + pcp_metadata=common_attn_metadata.pcp_metadata if self.pcp_world_size > 1 else None, ) else: assert isinstance( @@ -979,6 +1102,8 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: # TODO: The cascade wrapper currently does not support setting # kv cache dtype to something different from query dtype. return False + if self.pcp_world_size > 1: + return False # TODO: Cascade attention doesn't work, disable it for now # return use_cascade_attention(*args, **kwargs) return False @@ -1150,6 +1275,27 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens + if (self.pcp_world_size > 1): + assert attn_metadata.pcp_metadata is not None + pcp_allgather_restore_idx = attn_metadata.pcp_metadata.allgather_restore_idx + assert pcp_allgather_restore_idx is not None + # NOTE(yyj): we must `slice` key and value because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. To be optimized for performance! + key_across_pcp = get_pcp_group().all_gather( + key[:num_actual_tokens].contiguous(), dim=0 + ) + value_across_pcp = get_pcp_group().all_gather( + value[:num_actual_tokens].contiguous(), dim=0 + ) + # Reorder kv after pcp allgather. + # Note that there are duplicate decoding tokens, + # but we only save the first one in kvcache. + key = torch.index_select( + key_across_pcp, 0, pcp_allgather_restore_idx + ) + value = torch.index_select( + value_across_pcp, 0, pcp_allgather_restore_idx + ) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. @@ -1179,8 +1325,8 @@ def forward( # Inputs and outputs may be padded for CUDA graphs query = query[:num_actual_tokens] - key = key[:num_actual_tokens] - value = value[:num_actual_tokens] + key = key[:num_actual_tokens*self.pcp_world_size] + value = value[:num_actual_tokens*self.pcp_world_size] output_padded = output output = output[:num_actual_tokens] @@ -1207,28 +1353,33 @@ def forward( assert prefill_wrapper is not None if not attn_metadata.prefill_use_trtllm: - if self.dcp_world_size > 1: - assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper) - assert prefill_wrapper._context._window_left == self.window_left - assert prefill_wrapper._context._logits_soft_cap == ( - self.logits_soft_cap or 0.0 - ) - assert prefill_wrapper._context._sm_scale == self.scale - assert not prefill_wrapper._context._causal - assert prefill_wrapper._new_tokens._window_left == self.window_left - assert prefill_wrapper._new_tokens._logits_soft_cap == ( - self.logits_soft_cap or 0.0 - ) - assert prefill_wrapper._new_tokens._sm_scale == self.scale - assert prefill_wrapper._new_tokens._causal + if self.total_cp_world_size > 1: + assert isinstance(prefill_wrapper, BatchCPPrefillWrapper) + expected_logits_soft_cap = self.logits_soft_cap or 0.0 + + wrappers_to_check = [(prefill_wrapper._context, False)] + if self.pcp_world_size > 1: + wrappers_to_check.extend([ + (prefill_wrapper._new_tokens_head, True), + (prefill_wrapper._new_tokens_tail, True) + ]) + else: + wrappers_to_check.append((prefill_wrapper._new_tokens, True)) + + for wrapper, expected_causal in wrappers_to_check: + assert wrapper._window_left == self.window_left + assert wrapper._logits_soft_cap == expected_logits_soft_cap + assert wrapper._sm_scale == self.scale + assert wrapper._causal == expected_causal prefill_wrapper.run( layer, prefill_query, kv_cache_permute, - key[num_decode_tokens:], - value[num_decode_tokens:], + key[num_decode_tokens * self.pcp_world_size:], + value[num_decode_tokens * self.pcp_world_size:], out=output[num_decode_tokens:], + pcp_metadata=attn_metadata.pcp_metadata if self.pcp_world_size > 1 else None, ) else: assert isinstance( @@ -1323,28 +1474,31 @@ def forward( assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale - if self.dcp_world_size > 1: + if self.total_cp_world_size > 1: decode_query = get_dcp_group().all_gather( - decode_query.contiguous(), dim=-2 + decode_query.contiguous(), dim=1 ) - output_tmp = torch.empty_like(decode_query) - lse = torch.empty( - (decode_query.size(0), decode_query.size(1)), - dtype=torch.float32, - device=decode_query.device, - ) - decode_wrapper.run( + out, lse = decode_wrapper.run( decode_query, kv_cache_permute, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, - out=output_tmp, - lse=lse, return_lse=True, ) - output[:num_decode_tokens] = cp_lse_ag_out_rs( - output_tmp, lse, get_dcp_group() - ) + if self.dcp_world_size > 1: + out, lse = cp_lse_ag_out_rs( + out, lse, get_dcp_group(), + return_lse=True, + ) + else: + output[:num_decode_tokens] = out + + if self.pcp_world_size > 1: + output[:num_decode_tokens] = cp_lse_ag_out_ar( + out, lse, get_pcp_group() + ) + else: + output[:num_decode_tokens] = out else: decode_wrapper.run( decode_query, From eb5d34baee99df4ffcb74b334c1926e0dfc13cbe Mon Sep 17 00:00:00 2001 From: Jingchun Gao Date: Sun, 23 Nov 2025 23:04:59 +0800 Subject: [PATCH 3/7] [Fix] common support for pcp Signed-off-by: Jingchun Gao --- vllm/distributed/parallel_state.py | 10 +++ vllm/v1/worker/gpu_model_runner.py | 104 ++++++++++++++++++----------- 2 files changed, 74 insertions(+), 40 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f81612fd1f4a..2d9da36f005e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1549,6 +1549,16 @@ def get_decode_context_model_parallel_rank(): return get_dcp_group().rank_in_group +def get_prefill_context_model_parallel_world_size(): + """Return world size for the decode context model parallel group.""" + return get_pcp_group().world_size + + +def get_prefill_context_model_parallel_rank(): + """Return my rank for the decode context model parallel group.""" + return get_pcp_group().rank_in_group + + def get_node_count() -> int: """Return the total number of nodes in the distributed environment.""" assert _NODE_COUNT is not None, "distributed environment is not initialized" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index df41096f9352..cb447c89072b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -504,6 +504,22 @@ def __init__( max_num_tokens, dtype=torch.int64 ) + self.q_head_indices = self._make_buffer( + max_num_tokens, + dtype=torch.int64 + ) + self.q_tail_indices = self._make_buffer( + max_num_tokens, + dtype=torch.int64 + ) + self.kv_for_head_indices = self._make_buffer( + max_num_tokens, + dtype=torch.int64 + ) + self.kv_for_tail_indices = self._make_buffer( + max_num_tokens, + dtype=torch.int64 + ) self.pcp_padded_slot_mapping = torch.empty( (max_num_tokens,), dtype=torch.int64, @@ -517,6 +533,18 @@ def __init__( (max_num_tokens,), device="cpu", dtype=torch.bool, pin_memory=True ) self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() + self.q_indptr_cpu_tensor = torch.zeros( + (self.max_num_reqs + 1,), device="cpu", dtype=torch.int64, pin_memory=True + ) + self.q_indptr_cpu = self.q_indptr_cpu_tensor.numpy() + self.kv_for_head_indptr_cpu_tensor = torch.zeros( + (self.max_num_reqs + 1,), device="cpu", dtype=torch.int64, pin_memory=True + ) + self.kv_for_head_indptr_cpu = self.kv_for_head_indptr_cpu_tensor.numpy() + self.kv_for_tail_indptr_cpu_tensor = torch.zeros( + (self.max_num_reqs + 1,), device="cpu", dtype=torch.int64, pin_memory=True + ) + self.kv_for_tail_indptr_cpu = self.kv_for_tail_indptr_cpu_tensor.numpy() # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -1064,69 +1092,59 @@ def _get_pcp_metadata( allgather_restore_idx=allgather_restore_idx, ) - def _get_partial_kv_idx(kv_len_per_pcp_chunk): - kv_partial_len = pcp_chunk_sizes * kv_len_per_pcp_chunk - kv_partial_indptr = np.zeros(len(kv_partial_len) + 1) - kv_partial_indptr[1:], kv_partial_arange = self._get_cumsum_and_arange(kv_partial_len) - kv_parial_indices = kv_partial_arange + np.repeat( + def _get_partial_kv_idx(kv_partial_len, kv_partial_indptr, kv_parial_indices): + kv_partial_indptr[1 : len(kv_partial_len) + 1], kv_partial_arange = self._get_cumsum_and_arange(kv_partial_len) + kv_parial_indices.np[: kv_partial_arange.shape[0]] = kv_partial_arange + np.repeat( kv_start_loc, kv_partial_len, ) - return kv_partial_indptr, kv_parial_indices - - def _to_tensor(data, **kwargs): - return {k: torch.from_numpy(v).to(**kwargs) for k, v in data.items()} + return kv_partial_arange.shape[0] - pcp_chunk_sizes = q_lens // 2 - q_indptr = np.zeros(len(pcp_chunk_sizes) + 1) - q_indptr[1:], q_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) + pcp_chunk_sizes = q_lens // 2 + self.q_indptr_cpu[1 : len(pcp_chunk_sizes) + 1], q_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) q_head_start_loc = np.roll(np.cumsum(q_lens), 1) q_head_start_loc[0] = 0 - q_head_indices = q_chunk_arange + np.repeat( + self.q_head_indices.np[: q_chunk_arange.shape[0]] = q_chunk_arange + np.repeat( q_head_start_loc, pcp_chunk_sizes, ) + self.q_head_indices.copy_to_gpu(q_chunk_arange.shape[0]) q_tail_start_loc = q_head_start_loc + pcp_chunk_sizes - q_tail_indices = q_chunk_arange + np.repeat( + self.q_tail_indices.np[: q_chunk_arange.shape[0]] = q_chunk_arange + np.repeat( q_tail_start_loc, pcp_chunk_sizes, ) + self.q_tail_indices.copy_to_gpu(q_chunk_arange.shape[0]) kv_start_loc = np.roll(np.cumsum(kv_lens), 1) kv_start_loc[0] = 0 # kv_for_q_head - kv_for_head_indptr, kv_for_head_indices = _get_partial_kv_idx(self.pcp_rank + 1) + kv_for_head_len = (self.pcp_rank + 1) * pcp_chunk_sizes + kv_head_tokens_sum = _get_partial_kv_idx(kv_for_head_len, self.kv_for_head_indptr_cpu, self.kv_for_head_indices) + self.kv_for_head_indices.copy_to_gpu(kv_head_tokens_sum) # kv_for_q_tail - kv_for_tail_indptr, kv_for_tail_indices = _get_partial_kv_idx( - 2 * self.pcp_world_size - self.pcp_rank - ) - - head_tail_indices = _to_tensor({ - "q_head": q_head_indices, - "q_tail": q_tail_indices, - "kv_head": kv_for_head_indices, - "kv_tail": kv_for_tail_indices, - }, device=self.device, dtype=torch.int64, non_blocking=True) - head_tail_indptr = _to_tensor({ - "q": q_indptr, - "kv_head": kv_for_head_indptr, - "kv_tail": kv_for_tail_indptr - }, dtype=torch.int64) - - q_full_indices = torch.cat([head_tail_indices["q_head"], head_tail_indices["q_tail"]]) - q_full_indices = q_full_indices.to(torch.float32).argsort().to(torch.int32) + kv_for_tail_len = (2 * self.pcp_world_size - self.pcp_rank) * pcp_chunk_sizes + kv_tail_tokens_sum = _get_partial_kv_idx(kv_for_tail_len, self.kv_for_tail_indptr_cpu, self.kv_for_tail_indices) + self.kv_for_tail_indices.copy_to_gpu(kv_tail_tokens_sum) + + q_full_indices = torch.cat( + [ + self.q_head_indices.gpu[: q_chunk_arange.shape[0]], + self.q_tail_indices.gpu[: q_chunk_arange.shape[0]] + ] + ).argsort() return PrefillContextParallelMetadata( allgather_restore_idx=allgather_restore_idx, - q_head_indices=head_tail_indices["q_head"], - q_tail_indices=head_tail_indices["q_tail"], - q_head_start_loc=head_tail_indptr["q"], - kv_for_head_indices=head_tail_indices["kv_head"], - kv_for_tail_indices=head_tail_indices["kv_tail"], - kv_for_head_indptr=head_tail_indptr["kv_head"], - kv_for_tail_indptr=head_tail_indptr["kv_tail"], + q_head_indices=self.q_head_indices.gpu[: q_chunk_arange.shape[0]], + q_tail_indices=self.q_tail_indices.gpu[: q_chunk_arange.shape[0]], + q_head_start_loc=self.q_indptr_cpu_tensor[: len(pcp_chunk_sizes) + 1], + kv_for_head_indices=self.kv_for_head_indices.gpu[: kv_head_tokens_sum], + kv_for_tail_indices=self.kv_for_tail_indices.gpu[: kv_tail_tokens_sum], + kv_for_head_indptr=self.kv_for_head_indptr_cpu_tensor[: len(kv_for_head_len) + 1], + kv_for_tail_indptr=self.kv_for_tail_indptr_cpu_tensor[: len(kv_for_tail_len) + 1], q_full_indices=q_full_indices, ) @@ -3068,6 +3086,12 @@ def execute_model( self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) elif num_tokens_across_dp is not None: num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + elif self.pcp_world_size > 1: + # NOTE(qcs): For PCP, we pad num_scheduled_tokens_np but + # do not update total_num_scheduled_tokens in scheduler_output + num_input_tokens = self._get_num_input_tokens( + sum(num_scheduled_tokens_np) + ) else: num_input_tokens = self._get_num_input_tokens( scheduler_output.total_num_scheduled_tokens From 3d653300404121fa3dc7036e287681384c76beb0 Mon Sep 17 00:00:00 2001 From: Jingchun Gao Date: Sun, 23 Nov 2025 23:10:21 +0800 Subject: [PATCH 4/7] [Fix] flashinfer support for pcp Signed-off-by: Jingchun Gao --- vllm/v1/attention/backends/flashinfer.py | 66 ++++++++++++------------ 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index dfd11703879c..f03064021350 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -29,7 +29,7 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.config import CUDAGraphMode, VllmConfig from vllm.config.cache import CacheDType -from vllm.distributed.parallel_state import get_pcp_group, get_dcp_group +from vllm.distributed.parallel_state import get_dcp_group, get_pcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -54,12 +54,12 @@ AttentionMetadataBuilder, CommonAttentionMetadata, KVCacheLayoutType, + PrefillContextParallelMetadata, get_cp_local_seq_lens, get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills, - PrefillContextParallelMetadata, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -267,7 +267,7 @@ def plan( causal=True, # This is newtokens run **common_kwargs, ) - + def _attention_with_head_and_tail( self, query: torch.Tensor, @@ -279,7 +279,7 @@ def _attention_with_head_and_tail( """ For prompt with tokens [T0, T1, T2, T3], the query on PCP0 is [Q0, Q3] and we all-gather full K as [K0, K1, K2, K3]. - There are two attn ops. The "head" is [Q0]x[K0] and the "tail" is + There are two attn ops. The "head" is [Q0]x[K0] and the "tail" is [Q3]x[K0, K1, K2, K3]. """ q_head_indices = metadata.q_head_indices @@ -581,7 +581,9 @@ def __init__( try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group - self.cp_kv_cache_interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size + self.cp_kv_cache_interleave_size = ( + vllm_config.parallel_config.cp_kv_cache_interleave_size + ) except AssertionError: # DCP might not be initialized in testing self.dcp_world_size = 1 @@ -590,8 +592,8 @@ def __init__( self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank - self.num_qo_heads = ( - self.model_config.get_num_attention_heads(self.vllm_config.parallel_config) + self.num_qo_heads = self.model_config.get_num_attention_heads( + self.vllm_config.parallel_config ) self.num_kv_heads = self.kv_cache_spec.num_kv_heads @@ -880,7 +882,7 @@ def build( self.num_kv_heads, num_prefill_tokens, max_seq_len, - self.dcp_world_size, + self.total_cp_world_size, self.cache_dtype, self.q_data_type, is_prefill=True, @@ -1007,7 +1009,9 @@ def build( kv_cache_dtype=self.kv_cache_dtype, prefill_fixed_split_size=self.prefill_fixed_split_size, disable_split_kv=self.disable_split_kv, - pcp_metadata=common_attn_metadata.pcp_metadata if self.pcp_world_size > 1 else None, + pcp_metadata=common_attn_metadata.pcp_metadata + if self.pcp_world_size > 1 + else None, ) else: assert isinstance( @@ -1275,18 +1279,14 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens - if (self.pcp_world_size > 1): + if self.pcp_world_size > 1: assert attn_metadata.pcp_metadata is not None pcp_allgather_restore_idx = attn_metadata.pcp_metadata.allgather_restore_idx assert pcp_allgather_restore_idx is not None # NOTE(yyj): we must `slice` key and value because pcp_allgather_restore_idx # ignores the padding from CUDA Graph. To be optimized for performance! - key_across_pcp = get_pcp_group().all_gather( - key[:num_actual_tokens].contiguous(), dim=0 - ) - value_across_pcp = get_pcp_group().all_gather( - value[:num_actual_tokens].contiguous(), dim=0 - ) + key_across_pcp = get_pcp_group().all_gather(key[:num_actual_tokens].contiguous(), dim=0) + value_across_pcp = get_pcp_group().all_gather(value[:num_actual_tokens].contiguous(), dim=0) # Reorder kv after pcp allgather. # Note that there are duplicate decoding tokens, # but we only save the first one in kvcache. @@ -1325,8 +1325,8 @@ def forward( # Inputs and outputs may be padded for CUDA graphs query = query[:num_actual_tokens] - key = key[:num_actual_tokens*self.pcp_world_size] - value = value[:num_actual_tokens*self.pcp_world_size] + key = key[: num_actual_tokens * self.pcp_world_size] + value = value[: num_actual_tokens * self.pcp_world_size] output_padded = output output = output[:num_actual_tokens] @@ -1359,10 +1359,12 @@ def forward( wrappers_to_check = [(prefill_wrapper._context, False)] if self.pcp_world_size > 1: - wrappers_to_check.extend([ - (prefill_wrapper._new_tokens_head, True), - (prefill_wrapper._new_tokens_tail, True) - ]) + wrappers_to_check.extend( + [ + (prefill_wrapper._new_tokens_head, True), + (prefill_wrapper._new_tokens_tail, True), + ] + ) else: wrappers_to_check.append((prefill_wrapper._new_tokens, True)) @@ -1376,10 +1378,12 @@ def forward( layer, prefill_query, kv_cache_permute, - key[num_decode_tokens * self.pcp_world_size:], - value[num_decode_tokens * self.pcp_world_size:], + key[num_decode_tokens * self.pcp_world_size :], + value[num_decode_tokens * self.pcp_world_size :], out=output[num_decode_tokens:], - pcp_metadata=attn_metadata.pcp_metadata if self.pcp_world_size > 1 else None, + pcp_metadata=attn_metadata.pcp_metadata + if self.pcp_world_size > 1 + else None, ) else: assert isinstance( @@ -1487,18 +1491,16 @@ def forward( ) if self.dcp_world_size > 1: out, lse = cp_lse_ag_out_rs( - out, lse, get_dcp_group(), + out, + lse, + get_dcp_group(), return_lse=True, ) - else: - output[:num_decode_tokens] = out - if self.pcp_world_size > 1: - output[:num_decode_tokens] = cp_lse_ag_out_ar( + out = cp_lse_ag_out_ar( out, lse, get_pcp_group() ) - else: - output[:num_decode_tokens] = out + output[:num_decode_tokens] = out else: decode_wrapper.run( decode_query, From df36e767a32c264c45caa4876f37d941e5f062db Mon Sep 17 00:00:00 2001 From: Jingchun Gao Date: Mon, 24 Nov 2025 00:35:02 +0800 Subject: [PATCH 5/7] [Lint] Signed-off-by: Jingchun Gao --- vllm/v1/attention/backends/flashinfer.py | 24 ++--- vllm/v1/attention/backends/mla/common.py | 4 +- vllm/v1/attention/backends/utils.py | 6 +- vllm/v1/worker/gpu_model_runner.py | 123 ++++++++++++++--------- 4 files changed, 94 insertions(+), 63 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index f03064021350..593dc2bab7c7 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1285,17 +1285,17 @@ def forward( assert pcp_allgather_restore_idx is not None # NOTE(yyj): we must `slice` key and value because pcp_allgather_restore_idx # ignores the padding from CUDA Graph. To be optimized for performance! - key_across_pcp = get_pcp_group().all_gather(key[:num_actual_tokens].contiguous(), dim=0) - value_across_pcp = get_pcp_group().all_gather(value[:num_actual_tokens].contiguous(), dim=0) + key_across_pcp = get_pcp_group().all_gather( + key[:num_actual_tokens].contiguous(), dim=0 + ) + value_across_pcp = get_pcp_group().all_gather( + value[:num_actual_tokens].contiguous(), dim=0 + ) # Reorder kv after pcp allgather. # Note that there are duplicate decoding tokens, # but we only save the first one in kvcache. - key = torch.index_select( - key_across_pcp, 0, pcp_allgather_restore_idx - ) - value = torch.index_select( - value_across_pcp, 0, pcp_allgather_restore_idx - ) + key = torch.index_select(key_across_pcp, 0, pcp_allgather_restore_idx) + value = torch.index_select(value_across_pcp, 0, pcp_allgather_restore_idx) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. @@ -1356,7 +1356,7 @@ def forward( if self.total_cp_world_size > 1: assert isinstance(prefill_wrapper, BatchCPPrefillWrapper) expected_logits_soft_cap = self.logits_soft_cap or 0.0 - + wrappers_to_check = [(prefill_wrapper._context, False)] if self.pcp_world_size > 1: wrappers_to_check.extend( @@ -1367,7 +1367,7 @@ def forward( ) else: wrappers_to_check.append((prefill_wrapper._new_tokens, True)) - + for wrapper, expected_causal in wrappers_to_check: assert wrapper._window_left == self.window_left assert wrapper._logits_soft_cap == expected_logits_soft_cap @@ -1497,9 +1497,7 @@ def forward( return_lse=True, ) if self.pcp_world_size > 1: - out = cp_lse_ag_out_ar( - out, lse, get_pcp_group() - ) + out = cp_lse_ag_out_ar(out, lse, get_pcp_group()) output[:num_decode_tokens] = out else: decode_wrapper.run( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 1a7f3c1d0a59..8e58beeb4814 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -754,8 +754,8 @@ def build( query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu - dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens - dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu + dcp_local_seq_lens = common_attn_metadata.cp_local_seq_lens + dcp_local_seq_lens_cpu = common_attn_metadata.cp_local_seq_lens_cpu query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index bbad992975e6..d52a27f68e9d 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -48,11 +48,13 @@ def is_valid_kv_cache_layout(value: str) -> bool: return value in get_args(KVCacheLayoutType) + @dataclass class PrefillContextParallelMetadata: """ Attention metadata for prefill context parallel """ + allgather_restore_idx: torch.Tensor """ We split and concatenate the sequence in a head-tail style, @@ -62,11 +64,12 @@ class PrefillContextParallelMetadata: q_tail_indices: torch.Tensor | None = None q_head_start_loc: torch.Tensor | None = None kv_for_head_indices: torch.Tensor | None = None - kv_for_tail_indices : torch.Tensor | None = None + kv_for_tail_indices: torch.Tensor | None = None kv_for_head_indptr: torch.Tensor | None = None kv_for_tail_indptr: torch.Tensor | None = None q_full_indices: torch.Tensor | None = None + @dataclass class CommonAttentionMetadata: """ @@ -115,6 +118,7 @@ class CommonAttentionMetadata: pcp_metadata: PrefillContextParallelMetadata | None = None + def slice_query_start_locs( query_start_loc: torch.Tensor, request_slice: slice, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cb447c89072b..675b302858b4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -93,9 +93,9 @@ AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + PrefillContextParallelMetadata, create_fast_prefill_custom_backend, get_cp_local_seq_lens, - PrefillContextParallelMetadata, reorder_batch_to_split_decodes_and_prefills, split_attn_metadata, ) @@ -461,7 +461,9 @@ def __init__( if self.pcp_world_size > 1: # Note(qcs): we will pad the tokens of each request # to a multiple of 2 * pcp_size. - max_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_world_size + max_num_tokens = ( + self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_world_size + ) else: max_num_tokens = self.max_num_tokens # Persistent buffers for CUDA graphs. @@ -501,24 +503,15 @@ def __init__( # Persistent buffers for Prefill Context Parallism if self.pcp_world_size > 1: self.pcp_allgather_restore_idx = self._make_buffer( - max_num_tokens, - dtype=torch.int64 - ) - self.q_head_indices = self._make_buffer( - max_num_tokens, - dtype=torch.int64 - ) - self.q_tail_indices = self._make_buffer( - max_num_tokens, - dtype=torch.int64 + max_num_tokens, dtype=torch.int64 ) + self.q_head_indices = self._make_buffer(max_num_tokens, dtype=torch.int64) + self.q_tail_indices = self._make_buffer(max_num_tokens, dtype=torch.int64) self.kv_for_head_indices = self._make_buffer( - max_num_tokens, - dtype=torch.int64 + max_num_tokens, dtype=torch.int64 ) self.kv_for_tail_indices = self._make_buffer( - max_num_tokens, - dtype=torch.int64 + max_num_tokens, dtype=torch.int64 ) self.pcp_padded_slot_mapping = torch.empty( (max_num_tokens,), @@ -534,15 +527,24 @@ def __init__( ) self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() self.q_indptr_cpu_tensor = torch.zeros( - (self.max_num_reqs + 1,), device="cpu", dtype=torch.int64, pin_memory=True + (self.max_num_reqs + 1,), + device="cpu", + dtype=torch.int64, + pin_memory=True, ) self.q_indptr_cpu = self.q_indptr_cpu_tensor.numpy() self.kv_for_head_indptr_cpu_tensor = torch.zeros( - (self.max_num_reqs + 1,), device="cpu", dtype=torch.int64, pin_memory=True + (self.max_num_reqs + 1,), + device="cpu", + dtype=torch.int64, + pin_memory=True, ) self.kv_for_head_indptr_cpu = self.kv_for_head_indptr_cpu_tensor.numpy() self.kv_for_tail_indptr_cpu_tensor = torch.zeros( - (self.max_num_reqs + 1,), device="cpu", dtype=torch.int64, pin_memory=True + (self.max_num_reqs + 1,), + device="cpu", + dtype=torch.int64, + pin_memory=True, ) self.kv_for_tail_indptr_cpu = self.kv_for_tail_indptr_cpu_tensor.numpy() @@ -1070,22 +1072,22 @@ def _get_pcp_metadata( ) -> PrefillContextParallelMetadata: """ During the prefill phrase, the attention computation is divided into - two parts: q_head and q_tail. Here, we calculate the kv indices - corresponding to q_head or q_tail. Meawhile, the q and kv indptr are + two parts: q_head and q_tail. Here, we calculate the kv indices + corresponding to q_head or q_tail. Meawhile, the q and kv indptr are also computed to build the attention wrapper. If the pcp_size is 2, the variables are following: >>> q_lens [4, 8] kv_lens [8, 16] >>> pcp_chunk_sizes[2, 4] - >>> q_indptr [0, 2, 4] + >>> q_indptr[0, 2, 4] >>> q_head_indices [0, 1, 4, 5, 6, 7] q_tail_indices [2, 3, 8, 9, 10, 11] >>> kv_head_len r0 [2, 4] / r1 [4, 8] >>> kv_for_head_indptr r0 [0, 2, 6] / r1 [0, 4, 12] >>> kv_for_head_indices r0 [0, 1, 8, 9, 10, 11] - >>> r1 [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15] + >>> r1[0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15] >>> kv_tail_len r0 [8, 16] / r1 [6, 12] >>> kv_for_tail_indptr r0 [0, 8, 24] / r1 [0, 6, 18] >>> kv_for_tail_indices r0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 23] - >>> r1 [0, 1, 2, 3, 4, 5, 8, 9, ..., 19] + >>> r1[0, 1, 2, 3, 4, 5, 8, 9, ..., 19] """ if len(q_lens) == 0: return PrefillContextParallelMetadata( @@ -1093,15 +1095,22 @@ def _get_pcp_metadata( ) def _get_partial_kv_idx(kv_partial_len, kv_partial_indptr, kv_parial_indices): - kv_partial_indptr[1 : len(kv_partial_len) + 1], kv_partial_arange = self._get_cumsum_and_arange(kv_partial_len) - kv_parial_indices.np[: kv_partial_arange.shape[0]] = kv_partial_arange + np.repeat( - kv_start_loc, - kv_partial_len, + kv_partial_indptr[1 : len(kv_partial_len) + 1], kv_partial_arange = ( + self._get_cumsum_and_arange(kv_partial_len) + ) + kv_parial_indices.np[: kv_partial_arange.shape[0]] = ( + kv_partial_arange + + np.repeat( + kv_start_loc, + kv_partial_len, + ) ) return kv_partial_arange.shape[0] pcp_chunk_sizes = q_lens // 2 - self.q_indptr_cpu[1 : len(pcp_chunk_sizes) + 1], q_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) + self.q_indptr_cpu[1 : len(pcp_chunk_sizes) + 1], q_chunk_arange = ( + self._get_cumsum_and_arange(pcp_chunk_sizes) + ) q_head_start_loc = np.roll(np.cumsum(q_lens), 1) q_head_start_loc[0] = 0 @@ -1109,6 +1118,7 @@ def _get_partial_kv_idx(kv_partial_len, kv_partial_indptr, kv_parial_indices): q_head_start_loc, pcp_chunk_sizes, ) + self.q_head_indices.copy_to_gpu(q_chunk_arange.shape[0]) q_tail_start_loc = q_head_start_loc + pcp_chunk_sizes @@ -1122,17 +1132,25 @@ def _get_partial_kv_idx(kv_partial_len, kv_partial_indptr, kv_parial_indices): kv_start_loc[0] = 0 # kv_for_q_head kv_for_head_len = (self.pcp_rank + 1) * pcp_chunk_sizes - kv_head_tokens_sum = _get_partial_kv_idx(kv_for_head_len, self.kv_for_head_indptr_cpu, self.kv_for_head_indices) + kv_head_tokens_sum = _get_partial_kv_idx( + kv_for_head_len, + self.kv_for_head_indptr_cpu, + self.kv_for_head_indices, + ) self.kv_for_head_indices.copy_to_gpu(kv_head_tokens_sum) # kv_for_q_tail kv_for_tail_len = (2 * self.pcp_world_size - self.pcp_rank) * pcp_chunk_sizes - kv_tail_tokens_sum = _get_partial_kv_idx(kv_for_tail_len, self.kv_for_tail_indptr_cpu, self.kv_for_tail_indices) + kv_tail_tokens_sum = _get_partial_kv_idx( + kv_for_tail_len, + self.kv_for_tail_indptr_cpu, + self.kv_for_tail_indices, + ) self.kv_for_tail_indices.copy_to_gpu(kv_tail_tokens_sum) q_full_indices = torch.cat( [ self.q_head_indices.gpu[: q_chunk_arange.shape[0]], - self.q_tail_indices.gpu[: q_chunk_arange.shape[0]] + self.q_tail_indices.gpu[: q_chunk_arange.shape[0]], ] ).argsort() @@ -1141,13 +1159,17 @@ def _get_partial_kv_idx(kv_partial_len, kv_partial_indptr, kv_parial_indices): q_head_indices=self.q_head_indices.gpu[: q_chunk_arange.shape[0]], q_tail_indices=self.q_tail_indices.gpu[: q_chunk_arange.shape[0]], q_head_start_loc=self.q_indptr_cpu_tensor[: len(pcp_chunk_sizes) + 1], - kv_for_head_indices=self.kv_for_head_indices.gpu[: kv_head_tokens_sum], - kv_for_tail_indices=self.kv_for_tail_indices.gpu[: kv_tail_tokens_sum], - kv_for_head_indptr=self.kv_for_head_indptr_cpu_tensor[: len(kv_for_head_len) + 1], - kv_for_tail_indptr=self.kv_for_tail_indptr_cpu_tensor[: len(kv_for_tail_len) + 1], + kv_for_head_indices=self.kv_for_head_indices.gpu[:kv_head_tokens_sum], + kv_for_tail_indices=self.kv_for_tail_indices.gpu[:kv_tail_tokens_sum], + kv_for_head_indptr=( + self.kv_for_head_indptr_cpu_tensor[: len(kv_for_head_len) + 1] + ), + kv_for_tail_indptr=( + self.kv_for_tail_indptr_cpu_tensor[: len(kv_for_tail_len) + 1] + ), q_full_indices=q_full_indices, ) - + def _update_tokens_for_pcp( self, tokens: np.ndarray, @@ -1189,8 +1211,15 @@ def _update_tokens_for_pcp( self.input_batch.num_computed_tokens_cpu[:num_reqs] >= self.input_batch.num_prompt_tokens[:num_reqs] ) + else: + if num_reqs is None or num_decode_reqs is None: + raise ValueError( + "num_reqs and num_decode_reqs must be provided for dummy input" + ) + assert num_reqs is not None + assert num_decode_reqs is not None self.num_pcp_pads_cpu[:num_reqs] = 0 - + num_decode_tokens = sum(tokens[:num_decode_reqs]) num_padded_scheduled_tokens = np.ceil( @@ -1259,8 +1288,8 @@ def get_current_rank_positions( self._get_pcp_metadata( pcp_tokens[num_decode_reqs:], num_padded_scheduled_tokens[num_decode_reqs:], - self.pcp_allgather_restore_idx.gpu[: all_positions.shape[0]] - ) + self.pcp_allgather_restore_idx.gpu[: all_positions.shape[0]], + ), ) def _get_cumsum_and_arange( @@ -1471,10 +1500,9 @@ def _prepare_inputs( pcp_metadata = None if self.pcp_world_size > 1: - num_scheduled_tokens[:num_reqs], pcp_positions, pcp_metadata = \ - self._update_tokens_for_pcp( - num_scheduled_tokens[:num_reqs] - ) + num_scheduled_tokens[:num_reqs], pcp_positions, pcp_metadata = ( + self._update_tokens_for_pcp(num_scheduled_tokens[:num_reqs]) + ) # Re-update after PCP split sequences. total_num_scheduled_tokens = sum(num_scheduled_tokens) @@ -1605,7 +1633,7 @@ def _prepare_inputs( if self.pcp_world_size > 1: discard_requests_mask = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] - + num_scheduled_tokens * self.pcp_world_size + + num_scheduled_tokens * self.pcp_world_size - self.num_pcp_pads_cpu[:num_reqs] ) < num_tokens_np else: @@ -3167,7 +3195,7 @@ def execute_model( # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx # ignores the padding from CUDA Graph. hidden_states = get_pcp_group().all_gather( - hidden_states[:num_scheduled_tokens_np.sum()], + hidden_states[: num_scheduled_tokens_np.sum()], 0, ) hidden_states = torch.index_select( @@ -4077,13 +4105,14 @@ def _dummy_run( pcp_metadata = None if self.pcp_world_size > 1 and force_attention: num_decode_reqs = sum(num_scheduled_tokens == 1) - num_scheduled_tokens[:num_reqs], _, pcp_metadata = \ + num_scheduled_tokens[:num_reqs], _, pcp_metadata = ( self._update_tokens_for_pcp( num_scheduled_tokens[:num_reqs], dummy_input=True, num_reqs=num_reqs, num_decode_reqs=num_decode_reqs, ) + ) total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) From 3e2f6fd031061eaf00b4f0809176bbb2cfc1338c Mon Sep 17 00:00:00 2001 From: Jingchun Gao Date: Tue, 25 Nov 2025 21:07:00 +0800 Subject: [PATCH 6/7] Move flashinfer-related params out of utils Signed-off-by: Jingchun Gao --- vllm/v1/attention/backends/flashinfer.py | 320 +++++++++++++---------- vllm/v1/attention/backends/utils.py | 69 ++++- vllm/v1/spec_decode/eagle.py | 6 - vllm/v1/worker/gpu_model_runner.py | 142 +--------- 4 files changed, 262 insertions(+), 275 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 593dc2bab7c7..cc2ed9e68dd4 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import ClassVar +from xxlimited import Str import numpy as np import torch @@ -57,7 +58,9 @@ PrefillContextParallelMetadata, get_cp_local_seq_lens, get_kv_cache_layout, + get_kv_indices, get_per_layer_parameters, + get_q_indices, infer_global_hyperparameters, split_decodes_and_prefills, ) @@ -171,6 +174,7 @@ def __init__( self, dcp_world_size, pcp_world_size, + max_num_reqs, workspace_buffer: torch.Tensor | None = None, ): self.dcp_world_size = dcp_world_size @@ -185,11 +189,62 @@ def __init__( self._new_tokens_tail = BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, get_kv_cache_layout() ) + pin_memory = is_pin_memory_available() + self.pcp_q_indptr_cpu = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.pcp_q_indptr_np = self.pcp_q_indptr_cpu.numpy() + self.kv_for_head_indptr_cpu = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.kv_for_head_indptr_np = self.kv_for_head_indptr_cpu.numpy() + self.kv_for_tail_indptr_cpu = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.kv_for_tail_indptr_np = self.kv_for_tail_indptr_cpu.numpy() else: self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, get_kv_cache_layout() ) + def _iter_wrappers(self) -> list: + wrappers = [self._context] + if self.pcp_world_size > 1: + wrappers.extend([self._new_tokens_head, self._new_tokens_tail]) + else: + wrappers.append(self._new_tokens) + return wrappers + + def _get_consistent_attr(self, attr_name: str): + wrappers = self._iter_wrappers() + base_value = getattr(wrappers[0], attr_name) + for wrapper in wrappers[1:]: + value = getattr(wrapper, attr_name) + assert ( + value == base_value + ), f"Inconsistent {attr_name} detected across CP prefill wrappers." + return base_value + + @property + def _window_left(self): + return self._get_consistent_attr("_window_left") + + @property + def _logits_soft_cap(self): + return self._get_consistent_attr("_logits_soft_cap") + + @property + def _sm_scale(self): + return self._get_consistent_attr("_sm_scale") + + def _assert_causal(self) -> None: + assert not self._context._causal + if self.pcp_world_size > 1: + assert self._new_tokens_head._causal + assert self._new_tokens_tail._causal + else: + assert self._new_tokens._causal + def plan( self, qo_indptr_cpu: torch.Tensor, @@ -208,8 +263,8 @@ def plan( kv_cache_dtype: torch.dtype, prefill_fixed_split_size: int, disable_split_kv: bool, - pcp_metadata: PrefillContextParallelMetadata | None, - ): + device: str, + ) -> PrefillContextParallelMetadata: """Plan the prefill operation with given parameters.""" common_args = [ num_kv_heads, @@ -236,13 +291,30 @@ def plan( **common_kwargs, ) if self.pcp_world_size > 1: - assert pcp_metadata is not None - qo_indptr_cpu = pcp_metadata.q_head_start_loc - kv_for_head_indptr = pcp_metadata.kv_for_head_indptr - kv_for_tail_indptr = pcp_metadata.kv_for_tail_indptr - self._new_tokens_head.plan( + self.pcp_q_indptr_cpu[: qo_indptr_cpu.shape[0]] = qo_indptr_cpu // 2 + pcp_q_indptr_cpu = self.pcp_q_indptr_cpu[: qo_indptr_cpu.shape[0]] + self.kv_for_head_indptr_cpu[: qo_indptr_cpu.shape[0]] = ( + (self.pcp_rank + 1) * pcp_q_indptr_cpu + ) + self.kv_for_tail_indptr_cpu[: qo_indptr_cpu.shape[0]] = ( + (2 * self.pcp_world_size - self.pcp_rank) * pcp_q_indptr_cpu + ) + q_head_indices, q_tail_indices = get_q_indices( qo_indptr_cpu, - kv_for_head_indptr, + self.pcp_q_indptr_np[: qo_indptr_cpu.shape[0]], + ) + q_head_indices_gpu = q_head_indices.to(device) + q_tail_indices_gpu = q_tail_indices.to(device) + q_full_indices = torch.cat([q_head_indices_gpu, q_tail_indices_gpu]).argsort() + kv_for_head_indices, kv_for_tail_indices = get_kv_indices( + qo_indptr_cpu * self.pcp_world_size, + self.kv_for_head_indptr_np[: qo_indptr_cpu.shape[0]], + self.kv_for_tail_indptr_np[: qo_indptr_cpu.shape[0]], + ) + + self._new_tokens_head.plan( + pcp_q_indptr_cpu, + self.kv_for_head_indptr_cpu[: qo_indptr_cpu.shape[0]], num_qo_heads, *common_args, causal=True, @@ -250,12 +322,19 @@ def plan( ) self._new_tokens_tail.plan( qo_indptr_cpu, - kv_for_tail_indptr, + self.kv_for_tail_indptr_cpu[: qo_indptr_cpu.shape[0]], num_qo_heads, *common_args, causal=True, **common_kwargs, ) + return PrefillContextParallelMetadata( + q_head_indices=q_head_indices, + q_tail_indices=q_tail_indices, + kv_for_head_indices=kv_for_head_indices.to(device, non_blocking=True), + kv_for_tail_indices=kv_for_tail_indices.to(device, non_blocking=True), + q_full_indices=q_full_indices, + ) else: self._new_tokens.plan( qo_indptr_cpu, @@ -267,90 +346,86 @@ def plan( causal=True, # This is newtokens run **common_kwargs, ) - - def _attention_with_head_and_tail( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - metadata: PrefillContextParallelMetadata, - return_lse: bool = False, - ): - """ - For prompt with tokens [T0, T1, T2, T3], the query on PCP0 is [Q0, Q3] - and we all-gather full K as [K0, K1, K2, K3]. - There are two attn ops. The "head" is [Q0]x[K0] and the "tail" is - [Q3]x[K0, K1, K2, K3]. - """ - q_head_indices = metadata.q_head_indices - q_tail_indices = metadata.q_tail_indices - kv_for_head_indices = metadata.kv_for_head_indices - kv_for_tail_indices = metadata.kv_for_tail_indices - q_full_indices = metadata.q_full_indices - - q_head = torch.index_select(query, 0, q_head_indices) - q_tail = torch.index_select(query, 0, q_tail_indices) - k_head = torch.index_select(key, 0, kv_for_head_indices) - v_head = torch.index_select(value, 0, kv_for_head_indices) - k_tail = torch.index_select(key, 0, kv_for_tail_indices) - v_tail = torch.index_select(value, 0, kv_for_tail_indices) - - output_head = self._new_tokens_head.run( - q_head, - k_head, - v_head, - return_lse=return_lse, - ) - output_tail = self._new_tokens_tail.run( - q_tail, - k_tail, - v_tail, - return_lse=return_lse, - ) - if return_lse: - output_head, lse_head = output_head - output_tail, lse_tail = output_tail - output = torch.index_select( - torch.cat([output_head, output_tail], dim=0), - 0, - q_full_indices, - ) - lse = torch.index_select( - torch.cat([lse_head, lse_tail], dim=0), - 0, - q_full_indices, - ) - return output, lse - else: - output = torch.index_select( - torch.cat([output_head, output_tail], dim=0), - 0, - q_full_indices, - ) - return output + return None def run( self, + num_actual_tokens: int, + num_decode_tokens: int, layer: torch.nn.Module, - prefill_query: torch.Tensor, + query: torch.Tensor, kv_cache_permute: torch.Tensor, key: torch.Tensor, value: torch.Tensor, out: torch.Tensor, + allgather_restore_idx: torch.Tensor | None, pcp_metadata: PrefillContextParallelMetadata | None, + return_lse: bool = False, ): if self.pcp_world_size > 1: assert pcp_metadata is not None - out[:], lse_query = self._attention_with_head_and_tail( - prefill_query, - key, - value, - pcp_metadata, - return_lse=True, + # NOTE(yyj): we must `slice` key and value because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. To be optimized for performance! + key_across_pcp = get_pcp_group().all_gather(key.contiguous(), dim=0) + value_across_pcp = get_pcp_group().all_gather(value.contiguous(), dim=0) + # Reorder kv after pcp allgather. + # Note that there are duplicate decoding tokens, + # but we only save the first one in kvcache. + key = torch.index_select(key_across_pcp, 0, allgather_restore_idx) + value = torch.index_select(value_across_pcp, 0, allgather_restore_idx) + key = key[: num_actual_tokens * self.pcp_world_size] + value = value[: num_actual_tokens * self.pcp_world_size] + + key = key[num_decode_tokens * self.pcp_world_size :] + value = value[num_decode_tokens * self.pcp_world_size :] + """ + For prompt with tokens [T0, T1, T2, T3], the query on PCP0 is [Q0, Q3] + and we all-gather full K as [K0, K1, K2, K3]. + There are two attn ops. The "head" is [Q0]x[K0] and the "tail" is + [Q3]x[K0, K1, K2, K3]. + """ + q_head_indices = pcp_metadata.q_head_indices + q_tail_indices = pcp_metadata.q_tail_indices + kv_for_head_indices = pcp_metadata.kv_for_head_indices + kv_for_tail_indices = pcp_metadata.kv_for_tail_indices + q_full_indices = pcp_metadata.q_full_indices + + q_head = torch.index_select(query, 0, q_head_indices) + q_tail = torch.index_select(query, 0, q_tail_indices) + k_head = torch.index_select(key, 0, kv_for_head_indices) + v_head = torch.index_select(value, 0, kv_for_head_indices) + k_tail = torch.index_select(key, 0, kv_for_tail_indices) + v_tail = torch.index_select(value, 0, kv_for_tail_indices) + + output_head = self._new_tokens_head.run( + q_head, + k_head, + v_head, + return_lse=return_lse, + ) + output_tail = self._new_tokens_tail.run( + q_tail, + k_tail, + v_tail, + return_lse=return_lse, + ) + if return_lse: + output_head, lse_head = output_head + output_tail, lse_tail = output_tail + lse = torch.index_select( + torch.cat([lse_head, lse_tail], dim=0), + 0, + q_full_indices, + ) + + out[:] = torch.index_select( + torch.cat([output_head, output_tail], dim=0), + 0, + q_full_indices, ) else: prefill_query_across_dcp = get_dcp_group().all_gather( - prefill_query.contiguous(), dim=1 + query.contiguous(), dim=1 ) output_context_tmp, lse_context_tmp = self._context.run( prefill_query_across_dcp, @@ -365,7 +440,7 @@ def run( lse_context = lse_context.transpose(0, 1).contiguous() output_query, lse_query = self._new_tokens.run( - prefill_query, + query, key, value, return_lse=True, @@ -591,6 +666,7 @@ def __init__( self.cp_kv_cache_interleave_size = 1 self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank + self.max_num_reqs = max_num_reqs self.num_qo_heads = self.model_config.get_num_attention_heads( self.vllm_config.parallel_config @@ -715,6 +791,7 @@ def _get_prefill_wrapper( self._prefill_wrapper = BatchCPPrefillWrapper( self.dcp_world_size, self.pcp_world_size, + self.max_num_reqs, self._get_workspace_buffer(), ) else: @@ -934,7 +1011,6 @@ def build( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, - pcp_metadata=common_attn_metadata.pcp_metadata, ) paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] @@ -992,7 +1068,7 @@ def build( assert isinstance( attn_metadata.prefill_wrapper, BatchCPPrefillWrapper ) - attn_metadata.prefill_wrapper.plan( + pcp_metadata = attn_metadata.prefill_wrapper.plan( qo_indptr_cpu=qo_indptr_cpu, paged_kv_indptr_cpu=paged_kv_indptr_cpu, paged_kv_indices=paged_kv_indices, @@ -1009,10 +1085,9 @@ def build( kv_cache_dtype=self.kv_cache_dtype, prefill_fixed_split_size=self.prefill_fixed_split_size, disable_split_kv=self.disable_split_kv, - pcp_metadata=common_attn_metadata.pcp_metadata - if self.pcp_world_size > 1 - else None, + device=self.device, ) + attn_metadata.pcp_metadata = pcp_metadata else: assert isinstance( attn_metadata.prefill_wrapper, @@ -1168,6 +1243,7 @@ def __init__( ) self.sinks = sinks + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) self.bmm1_scale: float | None = None self.bmm2_scale: float | None = None @@ -1279,23 +1355,6 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens - if self.pcp_world_size > 1: - assert attn_metadata.pcp_metadata is not None - pcp_allgather_restore_idx = attn_metadata.pcp_metadata.allgather_restore_idx - assert pcp_allgather_restore_idx is not None - # NOTE(yyj): we must `slice` key and value because pcp_allgather_restore_idx - # ignores the padding from CUDA Graph. To be optimized for performance! - key_across_pcp = get_pcp_group().all_gather( - key[:num_actual_tokens].contiguous(), dim=0 - ) - value_across_pcp = get_pcp_group().all_gather( - value[:num_actual_tokens].contiguous(), dim=0 - ) - # Reorder kv after pcp allgather. - # Note that there are duplicate decoding tokens, - # but we only save the first one in kvcache. - key = torch.index_select(key_across_pcp, 0, pcp_allgather_restore_idx) - value = torch.index_select(value_across_pcp, 0, pcp_allgather_restore_idx) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. @@ -1325,8 +1384,8 @@ def forward( # Inputs and outputs may be padded for CUDA graphs query = query[:num_actual_tokens] - key = key[: num_actual_tokens * self.pcp_world_size] - value = value[: num_actual_tokens * self.pcp_world_size] + key = key[: num_actual_tokens] + value = value[: num_actual_tokens] output_padded = output output = output[:num_actual_tokens] @@ -1353,34 +1412,29 @@ def forward( assert prefill_wrapper is not None if not attn_metadata.prefill_use_trtllm: + expected_logits_soft_cap = self.logits_soft_cap or 0.0 if self.total_cp_world_size > 1: assert isinstance(prefill_wrapper, BatchCPPrefillWrapper) - expected_logits_soft_cap = self.logits_soft_cap or 0.0 - - wrappers_to_check = [(prefill_wrapper._context, False)] - if self.pcp_world_size > 1: - wrappers_to_check.extend( - [ - (prefill_wrapper._new_tokens_head, True), - (prefill_wrapper._new_tokens_tail, True), - ] + if self.is_debugging_mode: + assert prefill_wrapper._window_left == self.window_left + assert ( + prefill_wrapper._logits_soft_cap == expected_logits_soft_cap ) - else: - wrappers_to_check.append((prefill_wrapper._new_tokens, True)) - - for wrapper, expected_causal in wrappers_to_check: - assert wrapper._window_left == self.window_left - assert wrapper._logits_soft_cap == expected_logits_soft_cap - assert wrapper._sm_scale == self.scale - assert wrapper._causal == expected_causal - + assert prefill_wrapper._sm_scale == self.scale + prefill_wrapper._assert_causal() + assert attn_metadata.pcp_metadata is not None + pcp_allgather_restore_idx = attn_metadata.pcp_metadata.allgather_restore_idx + assert pcp_allgather_restore_idx is not None prefill_wrapper.run( + num_actual_tokens, + num_decode_tokens, layer, prefill_query, kv_cache_permute, - key[num_decode_tokens * self.pcp_world_size :], - value[num_decode_tokens * self.pcp_world_size :], + key, + value, out=output[num_decode_tokens:], + allgather_restore_idx=pcp_allgather_restore_idx, pcp_metadata=attn_metadata.pcp_metadata if self.pcp_world_size > 1 else None, @@ -1389,12 +1443,13 @@ def forward( assert isinstance( prefill_wrapper, BatchPrefillWithPagedKVCacheWrapper ) - assert prefill_wrapper._window_left == self.window_left - assert prefill_wrapper._logits_soft_cap == ( - self.logits_soft_cap or 0.0 - ) - assert prefill_wrapper._sm_scale == self.scale - assert prefill_wrapper._causal + if self.is_debugging_mode: + assert prefill_wrapper._window_left == self.window_left + assert ( + prefill_wrapper._logits_soft_cap == expected_logits_soft_cap + ) + assert prefill_wrapper._sm_scale == self.scale + assert prefill_wrapper._causal prefill_wrapper.run( prefill_query, kv_cache_permute, @@ -1490,14 +1545,17 @@ def forward( return_lse=True, ) if self.dcp_world_size > 1: - out, lse = cp_lse_ag_out_rs( + out = cp_lse_ag_out_rs( out, lse, get_dcp_group(), - return_lse=True, + return_lse=self.pcp_world_size > 1, ) if self.pcp_world_size > 1: + if isinstance(out, tuple): + out, lse = out out = cp_lse_ag_out_ar(out, lse, get_pcp_group()) + assert isinstance(out, torch.Tensor) output[:num_decode_tokens] = out else: decode_wrapper.run( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index d52a27f68e9d..a5173d9ed3cc 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -62,11 +62,8 @@ class PrefillContextParallelMetadata: """ q_head_indices: torch.Tensor | None = None q_tail_indices: torch.Tensor | None = None - q_head_start_loc: torch.Tensor | None = None kv_for_head_indices: torch.Tensor | None = None kv_for_tail_indices: torch.Tensor | None = None - kv_for_head_indptr: torch.Tensor | None = None - kv_for_tail_indptr: torch.Tensor | None = None q_full_indices: torch.Tensor | None = None @@ -1137,3 +1134,69 @@ def get_cp_local_seq_lens( ) cp_local_seq_lens = base + remainder return cp_local_seq_lens.squeeze(1) + +def get_pcp_metadata( + q_start_loc_cpu: torch.Tensor, + q_indptr_np: np.ndarray, + allgather_restore_idx: torch.Tensor, + kv_start_loc_cpu: torch.Tensor, + kv_for_head_indptr_np: np.ndarray, + kv_for_tail_indptr_np: np.ndarray, +) -> PrefillContextParallelMetadata: + """ + During the prefill phrase, the attention computation is divided into + two parts: q_head and q_tail. Here, we calculate the kv indices + corresponding to q_head or q_tail. Meawhile, the q and kv indptr are + also computed to build the attention wrapper. + If the pcp_size is 2, the variables are following: + >>> q_lens [4, 8] kv_lens [8, 16] + >>> pcp_chunk_sizes[2, 4] + >>> q_indptr[0, 2, 4] + >>> q_head_indices [0, 1, 4, 5, 6, 7] q_tail_indices [2, 3, 8, 9, 10, 11] + >>> kv_head_len r0 [2, 4] / r1 [4, 8] + >>> kv_for_head_indptr r0 [0, 2, 6] / r1 [0, 4, 12] + >>> kv_for_head_indices r0 [0, 1, 8, 9, 10, 11] + >>> r1[0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15] + >>> kv_tail_len r0 [8, 16] / r1 [6, 12] + >>> kv_for_tail_indptr r0 [0, 8, 24] / r1 [0, 6, 18] + >>> kv_for_tail_indices r0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 23] + >>> r1[0, 1, 2, 3, 4, 5, 8, 9, ..., 19] + """ + + +def get_q_indices( + q_start_loc_cpu: torch.Tensor, + q_indptr_np: np.ndarray, +) -> tuple[torch.Tensor, torch.Tensor]: + q_head_indices, q_selected_num_tokens = get_pcp_selected_indices(q_start_loc_cpu, q_indptr_np) + q_tail_indices = q_head_indices + np.repeat(q_selected_num_tokens, q_selected_num_tokens) + + return q_head_indices, q_tail_indices + + +def get_kv_indices( + kv_start_loc_cpu: torch.Tensor, + kv_for_head_indptr_np: np.ndarray, + kv_for_tail_indptr_np: np.ndarray, +) -> tuple[torch.Tensor, torch.Tensor]: + kv_for_head_indices = get_pcp_selected_indices(kv_start_loc_cpu, kv_for_head_indptr_np) + kv_for_tail_indices = get_pcp_selected_indices(kv_start_loc_cpu, kv_for_tail_indptr_np) + return kv_for_head_indices, kv_for_tail_indices + + +def get_pcp_selected_indices( + start_loc_cpu: torch.Tensor, + selected_cu_num_tokens: np.ndarray, + need_num_tokens: bool = False, +) -> torch.Tensor: + start_loc_np = start_loc_cpu.numpy() + selected_num_tokens = selected_cu_num_tokens[1:] - selected_cu_num_tokens[:-1] + cumsums_offsets = np.repeat(selected_cu_num_tokens[:-1], selected_num_tokens) + arange = np.arange(selected_cu_num_tokens[-1], dtype=np.int32) - cumsums_offsets + selected_indices = arange + np.repeat(start_loc_np, selected_num_tokens) + selected_indices_cpu = torch.from_numpy(selected_indices) + + if need_num_tokens: + return selected_indices_cpu, selected_num_tokens + else: + return selected_indices_cpu \ No newline at end of file diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5cab0f6adea1..d3200cb94f05 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -644,9 +644,6 @@ def prepare_inputs_padded( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, - cp_local_seq_lens=common_attn_metadata.cp_local_seq_lens, - cp_local_seq_lens_cpu=common_attn_metadata.cp_local_seq_lens_cpu, - pcp_metadata=common_attn_metadata.pcp_metadata, ) token_indices_to_sample = ( @@ -923,9 +920,6 @@ def prepare_inputs( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, - cp_local_seq_lens=common_attn_metadata.cp_local_seq_lens, - cp_local_seq_lens_cpu=common_attn_metadata.cp_local_seq_lens_cpu, - pcp_metadata=common_attn_metadata.pcp_metadata, ) return spec_common_attn_metadata, token_indices diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 675b302858b4..babd93798b18 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -526,27 +526,6 @@ def __init__( (max_num_tokens,), device="cpu", dtype=torch.bool, pin_memory=True ) self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() - self.q_indptr_cpu_tensor = torch.zeros( - (self.max_num_reqs + 1,), - device="cpu", - dtype=torch.int64, - pin_memory=True, - ) - self.q_indptr_cpu = self.q_indptr_cpu_tensor.numpy() - self.kv_for_head_indptr_cpu_tensor = torch.zeros( - (self.max_num_reqs + 1,), - device="cpu", - dtype=torch.int64, - pin_memory=True, - ) - self.kv_for_head_indptr_cpu = self.kv_for_head_indptr_cpu_tensor.numpy() - self.kv_for_tail_indptr_cpu_tensor = torch.zeros( - (self.max_num_reqs + 1,), - device="cpu", - dtype=torch.int64, - pin_memory=True, - ) - self.kv_for_tail_indptr_cpu = self.kv_for_tail_indptr_cpu_tensor.numpy() # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -1064,119 +1043,13 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: dummy_modality = mm_budget.get_modality_with_max_tokens() return self._get_mm_dummy_batch(dummy_modality, num_seqs) - def _get_pcp_metadata( - self, - q_lens: np.ndarray, - kv_lens: np.ndarray, - allgather_restore_idx: torch.Tensor, - ) -> PrefillContextParallelMetadata: - """ - During the prefill phrase, the attention computation is divided into - two parts: q_head and q_tail. Here, we calculate the kv indices - corresponding to q_head or q_tail. Meawhile, the q and kv indptr are - also computed to build the attention wrapper. - If the pcp_size is 2, the variables are following: - >>> q_lens [4, 8] kv_lens [8, 16] - >>> pcp_chunk_sizes[2, 4] - >>> q_indptr[0, 2, 4] - >>> q_head_indices [0, 1, 4, 5, 6, 7] q_tail_indices [2, 3, 8, 9, 10, 11] - >>> kv_head_len r0 [2, 4] / r1 [4, 8] - >>> kv_for_head_indptr r0 [0, 2, 6] / r1 [0, 4, 12] - >>> kv_for_head_indices r0 [0, 1, 8, 9, 10, 11] - >>> r1[0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15] - >>> kv_tail_len r0 [8, 16] / r1 [6, 12] - >>> kv_for_tail_indptr r0 [0, 8, 24] / r1 [0, 6, 18] - >>> kv_for_tail_indices r0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 23] - >>> r1[0, 1, 2, 3, 4, 5, 8, 9, ..., 19] - """ - if len(q_lens) == 0: - return PrefillContextParallelMetadata( - allgather_restore_idx=allgather_restore_idx, - ) - - def _get_partial_kv_idx(kv_partial_len, kv_partial_indptr, kv_parial_indices): - kv_partial_indptr[1 : len(kv_partial_len) + 1], kv_partial_arange = ( - self._get_cumsum_and_arange(kv_partial_len) - ) - kv_parial_indices.np[: kv_partial_arange.shape[0]] = ( - kv_partial_arange - + np.repeat( - kv_start_loc, - kv_partial_len, - ) - ) - return kv_partial_arange.shape[0] - - pcp_chunk_sizes = q_lens // 2 - self.q_indptr_cpu[1 : len(pcp_chunk_sizes) + 1], q_chunk_arange = ( - self._get_cumsum_and_arange(pcp_chunk_sizes) - ) - - q_head_start_loc = np.roll(np.cumsum(q_lens), 1) - q_head_start_loc[0] = 0 - self.q_head_indices.np[: q_chunk_arange.shape[0]] = q_chunk_arange + np.repeat( - q_head_start_loc, - pcp_chunk_sizes, - ) - - self.q_head_indices.copy_to_gpu(q_chunk_arange.shape[0]) - - q_tail_start_loc = q_head_start_loc + pcp_chunk_sizes - self.q_tail_indices.np[: q_chunk_arange.shape[0]] = q_chunk_arange + np.repeat( - q_tail_start_loc, - pcp_chunk_sizes, - ) - self.q_tail_indices.copy_to_gpu(q_chunk_arange.shape[0]) - - kv_start_loc = np.roll(np.cumsum(kv_lens), 1) - kv_start_loc[0] = 0 - # kv_for_q_head - kv_for_head_len = (self.pcp_rank + 1) * pcp_chunk_sizes - kv_head_tokens_sum = _get_partial_kv_idx( - kv_for_head_len, - self.kv_for_head_indptr_cpu, - self.kv_for_head_indices, - ) - self.kv_for_head_indices.copy_to_gpu(kv_head_tokens_sum) - # kv_for_q_tail - kv_for_tail_len = (2 * self.pcp_world_size - self.pcp_rank) * pcp_chunk_sizes - kv_tail_tokens_sum = _get_partial_kv_idx( - kv_for_tail_len, - self.kv_for_tail_indptr_cpu, - self.kv_for_tail_indices, - ) - self.kv_for_tail_indices.copy_to_gpu(kv_tail_tokens_sum) - - q_full_indices = torch.cat( - [ - self.q_head_indices.gpu[: q_chunk_arange.shape[0]], - self.q_tail_indices.gpu[: q_chunk_arange.shape[0]], - ] - ).argsort() - - return PrefillContextParallelMetadata( - allgather_restore_idx=allgather_restore_idx, - q_head_indices=self.q_head_indices.gpu[: q_chunk_arange.shape[0]], - q_tail_indices=self.q_tail_indices.gpu[: q_chunk_arange.shape[0]], - q_head_start_loc=self.q_indptr_cpu_tensor[: len(pcp_chunk_sizes) + 1], - kv_for_head_indices=self.kv_for_head_indices.gpu[:kv_head_tokens_sum], - kv_for_tail_indices=self.kv_for_tail_indices.gpu[:kv_tail_tokens_sum], - kv_for_head_indptr=( - self.kv_for_head_indptr_cpu_tensor[: len(kv_for_head_len) + 1] - ), - kv_for_tail_indptr=( - self.kv_for_tail_indptr_cpu_tensor[: len(kv_for_tail_len) + 1] - ), - q_full_indices=q_full_indices, - ) - def _update_tokens_for_pcp( self, tokens: np.ndarray, dummy_input: bool = False, num_reqs: int | None = None, num_decode_reqs: int | None = None, - ) -> tuple[np.ndarray, np.ndarray, PrefillContextParallelMetadata]: + ) -> tuple[np.ndarray, np.ndarray]: """ If prefill context parallelism is enabled, we will update the number of `tokens` after sequence splitting. @@ -1285,11 +1158,6 @@ def get_current_rank_positions( return ( pcp_tokens[:num_reqs], positions, - self._get_pcp_metadata( - pcp_tokens[num_decode_reqs:], - num_padded_scheduled_tokens[num_decode_reqs:], - self.pcp_allgather_restore_idx.gpu[: all_positions.shape[0]], - ), ) def _get_cumsum_and_arange( @@ -1500,7 +1368,7 @@ def _prepare_inputs( pcp_metadata = None if self.pcp_world_size > 1: - num_scheduled_tokens[:num_reqs], pcp_positions, pcp_metadata = ( + num_scheduled_tokens[:num_reqs], pcp_positions = ( self._update_tokens_for_pcp(num_scheduled_tokens[:num_reqs]) ) @@ -1873,7 +1741,11 @@ def _build_attention_metadata( encoder_seq_lens=encoder_seq_lens, cp_local_seq_lens=cp_local_seq_lens, cp_local_seq_lens_cpu=cp_local_seq_lens_cpu, - pcp_metadata=pcp_metadata, + pcp_allgather_restore_idx=self.pcp_allgather_restore_idx.gpu[ + : total_num_scheduled_tokens * self.pcp_world_size + ] + if self.pcp_world_size > 1 + else None, ) if self.speculative_config and spec_decode_common_attn_metadata is None: From 07e78b14ed1dd6a0987a0ec5c8dde1fe1654d051 Mon Sep 17 00:00:00 2001 From: Jingchun Gao Date: Wed, 26 Nov 2025 09:05:01 +0800 Subject: [PATCH 7/7] fix bug&& add doc Signed-off-by: Jingchun Gao --- vllm/v1/attention/backends/flashinfer.py | 103 ++++++++++++++------- vllm/v1/attention/backends/utils.py | 112 +++++++++++++---------- vllm/v1/spec_decode/eagle.py | 2 + vllm/v1/worker/gpu_model_runner.py | 97 +++++++++++++++----- 4 files changed, 209 insertions(+), 105 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index cc2ed9e68dd4..7f3c845edde9 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -62,6 +62,7 @@ get_per_layer_parameters, get_q_indices, infer_global_hyperparameters, + pcp_kv_allgather_and_restore, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -174,11 +175,13 @@ def __init__( self, dcp_world_size, pcp_world_size, + pcp_rank, max_num_reqs, workspace_buffer: torch.Tensor | None = None, ): self.dcp_world_size = dcp_world_size self.pcp_world_size = pcp_world_size + self.pcp_rank = pcp_rank self._context = BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, get_kv_cache_layout() ) @@ -291,6 +294,45 @@ def plan( **common_kwargs, ) if self.pcp_world_size > 1: + """ + During the prefill phrase, the attention computation is divided into + head-tail style for load balancing. Here, we calculate the q and kv indptr + of head and tail for building attention wrappers. Then, the selected indices + of q, kv corresponding to head and tail are also computed. + + Example: + If the pcp_size is 2, qo_indptr_cpu=[0, 4, 12], so + >>> query Rank0 [T0_0, T0_1, T0_6, T0_7 | T1_0, ..., T1_3, T1_12, ..., T1_15] + Rank1 [T0_2, T0_3, T0_4, T0_5 | T1_4, ..., T1_7, T1_8, ..., T1_11] + >>> allgather_restored_kv [T0_0, T0_1, ..., T0_7 | T1_0, T1_1, ..., T1_15], + whose length is pcp_size(2) times query's. + T_m_n means the n-th token of m-th req. The variables are following: + >>> pcp_q_indptr [0, 2, 6], for both head and tail wrappers + >>> q_head_indices [0, 1, 4, 5, 6, 7] q_tail_indices [2, 3, 8, 9, 10, 11] + Rank0 + >>> selected_q_head [T0_0, T0_1 | T1_0, T1_1, T1_2, T1_3] + >>> selected_q_tail [T0_6, T0_7 | T1_12, T1_13, T1_14, T1_15] + + >>> kv_for_head_indptr (rank(0)+1)*pcp_q_indptr [0, 2, 6] + >>> kv_for_head_indices [0, 1, 8, 9, 10, 11] + >>> selected_kv_for_head [T0_0, T0_1 | T1_0, T1_1, T1_2, T1_3] + + >>> kv_for_tail_indptr (2*pcp_size(2) - rank(0))*pcp_q_indptr [0, 8, 24] + >>> kv_for_tail_indices [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 23] + >>> selected_kv_for_head [T0_0, ..., T0_7 | T1_0, ... T1_15] + + Rank1 + >>> selected_q_head [T0_2, T0_3 | T1_4, T1_5, T1_6, T1_7] + >>> selected_q_tail [T0_4, T0_5 | T1_8, T1_9, T1_10, T1_11] + + >>> kv_for_head_indptr (rank(1)+1)*pcp_q_indptr [0, 4, 12] + >>> kv_for_head_indices [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15] + >>> selected_kv_for_head [T0_0, ..., T0_3 | T1_0, ..., T1_7] + + >>> kv_for_tail_indptr (2*pcp_size(2) - rank(1))*pcp_q_indptr [0, 6, 18] + >>> kv_for_tail_indices [0, 1, 2, 3, 4, 5, 8, 9, ..., 19] + >>> selected_kv_for_head [T0_0, ..., T0_5 | T1_0, ... T1_11] + """ self.pcp_q_indptr_cpu[: qo_indptr_cpu.shape[0]] = qo_indptr_cpu // 2 pcp_q_indptr_cpu = self.pcp_q_indptr_cpu[: qo_indptr_cpu.shape[0]] self.kv_for_head_indptr_cpu[: qo_indptr_cpu.shape[0]] = ( @@ -299,19 +341,22 @@ def plan( self.kv_for_tail_indptr_cpu[: qo_indptr_cpu.shape[0]] = ( (2 * self.pcp_world_size - self.pcp_rank) * pcp_q_indptr_cpu ) + # Obtain selected q_indices based on complete and head-tail style q_indptrs q_head_indices, q_tail_indices = get_q_indices( - qo_indptr_cpu, + qo_indptr_cpu[:-1], self.pcp_q_indptr_np[: qo_indptr_cpu.shape[0]], ) q_head_indices_gpu = q_head_indices.to(device) q_tail_indices_gpu = q_tail_indices.to(device) + # This variable restore the origin sequence of query q_full_indices = torch.cat([q_head_indices_gpu, q_tail_indices_gpu]).argsort() + # Obtain selected kv_indices based on complete and head-tail style kv_indptrs kv_for_head_indices, kv_for_tail_indices = get_kv_indices( - qo_indptr_cpu * self.pcp_world_size, + qo_indptr_cpu[:-1] * self.pcp_world_size, self.kv_for_head_indptr_np[: qo_indptr_cpu.shape[0]], self.kv_for_tail_indptr_np[: qo_indptr_cpu.shape[0]], ) - + self._new_tokens_head.plan( pcp_q_indptr_cpu, self.kv_for_head_indptr_cpu[: qo_indptr_cpu.shape[0]], @@ -321,7 +366,7 @@ def plan( **common_kwargs, ) self._new_tokens_tail.plan( - qo_indptr_cpu, + pcp_q_indptr_cpu, self.kv_for_tail_indptr_cpu[: qo_indptr_cpu.shape[0]], num_qo_heads, *common_args, @@ -329,8 +374,8 @@ def plan( **common_kwargs, ) return PrefillContextParallelMetadata( - q_head_indices=q_head_indices, - q_tail_indices=q_tail_indices, + q_head_indices=q_head_indices_gpu, + q_tail_indices=q_tail_indices_gpu, kv_for_head_indices=kv_for_head_indices.to(device, non_blocking=True), kv_for_tail_indices=kv_for_tail_indices.to(device, non_blocking=True), q_full_indices=q_full_indices, @@ -350,34 +395,17 @@ def plan( def run( self, - num_actual_tokens: int, - num_decode_tokens: int, layer: torch.nn.Module, query: torch.Tensor, kv_cache_permute: torch.Tensor, key: torch.Tensor, value: torch.Tensor, out: torch.Tensor, - allgather_restore_idx: torch.Tensor | None, pcp_metadata: PrefillContextParallelMetadata | None, return_lse: bool = False, ): if self.pcp_world_size > 1: assert pcp_metadata is not None - # NOTE(yyj): we must `slice` key and value because pcp_allgather_restore_idx - # ignores the padding from CUDA Graph. To be optimized for performance! - key_across_pcp = get_pcp_group().all_gather(key.contiguous(), dim=0) - value_across_pcp = get_pcp_group().all_gather(value.contiguous(), dim=0) - # Reorder kv after pcp allgather. - # Note that there are duplicate decoding tokens, - # but we only save the first one in kvcache. - key = torch.index_select(key_across_pcp, 0, allgather_restore_idx) - value = torch.index_select(value_across_pcp, 0, allgather_restore_idx) - key = key[: num_actual_tokens * self.pcp_world_size] - value = value[: num_actual_tokens * self.pcp_world_size] - - key = key[num_decode_tokens * self.pcp_world_size :] - value = value[num_decode_tokens * self.pcp_world_size :] """ For prompt with tokens [T0, T1, T2, T3], the query on PCP0 is [Q0, Q3] and we all-gather full K as [K0, K1, K2, K3]. @@ -588,6 +616,7 @@ class FlashInferMetadata: paged_kv_indptr_gpu: torch.Tensor | None = None # For context parallel + pcp_allgather_restore_idx: torch.Tensor | None = None pcp_metadata: PrefillContextParallelMetadata | None = None @@ -791,6 +820,7 @@ def _get_prefill_wrapper( self._prefill_wrapper = BatchCPPrefillWrapper( self.dcp_world_size, self.pcp_world_size, + self.pcp_rank, self.max_num_reqs, self._get_workspace_buffer(), ) @@ -1011,6 +1041,7 @@ def build( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, + pcp_allgather_restore_idx=common_attn_metadata.pcp_allgather_restore_idx ) paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] @@ -1355,6 +1386,17 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens + if self.pcp_world_size > 1: + pcp_allgather_restore_idx = attn_metadata.pcp_allgather_restore_idx + assert pcp_allgather_restore_idx is not None + key, value = pcp_kv_allgather_and_restore( + key, + value, + num_actual_tokens, + pcp_allgather_restore_idx, + get_pcp_group(), + ) + if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. @@ -1384,8 +1426,8 @@ def forward( # Inputs and outputs may be padded for CUDA graphs query = query[:num_actual_tokens] - key = key[: num_actual_tokens] - value = value[: num_actual_tokens] + key = key[: num_actual_tokens * self.pcp_world_size] + value = value[: num_actual_tokens * self.pcp_world_size] output_padded = output output = output[:num_actual_tokens] @@ -1422,19 +1464,14 @@ def forward( ) assert prefill_wrapper._sm_scale == self.scale prefill_wrapper._assert_causal() - assert attn_metadata.pcp_metadata is not None - pcp_allgather_restore_idx = attn_metadata.pcp_metadata.allgather_restore_idx - assert pcp_allgather_restore_idx is not None + prefill_wrapper.run( - num_actual_tokens, - num_decode_tokens, layer, prefill_query, kv_cache_permute, - key, - value, + key[num_decode_tokens * self.pcp_world_size :], + value[num_decode_tokens * self.pcp_world_size :], out=output[num_decode_tokens:], - allgather_restore_idx=pcp_allgather_restore_idx, pcp_metadata=attn_metadata.pcp_metadata if self.pcp_world_size > 1 else None, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index a5173d9ed3cc..ad12cb8259af 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -33,6 +33,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout, ) +from vllm.distributed.parallel_state import GroupCoordinator from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.v1.kv_cache_interface import AttentionSpec @@ -53,10 +54,7 @@ def is_valid_kv_cache_layout(value: str) -> bool: class PrefillContextParallelMetadata: """ Attention metadata for prefill context parallel - """ - allgather_restore_idx: torch.Tensor - """ We split and concatenate the sequence in a head-tail style, and use this variable to restore the original order. """ @@ -113,7 +111,7 @@ class CommonAttentionMetadata: cp_local_seq_lens_cpu: torch.Tensor | None = None """Sequence lengths of the local rank in prefill/decode context parallelism world""" - pcp_metadata: PrefillContextParallelMetadata | None = None + pcp_allgather_restore_idx: torch.Tensor | None = None def slice_query_start_locs( @@ -1135,68 +1133,84 @@ def get_cp_local_seq_lens( cp_local_seq_lens = base + remainder return cp_local_seq_lens.squeeze(1) -def get_pcp_metadata( - q_start_loc_cpu: torch.Tensor, - q_indptr_np: np.ndarray, - allgather_restore_idx: torch.Tensor, - kv_start_loc_cpu: torch.Tensor, - kv_for_head_indptr_np: np.ndarray, - kv_for_tail_indptr_np: np.ndarray, -) -> PrefillContextParallelMetadata: +def pcp_kv_allgather_and_restore( + key: torch.Tensor, + value: torch.Tensor, + num_actual_tokens: int, + pcp_allgather_restore_idx: torch.Tensor, + pcp_group: GroupCoordinator, +): """ - During the prefill phrase, the attention computation is divided into - two parts: q_head and q_tail. Here, we calculate the kv indices - corresponding to q_head or q_tail. Meawhile, the q and kv indptr are - also computed to build the attention wrapper. - If the pcp_size is 2, the variables are following: - >>> q_lens [4, 8] kv_lens [8, 16] - >>> pcp_chunk_sizes[2, 4] - >>> q_indptr[0, 2, 4] - >>> q_head_indices [0, 1, 4, 5, 6, 7] q_tail_indices [2, 3, 8, 9, 10, 11] - >>> kv_head_len r0 [2, 4] / r1 [4, 8] - >>> kv_for_head_indptr r0 [0, 2, 6] / r1 [0, 4, 12] - >>> kv_for_head_indices r0 [0, 1, 8, 9, 10, 11] - >>> r1[0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15] - >>> kv_tail_len r0 [8, 16] / r1 [6, 12] - >>> kv_for_tail_indptr r0 [0, 8, 24] / r1 [0, 6, 18] - >>> kv_for_tail_indices r0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 23] - >>> r1[0, 1, 2, 3, 4, 5, 8, 9, ..., 19] + All-gather key and value tensors across PCP ranks and restore the original order. + Args: + key: key tensor for the current pcp rank. + value: value tensor for the current pcp rank. + num_actual_tokens: number of actual tokens (Exclude graph padding tokens). + pcp_allgather_restore_idx: indices to restore the original order. + pcp_group: PCP group coordinator. + Returns: + key: all-gathered and restored key tensor. + value: all-gathered and restored value tensor. """ + # NOTE(yyj): we must `slice` key and value because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. + # TODO(yyj) Batch all-gather operations to reduce launch overhead. + # Be careful about the dimensions of key and value. + key_across_cp = pcp_group.all_gather(key[:num_actual_tokens].contiguous(), dim=0) + value_across_cp = pcp_group.all_gather( + value[:num_actual_tokens].contiguous(), dim=0 + ) + # Reorder kv after pcp allgather. + # Note that there are duplicate decoding tokens after allgather. + key = torch.index_select(key_across_cp, 0, pcp_allgather_restore_idx) + value = torch.index_select(value_across_cp, 0, pcp_allgather_restore_idx) + return key, value + +def get_pcp_selected_indices( + start_loc_cpu: torch.Tensor, + selected_cu_num_tokens: np.ndarray, + need_num_tokens: bool = False, +) -> torch.Tensor: + """Get the selected indices when using "DualChunkSwap" style for load balancing. + e.g: [0, 4], [0, 2, 6] -> [0, 1, 4, 5, 6, 7], optional([2, 4]) + """ + start_loc_np = start_loc_cpu.numpy() + # Step1. Get batched selected_num_tokens array: [0, 2, 6] -> [2, 4] + selected_num_tokens = selected_cu_num_tokens[1:] - selected_cu_num_tokens[:-1] + # Step2: Get batched arrange of given selected_num_tokens + # [0, 2, 6], [2, 4] -> [0, 1, 0, 1, 2, 3] + cumsums_offsets = np.repeat(selected_cu_num_tokens[:-1], selected_num_tokens) + arange = np.arange(selected_cu_num_tokens[-1], dtype=np.int32) - cumsums_offsets + # Step3: Add complete req offsets + # [0, 1, 0, 1, 2, 3] + [0, 0, 4, 4, 4, 4] = [0, 1, 4, 5, 6, 7] + selected_indices = arange + np.repeat(start_loc_np, selected_num_tokens) + selected_indices_cpu = torch.from_numpy(selected_indices) + if need_num_tokens: + return selected_indices_cpu, selected_num_tokens + else: + return selected_indices_cpu def get_q_indices( q_start_loc_cpu: torch.Tensor, q_indptr_np: np.ndarray, ) -> tuple[torch.Tensor, torch.Tensor]: - q_head_indices, q_selected_num_tokens = get_pcp_selected_indices(q_start_loc_cpu, q_indptr_np) + """Get selected q_indices of both head and tail parts""" + q_head_indices, q_selected_num_tokens = ( + get_pcp_selected_indices(q_start_loc_cpu, q_indptr_np, need_num_tokens=True) + ) + # Get q_tail_indices by add q_selected_num_tokens offset q_tail_indices = q_head_indices + np.repeat(q_selected_num_tokens, q_selected_num_tokens) return q_head_indices, q_tail_indices - def get_kv_indices( kv_start_loc_cpu: torch.Tensor, kv_for_head_indptr_np: np.ndarray, kv_for_tail_indptr_np: np.ndarray, ) -> tuple[torch.Tensor, torch.Tensor]: + """Get selected kv_indices of both head and tail parts""" kv_for_head_indices = get_pcp_selected_indices(kv_start_loc_cpu, kv_for_head_indptr_np) kv_for_tail_indices = get_pcp_selected_indices(kv_start_loc_cpu, kv_for_tail_indptr_np) - return kv_for_head_indices, kv_for_tail_indices - -def get_pcp_selected_indices( - start_loc_cpu: torch.Tensor, - selected_cu_num_tokens: np.ndarray, - need_num_tokens: bool = False, -) -> torch.Tensor: - start_loc_np = start_loc_cpu.numpy() - selected_num_tokens = selected_cu_num_tokens[1:] - selected_cu_num_tokens[:-1] - cumsums_offsets = np.repeat(selected_cu_num_tokens[:-1], selected_num_tokens) - arange = np.arange(selected_cu_num_tokens[-1], dtype=np.int32) - cumsums_offsets - selected_indices = arange + np.repeat(start_loc_np, selected_num_tokens) - selected_indices_cpu = torch.from_numpy(selected_indices) - - if need_num_tokens: - return selected_indices_cpu, selected_num_tokens - else: - return selected_indices_cpu \ No newline at end of file + return kv_for_head_indices, kv_for_tail_indices diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index d3200cb94f05..406bb696bd4c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -644,6 +644,7 @@ def prepare_inputs_padded( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, + dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, ) token_indices_to_sample = ( @@ -920,6 +921,7 @@ def prepare_inputs( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, + dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, ) return spec_common_attn_metadata, token_indices diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index babd93798b18..6d1a53001331 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1051,25 +1051,43 @@ def _update_tokens_for_pcp( num_decode_reqs: int | None = None, ) -> tuple[np.ndarray, np.ndarray]: """ - If prefill context parallelism is enabled, we will update - the number of `tokens` after sequence splitting. - Meanwhile, we will compute: - `positions` the new token positions, - `self.num_pcp_pads_cpu` the number of padding tokens - per request for alignment, - `self.pcp_unpad_mask_cpu` the mask for non-padded tokens, - `self.pcp_allgather_restore_idx` indices to restore the - original vector order after PCP allgather. + Update token counts and positions for Prefill Context Parallelism (PCP). + + When using Prefill Context Parallelism, each request's prefill sequence is + split across multiple PCP ranks. The splitting strategy used here is the + "DualChunkSwap" style: each request's (padded) sequence is split into + 2 * pcp_world_size chunks and ranks are assigned chunks in an interleaved + head/tail pattern to balance load. + + This function: + - Computes how many tokens each request should be processed by the current + PCP rank (pcp_tokens). + - Computes the flattened positions of those tokens within the local + padded buffer (pcp_positions). + - Updates runner state arrays used to restore original order and mask out + padded tokens after allgather: + - self.num_pcp_pads_cpu: number of pads added per request + - self.pcp_unpad_mask_cpu: boolean mask marking real tokens in the + padded allgather buffer + - self.pcp_allgather_restore_idx: index array used to restore original + ordering after per-rank allgather and interleaving. + + Args: + tokens: 1D numpy array of length num_reqs containing the number of new + tokens scheduled for each request (before PCP splitting). + Returns: + Tuple (pcp_tokens, pcp_positions): + - pcp_tokens: number of tokens per request that this PCP rank will + actually process (after splitting / replication). + - pcp_positions: flattened positions for those tokens on this rank, + used to build the positions buffer for the model. + + Example: - >>> tokens = [1, 5, 8] - >>> pcp_world_size = 2 - >>> pcp_rank = 0 - >>> _update_tokens_for_pcp(tokens) - ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7]) - >>> pcp_rank = 1 - >>> _update_tokens_for_pcp(tokens) - ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5]) - >>> # the following results are same for each pcp rank + >>> Assume tokens = [1, 5, 8], pcp_world_size = 2. After _update_tokens_for_pcp. + >>> pcp_rank = 0 get ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7]) + >>> pcp_rank = 1 get ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5]) + >>> Meanwhile, the following results are same for each pcp rank >>> self.num_pcp_pads_cpu [1, 3, 0] >>> self.pcp_unpad_mask_cpu @@ -1094,42 +1112,74 @@ def _update_tokens_for_pcp( self.num_pcp_pads_cpu[:num_reqs] = 0 num_decode_tokens = sum(tokens[:num_decode_reqs]) - + # DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size). + # We first pad each request's token count up to that multiple. num_padded_scheduled_tokens = np.ceil( tokens / (2 * self.pcp_world_size) ).astype(np.int32) * (2 * self.pcp_world_size) - # we duplicate scheduled tokens of decode reqs to pcp_world_size + + # PCP does not split decode requests. For decode requests, we instead + # duplicate the scheduled tokens across the pcp_world_size ranks. num_padded_scheduled_tokens[:num_decode_reqs] = ( tokens[:num_decode_reqs] * self.pcp_world_size ) + + # Record how many pads were added per request (padded - original). self.num_pcp_pads_cpu[:num_reqs] = num_padded_scheduled_tokens - tokens + # cu_padded_tokens: cumulative sum of padded token counts, + # pcp_padded_arange: per-request arange flattened for padded tokens. cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange( num_padded_scheduled_tokens ) + # Build the mask that marks which positions in the padded allgather buffer + # correspond to real (unpadded) tokens. self.pcp_unpad_mask_cpu[: pcp_padded_arange.shape[0]] = ( pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens) ) pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size + + # Compute per-request "chunk sizes" for the head/tail splitting. + # For prefill requests, we further split the pcp_tokens into two chunks + # (head and tail). For decode requests, the chunk equals pcp_tokens. pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs] + + # Build arange-style helpers for pcp tokens and chunk sizes: + # - pcp_arange gives indices repeated for each token in pcp_tokens + # - pcp_chunk_arange gives indices repeated for each position inside chunks _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens) _, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) + + # Mask that marks whether a position belongs to the head chunk (True) + # or the tail chunk (False). For decode requests, tail chunk won't exist + # and is handled specially below. pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, pcp_tokens) def get_current_rank_positions( positions_start_loc: int | np.ndarray, rank: int ): + """ + Compute flattened positions for the given rank with a given start + offset for each request (positions_start_loc). + - For head chunks: start at positions_start_loc + rank * chunk_size. + - For tail chunks: start at positions_start_loc + (2*pcp_world_size- rank - + 1) * chunk_size. + - For decode requests: no tail chunks; their positions are filled from the + contiguous (unpadded) `tokens` arange instead (handled after). + """ positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) head_start_loc = positions_start_loc + rank * pcp_chunk_sizes tail_start_loc = ( positions_start_loc + (2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes ) + # Fill head positions using chunk arange offset by head_start_loc. positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat( head_start_loc, pcp_chunk_sizes ) - # Decode reqs do not have tail chunks. + # Fill tail positions. Note decode requests do not have tail chunks, + # so the tail filling is only for prefill positions. positions[~pcp_head_chunk_mask] = ( pcp_chunk_arange[num_decode_tokens:] + np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:] @@ -1144,6 +1194,7 @@ def get_current_rank_positions( tokens[:num_decode_reqs] )[1] + # Build the restore index used after allgather. padded_pos_start_loc = np.roll(cu_padded_tokens, 1) padded_pos_start_loc[0] = 0 all_positions_lst = [ @@ -1155,6 +1206,7 @@ def get_current_rank_positions( all_positions.argsort() ) self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) + return ( pcp_tokens[:num_reqs], positions, @@ -3977,7 +4029,7 @@ def _dummy_run( pcp_metadata = None if self.pcp_world_size > 1 and force_attention: num_decode_reqs = sum(num_scheduled_tokens == 1) - num_scheduled_tokens[:num_reqs], _, pcp_metadata = ( + num_scheduled_tokens[:num_reqs], _ = ( self._update_tokens_for_pcp( num_scheduled_tokens[:num_reqs], dummy_input=True, @@ -4033,7 +4085,6 @@ def _dummy_run( num_reqs=num_reqs, ubatch_slices=ubatch_slices, for_cudagraph_capture=True, - pcp_metadata=pcp_metadata if self.pcp_world_size > 1 else None, ) with self.maybe_dummy_run_with_lora(