diff --git a/vllm/v1/worker/cp_utils.py b/vllm/v1/worker/cp_utils.py index aac94721bacf..c5a07c15b782 100644 --- a/vllm/v1/worker/cp_utils.py +++ b/vllm/v1/worker/cp_utils.py @@ -79,69 +79,65 @@ def _get_cumsum_and_arange( arange = arange_np[:total_num_tokens] - cumsums_offsets return cu_num_tokens, arange - def update_tokens_for_pcp( + def compute_rank_indices( self, num_scheduled_tokens: np.ndarray, + num_computed_tokens: np.ndarray, arange_np: np.ndarray, - num_reqs: int, - reorder_batch_threshold: int | None = None, - ) -> tuple[np.ndarray, np.ndarray]: + reorder_batch_threshold: int | None, + ) -> tuple[np.ndarray, torch.Tensor, torch.Tensor]: """ - Update token counts and positions for Prefill Context Parallelism (PCP). - - When using Prefill Context Parallelism, each request's prefill sequence is - split across multiple PCP ranks. The splitting strategy used here is the - "DualChunkSwap" style: each request's (padded) sequence is split into - 2 * pcp_world_size chunks and ranks are assigned chunks in an interleaved - head/tail pattern to balance load. - - This function: - - Computes how many tokens each request should be processed by the current - PCP rank (pcp_tokens). - - Computes the flattened positions of those tokens within the local - padded buffer (pcp_positions). - - Updates runner state arrays used to restore original order and mask out - padded tokens after allgather: + Compute which tokens this PCP rank processes and their indices + into the original batch. + + When using Prefill Context Parallelism (PCP), each request's prefill + sequence is split across multiple PCP ranks. The splitting strategy + used here is the "DualChunkSwap" style: each request's (padded) + sequence is split into 2 * pcp_world_size chunks and ranks are + assigned chunks in an interleaved head/tail pattern to balance load. + + This method: + - Computes how many tokens each request should be processed by the + current PCP rank (pcp_num_scheduled). + - Computes indices into the original batch for gathering real tokens. + - Updates runner state arrays used to restore original order and mask + out padded tokens after allgather: - self.num_pcp_pads_cpu: number of pads added per request - self.pcp_unpad_mask_cpu: boolean mask marking real tokens in the padded allgather buffer - - self.pcp_allgather_restore_idx: index array used to restore original - ordering after per-rank allgather and interleaving. + - self.pcp_allgather_restore_idx: index array used to restore + original ordering after per-rank allgather and interleaving. Args: - num_scheduled_tokens: 1D numpy array of length num_reqs containing - the number of new tokens scheduled per request. - arange_np: 1D numpy array of length max_padded_num_tokens used for - efficient batched arange operations. - num_reqs: Total number of requests in the batch. + num_scheduled_tokens: 1D numpy array of per-request token counts. + num_computed_tokens: 1D numpy array of already computed tokens per + request. + arange_np: Pre-allocated arange buffer for efficient operations. reorder_batch_threshold: Threshold for decode vs prefill requests. Returns: - Tuple (pcp_tokens, pcp_positions): - - pcp_tokens: number of tokens per request that this PCP rank will - actually process (after splitting / replication). - - pcp_positions: flattened positions for those tokens on this rank, - used to build the positions buffer for the model. + Tuple (pcp_num_scheduled, pcp_rank_indices, padding_mask): + - pcp_num_scheduled: per-request token counts for this rank + - pcp_rank_indices: GPU tensor of indices into original batch for + real tokens (used with advanced indexing to gather tokens) + - padding_mask: GPU tensor, True for padding slots in PCP batch Example: - >>> Assume tokens = [1, 5, 8], pcp_world_size = 2. After _update_tokens_for_pcp. - >>> pcp_rank = 0 get ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7]) - >>> pcp_rank = 1 get ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5]) - >>> Meanwhile, the following results are same for each pcp rank - >>> self.num_pcp_pads_cpu - [1, 3, 0] - >>> self.pcp_unpad_mask_cpu - [True, False, True, True, True, True, True, False, False, - False, True, True, True, True, True, True, True, True] - >>> self.pcp_allgather_resotre_idx - [0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8] + Assume tokens = [1, 5, 8], pcp_world_size = 2: + - pcp_rank=0 gets pcp_num_scheduled=[1, 4, 4], positions=[0,0,1,6,7,0,1,6,7] + - pcp_rank=1 gets pcp_num_scheduled=[1, 4, 4], positions=[0,2,3,4,5,2,3,4,5] + Meanwhile, these are same for each pcp rank: + - num_pcp_pads_cpu: [1, 3, 0] + - pcp_unpad_mask_cpu: [T,F,T,T,T,T,T,F,F,F,T,T,T,T,T,T,T,T] + - pcp_allgather_restore_idx: [0,9,1,2,10,11,12,13,3,4,5,6,14,15,16,17,7,8] """ - assert reorder_batch_threshold is not None, ( "PCP depends on reorder batch to split decode and prefill requests." ) + num_reqs = len(num_scheduled_tokens) num_decode_reqs = sum(num_scheduled_tokens <= reorder_batch_threshold) num_decode_tokens = sum(num_scheduled_tokens[:num_decode_reqs]) + # DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size). # We first pad each request's token count up to that multiple. num_padded_scheduled_tokens = np.ceil( @@ -157,19 +153,23 @@ def update_tokens_for_pcp( self.num_pcp_pads_cpu[:num_reqs] = ( num_padded_scheduled_tokens - num_scheduled_tokens ) + # cu_padded_tokens: cumulative sum of padded token counts, # pcp_padded_arange: per-request arange flattened for padded tokens. cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange( num_padded_scheduled_tokens, arange_np ) - # Build the mask that marks which positions in the padded allgather buffer - # correspond to real (unpadded) tokens. + + # Build the mask that marks which positions in the padded allgather + # buffer correspond to real (unpadded) tokens. self.pcp_unpad_mask_cpu[: pcp_padded_arange.shape[0]] = ( pcp_padded_arange < np.repeat(num_scheduled_tokens, num_padded_scheduled_tokens) ) + # pcp_tokens: tokens per request for this rank after splitting pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size + # Compute per-request "chunk sizes" for the head/tail splitting. # For prefill requests, we further split the pcp_tokens into two chunks # (head and tail). For decode requests, the chunk equals pcp_tokens. @@ -189,16 +189,16 @@ def update_tokens_for_pcp( def get_current_rank_positions( positions_start_loc: int | np.ndarray, rank: int - ): + ) -> np.ndarray: """ Compute flattened positions for the given rank with a given start offset for each request (positions_start_loc). - For head chunks: start at positions_start_loc + rank * chunk_size. - - For tail chunks: start at positions_start_loc + (2*pcp_world_size- rank - - 1) * chunk_size. - - For decode requests: no tail chunks; their positions are filled from the - contiguous (unpadded) `tokens` arange instead (handled after). + - For tail chunks: start at positions_start_loc + + (2*pcp_world_size - rank - 1) * chunk_size. + - For decode requests: no tail chunks; their positions are filled + from the contiguous (unpadded) `tokens` arange instead. """ positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) head_start_loc = positions_start_loc + rank * pcp_chunk_sizes @@ -218,9 +218,10 @@ def get_current_rank_positions( ) return positions + # Get positions for this rank (position VALUES, not indices) positions = get_current_rank_positions(0, self.pcp_rank) - # Decode tokens are duplicated only after AG. But their positions are - # same without prefill context parallel. + # Decode tokens are duplicated only after allgather. But their positions + # are the same without prefill context parallel. if num_decode_reqs > 0: positions[:num_decode_tokens] = self._get_cumsum_and_arange( num_scheduled_tokens[:num_decode_reqs], arange_np @@ -238,10 +239,41 @@ def get_current_rank_positions( all_positions.argsort() ) self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) - return ( - pcp_tokens[:num_reqs], - positions, + + # Now compute indices into original batch + # positions[i] is the position VALUE for PCP token i + # We need to find which original token that corresponds to + + # Compute cumsum of original tokens for request start offsets + cu_orig_tokens = np.cumsum(num_scheduled_tokens) + orig_start_offsets = np.concatenate([[0], cu_orig_tokens[:-1]]) + + pcp_total_tokens = int(pcp_tokens.sum()) + + # padding_mask: True if position >= original seq len for that request + orig_seq_lens_expanded = np.repeat(num_scheduled_tokens, pcp_tokens) + padding_mask_np = positions[:pcp_total_tokens] >= orig_seq_lens_expanded + + # For real tokens, compute index into original batch + # original_index = orig_start_offset[req] + position + orig_start_expanded = np.repeat(orig_start_offsets, pcp_tokens) + # Only compute for non-padding tokens + real_indices = orig_start_expanded + positions[:pcp_total_tokens] + # Clamp padding indices to 0 (they won't be used anyway) + real_indices = np.where(padding_mask_np, 0, real_indices) + + # Extract just the indices for real tokens + pcp_rank_indices = real_indices[~padding_mask_np] + + # Convert to GPU tensors + pcp_rank_indices_gpu = torch.from_numpy(pcp_rank_indices.astype(np.int64)).to( + self.device, non_blocking=True ) + padding_mask_gpu = torch.from_numpy(padding_mask_np.copy()).to( + self.device, non_blocking=True + ) + + return pcp_tokens[:num_reqs], pcp_rank_indices_gpu, padding_mask_gpu def get_logits_indices(self, cu_num_tokens: np.ndarray, num_reqs: int): return ( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cc33bce7dbd1..688d3afc5f3c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1363,29 +1363,6 @@ def _prepare_inputs( self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) - if self.pcp_world_size > 1: - num_scheduled_tokens[:num_reqs], pcp_positions = ( - self.pcp_manager.update_tokens_for_pcp( - num_scheduled_tokens[:num_reqs], - self.arange_np, - self.input_batch.num_reqs, - self.reorder_batch_threshold, - ) - ) - - # Re-update after PCP split sequences. - total_num_scheduled_tokens = sum(num_scheduled_tokens) - scheduler_output.total_num_scheduled_tokens = total_num_scheduled_tokens - - req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) - cu_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) - positions_np = self.positions.np[:total_num_scheduled_tokens] - np.add( - self.input_batch.num_computed_tokens_cpu[req_indices], - pcp_positions[:total_num_scheduled_tokens], - out=positions_np, - ) - # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -1481,20 +1458,11 @@ def _prepare_inputs( num_tokens_np = np.array(num_tokens, dtype=np.int32) # Record which requests should not be sampled, - # so that we could clear the sampled tokens before returning - if self.pcp_world_size > 1: - self.discard_request_mask.np[:num_reqs] = ( - self.pcp_manager.get_discard_request_mask( - num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu, - num_scheduled_tokens=num_scheduled_tokens, - num_reqs=num_reqs, - num_tokens_np=num_tokens_np, - ) - ) - else: - self.discard_request_mask.np[:num_reqs] = ( - self.seq_lens.np[:num_reqs] < num_tokens_np - ) + # so that we could clear the sampled tokens before returning. + # Note: PCP updates this in _maybe_partition_batch_for_pcp. + self.discard_request_mask.np[:num_reqs] = ( + self.seq_lens.np[:num_reqs] < num_tokens_np + ) self.discard_request_mask.copy_to_gpu(num_reqs) # Copy the tensors to the GPU. @@ -1527,11 +1495,8 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. + # Note: PCP updates logits_indices in _maybe_partition_batch_for_pcp. logits_indices = query_start_loc[1:] - 1 - if self.pcp_world_size > 1: - logits_indices = self.pcp_manager.get_logits_indices( - cu_num_tokens, num_reqs - ) num_draft_tokens = None spec_decode_metadata = None num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) @@ -1583,6 +1548,96 @@ def _prepare_inputs( spec_decode_metadata, ) + def _partition_batch_for_pcp( + self, + num_scheduled_tokens_np: np.ndarray, + num_reqs: int, + original_total_tokens: int, + ) -> tuple[int, int, torch.Tensor]: + """ + Partition batch for Prefill Context Parallelism. Mutates GPU buffers + in-place. + + This method applies PCP partitioning by: + 1. Computing which tokens this PCP rank processes + 2. Using index_select on GPU to gather real tokens from original batch + 3. Filling padding slots with placeholder values + 4. Updating query_start_loc, seq_lens, and discard_request_mask + + Args: + num_scheduled_tokens_np: Per-request token counts (original). + num_reqs: Number of requests in the batch. + original_total_tokens: Total tokens in the original batch. + + Returns: + Tuple (total_num_scheduled_tokens, max_num_scheduled_tokens, logits_indices) + """ + + # 1. Compute PCP token assignments and indices + pcp_num_scheduled, pcp_rank_indices_gpu, padding_mask_gpu = ( + self.pcp_manager.compute_rank_indices( + num_scheduled_tokens_np[:num_reqs], + self.input_batch.num_computed_tokens_cpu[:num_reqs], + self.arange_np, + self.reorder_batch_threshold, + ) + ) + + pcp_total_tokens = int(pcp_num_scheduled.sum()) + max_scheduled = int(pcp_num_scheduled.max()) + + # 2. GPU index_select for real tokens + real_mask_gpu = ~padding_mask_gpu + self.input_ids.gpu[:pcp_total_tokens][real_mask_gpu] = self.input_ids.gpu[ + :original_total_tokens + ][pcp_rank_indices_gpu] + self.positions.gpu[:pcp_total_tokens][real_mask_gpu] = self.positions.gpu[ + :original_total_tokens + ][pcp_rank_indices_gpu] + + # 3. Fill padding slots with placeholder + self.input_ids.gpu[:pcp_total_tokens][padding_mask_gpu] = 0 + self.positions.gpu[:pcp_total_tokens][padding_mask_gpu] = 0 + + # 4. Handle inputs_embeds if present + if self.inputs_embeds is not None: + self.inputs_embeds.gpu[:pcp_total_tokens][real_mask_gpu] = ( + self.inputs_embeds.gpu[:original_total_tokens][pcp_rank_indices_gpu] + ) + self.inputs_embeds.gpu[:pcp_total_tokens][padding_mask_gpu] = 0 + + # 5. Update query_start_loc (CPU + GPU mirror) + cu_pcp_tokens = np.cumsum(pcp_num_scheduled) + self.query_start_loc.np[0] = 0 + self.query_start_loc.np[1 : num_reqs + 1] = cu_pcp_tokens + self.query_start_loc.copy_to_gpu() + + # 6. Update seq_lens (CPU + GPU mirror) + self.seq_lens.np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + pcp_num_scheduled + ) + self.seq_lens.copy_to_gpu() + + # 7. Update discard_request_mask + num_tokens = [ + self.requests[r].num_tokens for r in self.input_batch.req_ids[:num_reqs] + ] + num_tokens_np = np.array(num_tokens, dtype=np.int32) + self.discard_request_mask.np[:num_reqs] = ( + self.pcp_manager.get_discard_request_mask( + num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu, + num_scheduled_tokens=pcp_num_scheduled, + num_reqs=num_reqs, + num_tokens_np=num_tokens_np, + ) + ) + self.discard_request_mask.copy_to_gpu(num_reqs) + + # 8. Compute logits_indices + logits_indices = self.pcp_manager.get_logits_indices(cu_pcp_tokens, num_reqs) + + return pcp_total_tokens, max_scheduled, logits_indices + def _build_attention_metadata( self, num_tokens: int, @@ -3202,9 +3257,19 @@ def execute_model( scheduler_output, num_scheduled_tokens_np, ) + + total_num_scheduled_tokens = num_tokens_unpadded if self.pcp_world_size > 1: - max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) - num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + # Apply PCP partitioning (mutates GPU buffers in-place) + ( + total_num_scheduled_tokens, + max_num_scheduled_tokens, + logits_indices, + ) = self._partition_batch_for_pcp( + num_scheduled_tokens_np, + num_reqs, + num_tokens_unpadded, + ) cascade_attn_prefix_lens = None # Disable cascade attention when using microbatching (DBO) @@ -3265,7 +3330,7 @@ def execute_model( (attn_metadata, spec_decode_common_attn_metadata) = ( self._build_attention_metadata( - num_tokens=num_tokens_unpadded, + num_tokens=total_num_scheduled_tokens, num_tokens_padded=num_tokens_padded if pad_attn else None, num_reqs=num_reqs, num_reqs_padded=num_reqs_padded if pad_attn else None, @@ -3334,10 +3399,8 @@ def execute_model( # ignores the padding from CUDA Graph. hidden_states = self.pcp_manager.get_restore_hidden_states( hidden_states, - num_tokens_unpadded, + total_num_scheduled_tokens, ) - # Restore total_num_scheduled_tokens. - scheduler_output.total_num_scheduled_tokens = num_scheduled_tokens if not self.broadcast_pp_output: # Common case. if not get_pp_group().is_last_rank: