Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 87 additions & 55 deletions vllm/v1/worker/cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,69 +79,65 @@ def _get_cumsum_and_arange(
arange = arange_np[:total_num_tokens] - cumsums_offsets
return cu_num_tokens, arange

def update_tokens_for_pcp(
def compute_rank_indices(
self,
num_scheduled_tokens: np.ndarray,
num_computed_tokens: np.ndarray,
arange_np: np.ndarray,
num_reqs: int,
reorder_batch_threshold: int | None = None,
) -> tuple[np.ndarray, np.ndarray]:
reorder_batch_threshold: int | None,
) -> tuple[np.ndarray, torch.Tensor, torch.Tensor]:
"""
Update token counts and positions for Prefill Context Parallelism (PCP).

When using Prefill Context Parallelism, each request's prefill sequence is
split across multiple PCP ranks. The splitting strategy used here is the
"DualChunkSwap" style: each request's (padded) sequence is split into
2 * pcp_world_size chunks and ranks are assigned chunks in an interleaved
head/tail pattern to balance load.

This function:
- Computes how many tokens each request should be processed by the current
PCP rank (pcp_tokens).
- Computes the flattened positions of those tokens within the local
padded buffer (pcp_positions).
- Updates runner state arrays used to restore original order and mask out
padded tokens after allgather:
Compute which tokens this PCP rank processes and their indices
into the original batch.

When using Prefill Context Parallelism (PCP), each request's prefill
sequence is split across multiple PCP ranks. The splitting strategy
used here is the "DualChunkSwap" style: each request's (padded)
sequence is split into 2 * pcp_world_size chunks and ranks are
assigned chunks in an interleaved head/tail pattern to balance load.

This method:
- Computes how many tokens each request should be processed by the
current PCP rank (pcp_num_scheduled).
- Computes indices into the original batch for gathering real tokens.
- Updates runner state arrays used to restore original order and mask
out padded tokens after allgather:
- self.num_pcp_pads_cpu: number of pads added per request
- self.pcp_unpad_mask_cpu: boolean mask marking real tokens in the
padded allgather buffer
- self.pcp_allgather_restore_idx: index array used to restore original
ordering after per-rank allgather and interleaving.
- self.pcp_allgather_restore_idx: index array used to restore
original ordering after per-rank allgather and interleaving.

Args:
num_scheduled_tokens: 1D numpy array of length num_reqs containing
the number of new tokens scheduled per request.
arange_np: 1D numpy array of length max_padded_num_tokens used for
efficient batched arange operations.
num_reqs: Total number of requests in the batch.
num_scheduled_tokens: 1D numpy array of per-request token counts.
num_computed_tokens: 1D numpy array of already computed tokens per
request.
arange_np: Pre-allocated arange buffer for efficient operations.
reorder_batch_threshold: Threshold for decode vs prefill requests.

Returns:
Tuple (pcp_tokens, pcp_positions):
- pcp_tokens: number of tokens per request that this PCP rank will
actually process (after splitting / replication).
- pcp_positions: flattened positions for those tokens on this rank,
used to build the positions buffer for the model.
Tuple (pcp_num_scheduled, pcp_rank_indices, padding_mask):
- pcp_num_scheduled: per-request token counts for this rank
- pcp_rank_indices: GPU tensor of indices into original batch for
real tokens (used with advanced indexing to gather tokens)
- padding_mask: GPU tensor, True for padding slots in PCP batch

Example:
>>> Assume tokens = [1, 5, 8], pcp_world_size = 2. After _update_tokens_for_pcp.
>>> pcp_rank = 0 get ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7])
>>> pcp_rank = 1 get ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5])
>>> Meanwhile, the following results are same for each pcp rank
>>> self.num_pcp_pads_cpu
[1, 3, 0]
>>> self.pcp_unpad_mask_cpu
[True, False, True, True, True, True, True, False, False,
False, True, True, True, True, True, True, True, True]
>>> self.pcp_allgather_resotre_idx
[0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8]
Assume tokens = [1, 5, 8], pcp_world_size = 2:
- pcp_rank=0 gets pcp_num_scheduled=[1, 4, 4], positions=[0,0,1,6,7,0,1,6,7]
- pcp_rank=1 gets pcp_num_scheduled=[1, 4, 4], positions=[0,2,3,4,5,2,3,4,5]
Meanwhile, these are same for each pcp rank:
- num_pcp_pads_cpu: [1, 3, 0]
- pcp_unpad_mask_cpu: [T,F,T,T,T,T,T,F,F,F,T,T,T,T,T,T,T,T]
- pcp_allgather_restore_idx: [0,9,1,2,10,11,12,13,3,4,5,6,14,15,16,17,7,8]
"""

assert reorder_batch_threshold is not None, (
"PCP depends on reorder batch to split decode and prefill requests."
)
num_reqs = len(num_scheduled_tokens)
num_decode_reqs = sum(num_scheduled_tokens <= reorder_batch_threshold)
num_decode_tokens = sum(num_scheduled_tokens[:num_decode_reqs])

# DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size).
# We first pad each request's token count up to that multiple.
num_padded_scheduled_tokens = np.ceil(
Expand All @@ -157,19 +153,23 @@ def update_tokens_for_pcp(
self.num_pcp_pads_cpu[:num_reqs] = (
num_padded_scheduled_tokens - num_scheduled_tokens
)

# cu_padded_tokens: cumulative sum of padded token counts,
# pcp_padded_arange: per-request arange flattened for padded tokens.
cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange(
num_padded_scheduled_tokens, arange_np
)
# Build the mask that marks which positions in the padded allgather buffer
# correspond to real (unpadded) tokens.

# Build the mask that marks which positions in the padded allgather
# buffer correspond to real (unpadded) tokens.
self.pcp_unpad_mask_cpu[: pcp_padded_arange.shape[0]] = (
pcp_padded_arange
< np.repeat(num_scheduled_tokens, num_padded_scheduled_tokens)
)

# pcp_tokens: tokens per request for this rank after splitting
pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size

# Compute per-request "chunk sizes" for the head/tail splitting.
# For prefill requests, we further split the pcp_tokens into two chunks
# (head and tail). For decode requests, the chunk equals pcp_tokens.
Expand All @@ -189,16 +189,16 @@ def update_tokens_for_pcp(

def get_current_rank_positions(
positions_start_loc: int | np.ndarray, rank: int
):
) -> np.ndarray:
"""
Compute flattened positions for the given rank with a given start
offset for each request (positions_start_loc).

- For head chunks: start at positions_start_loc + rank * chunk_size.
- For tail chunks: start at positions_start_loc + (2*pcp_world_size- rank -
1) * chunk_size.
- For decode requests: no tail chunks; their positions are filled from the
contiguous (unpadded) `tokens` arange instead (handled after).
- For tail chunks: start at positions_start_loc +
(2*pcp_world_size - rank - 1) * chunk_size.
- For decode requests: no tail chunks; their positions are filled
from the contiguous (unpadded) `tokens` arange instead.
"""
positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32)
head_start_loc = positions_start_loc + rank * pcp_chunk_sizes
Expand All @@ -218,9 +218,10 @@ def get_current_rank_positions(
)
return positions

# Get positions for this rank (position VALUES, not indices)
positions = get_current_rank_positions(0, self.pcp_rank)
# Decode tokens are duplicated only after AG. But their positions are
# same without prefill context parallel.
# Decode tokens are duplicated only after allgather. But their positions
# are the same without prefill context parallel.
if num_decode_reqs > 0:
positions[:num_decode_tokens] = self._get_cumsum_and_arange(
num_scheduled_tokens[:num_decode_reqs], arange_np
Expand All @@ -238,10 +239,41 @@ def get_current_rank_positions(
all_positions.argsort()
)
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
return (
pcp_tokens[:num_reqs],
positions,

# Now compute indices into original batch
# positions[i] is the position VALUE for PCP token i
# We need to find which original token that corresponds to

# Compute cumsum of original tokens for request start offsets
cu_orig_tokens = np.cumsum(num_scheduled_tokens)
orig_start_offsets = np.concatenate([[0], cu_orig_tokens[:-1]])

pcp_total_tokens = int(pcp_tokens.sum())

# padding_mask: True if position >= original seq len for that request
orig_seq_lens_expanded = np.repeat(num_scheduled_tokens, pcp_tokens)
padding_mask_np = positions[:pcp_total_tokens] >= orig_seq_lens_expanded

# For real tokens, compute index into original batch
# original_index = orig_start_offset[req] + position
orig_start_expanded = np.repeat(orig_start_offsets, pcp_tokens)
# Only compute for non-padding tokens
real_indices = orig_start_expanded + positions[:pcp_total_tokens]
# Clamp padding indices to 0 (they won't be used anyway)
real_indices = np.where(padding_mask_np, 0, real_indices)

# Extract just the indices for real tokens
pcp_rank_indices = real_indices[~padding_mask_np]

# Convert to GPU tensors
pcp_rank_indices_gpu = torch.from_numpy(pcp_rank_indices.astype(np.int64)).to(
self.device, non_blocking=True
)
padding_mask_gpu = torch.from_numpy(padding_mask_np.copy()).to(
self.device, non_blocking=True
)

return pcp_tokens[:num_reqs], pcp_rank_indices_gpu, padding_mask_gpu

def get_logits_indices(self, cu_num_tokens: np.ndarray, num_reqs: int):
return (
Expand Down
Loading