Skip to content

[Attention] PCP alternative implementation #36306

Draft
LucasWilkinson wants to merge 2 commits intovllm-project:mainfrom
neuralmagic:pcp-virtual-request-kv-gather
Draft

[Attention] PCP alternative implementation #36306
LucasWilkinson wants to merge 2 commits intovllm-project:mainfrom
neuralmagic:pcp-virtual-request-kv-gather

Conversation

@LucasWilkinson
Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson commented Mar 7, 2026

Alternative implementation to: #33403

Uses "virtual batches"

…ther

Implements Prefill Context Parallelism (PCP) using the DualChunkSwap
approach for load balancing. Each PCP rank processes both head and tail
tokens from each request, which balances attention workload since head
tokens attend to fewer KV and tail tokens attend to more.

Key changes:
- PCPVirtualRequestManager with DualChunkSwap partitioning
- KV all-gather in MLA attention before cache write
- slot_mapping all-gather during metadata building
- pcp_kv_restore_idx for reordering gathered KV to global position order

Example with 8 tokens and PCP=2:
  - Rank 0: positions [0,1,6,7] (head=[0,1], tail=[6,7])
  - Rank 1: positions [2,3,4,5] (head=[2,3], tail=[4,5])

After KV all-gather, all ranks have full KV [0,8) for correct attention.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@mergify mergify bot added the v1 label Mar 7, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an alternative implementation for Prefill Context Parallelism (PCP) using a "virtual request" approach with a new PCPVirtualRequestManager. The changes are well-tested and integrated into the GPUModelRunner and attention layers. My review focuses on a couple of areas for improvement regarding code duplication and performance optimization.

Comment on lines +915 to +941
# PCP: All-gather KV from all ranks before cache write
kv_for_cache = kv_c_normed
kpe_for_cache = k_pe
if attn_layer.impl.pcp_world_size > 1:
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata.get(layer_name)
pcp_kv_restore_idx = (
getattr(attn_metadata, "pcp_kv_restore_idx", None)
if attn_metadata
else None
)
if pcp_kv_restore_idx is not None:
from vllm.distributed.parallel_state import get_pcp_group

pcp_group = get_pcp_group()
# Batch all-gather k_c_normed and k_pe for efficiency
kv_combined = torch.cat(
[kv_c_normed, k_pe.view(k_pe.shape[0], -1)], dim=-1
)
kv_combined = pcp_group.all_gather(kv_combined.contiguous(), dim=0)
kv_combined = kv_combined[pcp_kv_restore_idx]
kv_lora_rank = kv_c_normed.shape[-1]
kv_for_cache = kv_combined[..., :kv_lora_rank]
kpe_for_cache = kv_combined[..., kv_lora_rank:].view(
-1, k_pe.shape[1], k_pe.shape[2]
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code for PCP KV all-gather is nearly identical to the one in MLAAttention.forward (lines 443-463). This duplication makes the code harder to maintain. Consider refactoring this logic into a shared helper function to reduce redundancy and improve clarity.

Comment on lines +1582 to +1591
if pcp_cu_num_tokens is not None:
cu_num_tokens = pcp_cu_num_tokens
# Compute arange from cumsum
arange = (
np.concatenate([np.arange(n) for n in num_scheduled_tokens])
if len(num_scheduled_tokens) > 0
else np.array([], dtype=np.int64)
)
else:
cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of arange computation using np.concatenate with a list comprehension can be inefficient for large batches. A more performant approach using vectorized numpy operations, similar to the one in _get_cumsum_and_arange, should be used.

Suggested change
if pcp_cu_num_tokens is not None:
cu_num_tokens = pcp_cu_num_tokens
# Compute arange from cumsum
arange = (
np.concatenate([np.arange(n) for n in num_scheduled_tokens])
if len(num_scheduled_tokens) > 0
else np.array([], dtype=np.int64)
)
else:
cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
if pcp_cu_num_tokens is not None:
cu_num_tokens = pcp_cu_num_tokens
# Compute arange from cumsum more efficiently
if len(num_scheduled_tokens) > 0:
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
num_scheduled_tokens)
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
else:
arange = np.array([], dtype=np.int64)
else:
cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)

…nostic

Refactors the PCP partitioning to use true virtual requests, where each
physical request becomes 2 virtual requests (head + tail). This allows
_prepare_inputs to be completely PCP-agnostic.

Key changes:
- partition() now returns (virtual_num_scheduled, virtual_num_computed,
  virtual_to_physical) instead of positions and req_indices
- _prepare_inputs uses virtual_num_computed for position computation
- _prepare_inputs uses virtual_to_physical for physical request lookups
  (block table, token_ids, etc.)
- Standard position formula (num_computed + arange) produces correct
  positions automatically

Example with 8 tokens and PCP=2:
  Rank 0:
    - vreq 0 (head): num_scheduled=2, num_computed=0
    - vreq 1 (tail): num_scheduled=2, num_computed=6
  Rank 1:
    - vreq 0 (head): num_scheduled=2, num_computed=2
    - vreq 1 (tail): num_scheduled=2, num_computed=4

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant