Skip to content
Open
12 changes: 10 additions & 2 deletions tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
dcp_size: int
pcp_size: int
cp_kv_cache_interleave_size: int
eager_mode: bool
chunked_prefill: bool
Expand All @@ -73,6 +74,7 @@ def detailed(
tp_base: int = 4,
pp_base: int = 1,
dcp_multipliers: list[float] | None = None,
pcp_base: int = 1,
cp_kv_cache_interleave_size: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
Expand All @@ -91,7 +93,8 @@ def detailed(
ParallelSetup(
tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
dcp_size=int(dcp_multiplier * tp_base),
dcp_size=max(1, int(dcp_multiplier * tp_base)),
pcp_size=pcp_base,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
Expand Down Expand Up @@ -129,6 +132,8 @@ def iter_params(self, model_id: str):
cp_kv_cache_interleave_size=64,
attn_backend="FLASHMLA",
),
CPTestSettings.detailed(tp_base=1, pcp_base=4, cp_kv_cache_interleave_size=64),
CPTestSettings.detailed(tp_base=2, pcp_base=2, cp_kv_cache_interleave_size=64),
],
"Qwen/Qwen2.5-1.5B-Instruct": [
CPTestSettings.detailed(
Expand Down Expand Up @@ -156,6 +161,7 @@ def _test_cp_gsm8k(
tp_size,
pp_size,
dcp_size,
pcp_size,
cp_kv_cache_interleave_size,
eager_mode,
chunked_prefill,
Expand Down Expand Up @@ -212,7 +218,9 @@ def _test_cp_gsm8k(
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--dcp-kv-cache-interleave-size",
"--prefill-context-parallel-size",
str(pcp_size),
"--cp-kv-cache-interleave-size",
str(cp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
Expand Down
10 changes: 5 additions & 5 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ class AttentionImpl(ABC, Generic[T]):
pcp_world_size: int
pcp_rank: int

total_cp_world_size: int
total_cp_rank: int
cp_world_size: int
cp_rank: int

def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this
Expand All @@ -343,11 +343,11 @@ def __new__(cls, *args, **kwargs):
except AssertionError:
self.pcp_world_size = 1
self.pcp_rank = 0
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
self.cp_world_size = self.pcp_world_size * self.dcp_world_size
self.cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank

self.need_to_return_lse_for_decode = (
self.dcp_world_size > 1 and self.can_return_lse_for_decode
self.cp_world_size > 1 and self.can_return_lse_for_decode
)
return self

Expand Down
217 changes: 217 additions & 0 deletions vllm/attention/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,220 @@ def unpack_seq_triton(
out = out.reshape(output_shape)

return out


@triton.jit
def _fused_pcp_qkv_select_kernel(
q_ptr,
q_stride_B,
q_stride_H,
k_ptr,
k_stride_B,
k_stride_H,
v_ptr,
v_stride_B,
v_stride_H,
query_start_ptr,
out_q_head_ptr,
out_q_tail_ptr,
out_k_head_ptr,
out_k_tail_ptr,
out_v_head_ptr,
out_v_tail_ptr,
pcp_world_size: tl.constexpr,
pcp_rank: tl.constexpr,
n_head: tl.constexpr,
q_head_dim: tl.constexpr,
k_head_dim: tl.constexpr,
v_head_dim: tl.constexpr,
SEQ_BLOCK_SIZE: tl.constexpr,
DIM_BLOCK_SIZE: tl.constexpr,
):
req_id = tl.program_id(0) // (2 * pcp_world_size)
seq_block_id = tl.program_id(0) % (2 * pcp_world_size)
head_id = tl.program_id(1)
dim_block_id = tl.program_id(2)
dim_off = tl.arange(0, DIM_BLOCK_SIZE) + dim_block_id * DIM_BLOCK_SIZE

q_start_loc = tl.load(query_start_ptr + req_id)
q_end_loc = tl.load(query_start_ptr + req_id + 1)
q_select_len = (q_end_loc - q_start_loc) // 2

# Select Q
if seq_block_id < 2:
block_q_start_loc = q_start_loc + seq_block_id * q_select_len
out_ptr = out_q_head_ptr if seq_block_id == 0 else out_q_tail_ptr
for qi in range(tl.cdiv(q_select_len, SEQ_BLOCK_SIZE)):
q_offset = tl.arange(0, SEQ_BLOCK_SIZE) + qi * SEQ_BLOCK_SIZE
mask = (dim_off[None, :] < q_head_dim) & (q_offset[:, None] < q_select_len)
q_src_idx = block_q_start_loc + q_offset[:, None]
q_dst_idx = q_start_loc // 2 + q_offset[:, None]
q_val = tl.load(
q_ptr
+ q_src_idx * q_stride_B
+ head_id * q_stride_H
+ dim_off[None, :],
mask=mask,
)
tl.store(
out_ptr
+ q_dst_idx * n_head * q_head_dim
+ head_id * q_head_dim
+ dim_off[None, :],
q_val,
mask=mask,
)

# Select KV
kv_start_loc = q_start_loc * pcp_world_size
kv_select_len = q_select_len
k_d_mask = dim_off[None, :] < k_head_dim
v_d_mask = dim_off[None, :] < v_head_dim
block_src_kv_start_loc = kv_start_loc + seq_block_id * kv_select_len
block_dst_kv_head_start_loc = (
kv_start_loc // 2 // pcp_world_size * (pcp_rank + 1)
+ seq_block_id * kv_select_len
)
block_dst_kv_tail_start_loc = (
kv_start_loc // 2 // pcp_world_size * (2 * pcp_world_size - pcp_rank)
+ seq_block_id * kv_select_len
)
for ki in range(tl.cdiv(kv_select_len, SEQ_BLOCK_SIZE)):
kv_offset = tl.arange(0, SEQ_BLOCK_SIZE) + ki * SEQ_BLOCK_SIZE
kv_block_mask = kv_offset[:, None] < kv_select_len
kv_src_idx = block_src_kv_start_loc + kv_offset[:, None]
kv_dst_idx_head = block_dst_kv_head_start_loc + kv_offset[:, None]
kv_dst_idx_tail = block_dst_kv_tail_start_loc + kv_offset[:, None]
k_val = tl.load(
k_ptr + kv_src_idx * k_stride_B + head_id * k_stride_H + dim_off[None, :],
mask=k_d_mask & kv_block_mask,
)
v_val = tl.load(
v_ptr + kv_src_idx * v_stride_B + head_id * v_stride_H + dim_off[None, :],
mask=v_d_mask & kv_block_mask,
)
if seq_block_id < pcp_rank + 1:
tl.store(
out_k_head_ptr
+ kv_dst_idx_head * n_head * k_head_dim
+ head_id * k_head_dim
+ dim_off[None, :],
k_val,
mask=k_d_mask & kv_block_mask,
)
tl.store(
out_v_head_ptr
+ kv_dst_idx_head * n_head * v_head_dim
+ head_id * v_head_dim
+ dim_off[None, :],
v_val,
mask=v_d_mask & kv_block_mask,
)
if seq_block_id < 2 * pcp_world_size - pcp_rank:
tl.store(
out_k_tail_ptr
+ kv_dst_idx_tail * n_head * k_head_dim
+ head_id * k_head_dim
+ dim_off[None, :],
k_val,
mask=k_d_mask & kv_block_mask,
)
tl.store(
out_v_tail_ptr
+ kv_dst_idx_tail * n_head * v_head_dim
+ head_id * v_head_dim
+ dim_off[None, :],
v_val,
mask=v_d_mask & kv_block_mask,
)


def fused_pcp_qkv_select(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
query_start_loc: torch.Tensor,
pcp_world_size: int,
pcp_rank: int,
):
"""
Select the query and kv tensors for PCP. Instead of calling
`torch.index_select` multiple times, this function fuses the
selection for Q, K, and V into a single kernel to reduce
kernel launch overhead.
Args:
q: query tensor on the current PCP rank.
k: key tensor across PCP ranks.
v: value tensor across PCP ranks.
query_start_loc: start location of each query.
pcp_world_size: number of PCP ranks.
pcp_rank: rank of the current PCP rank.
Returns:
q_head: selected query tensor for pcp head.
k_head: selected key tensor for pcp head.
v_head: selected value tensor for pcp head.
q_tail: selected query tensor for pcp tail.
k_tail: selected key tensor for pcp tail.
v_tail: selected value tensor for pcp tail.

"""
q_head = torch.empty(
(q.size(0) // 2,) + q.shape[1:], device=q.device, dtype=q.dtype
)
q_tail = torch.empty_like(q_head)
k_head = torch.empty(
(q.size(0) // 2 * (pcp_rank + 1),) + k.shape[1:], device=k.device, dtype=k.dtype
)
v_head = torch.empty(
(q.size(0) // 2 * (pcp_rank + 1),) + v.shape[1:], device=v.device, dtype=v.dtype
)
k_tail = torch.empty(
(q.size(0) // 2 * (2 * pcp_world_size - pcp_rank),) + k.shape[1:],
device=k.device,
dtype=k.dtype,
)
v_tail = torch.empty(
(q.size(0) // 2 * (2 * pcp_world_size - pcp_rank),) + v.shape[1:],
device=v.device,
dtype=v.dtype,
)
BS = len(query_start_loc) - 1
DIM_BLOCK_SIZE: int = 64
SEQ_BLOCK_SIZE: int = 256
assert q.shape[1] == k.shape[1] == v.shape[1]
n_head = q.shape[1]
n_dim_block = (
max(q.shape[2], k.shape[2], v.shape[2]) + DIM_BLOCK_SIZE
) // DIM_BLOCK_SIZE
grid = (
2 * pcp_world_size * BS,
n_head,
n_dim_block,
)
_fused_pcp_qkv_select_kernel[grid](
q,
q.stride(0),
q.stride(1),
k,
k.stride(0),
k.stride(1),
v,
v.stride(0),
v.stride(1),
query_start_loc,
q_head,
q_tail,
k_head,
k_tail,
v_head,
v_tail,
pcp_world_size,
pcp_rank,
n_head,
q.shape[2],
k.shape[2],
v.shape[2],
SEQ_BLOCK_SIZE,
DIM_BLOCK_SIZE,
)
return q_head, k_head, v_head, q_tail, k_tail, v_tail
10 changes: 5 additions & 5 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,12 @@ class is dynamically inherited by the worker class. This is used to inject
"""
cp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using DCP or PCP.
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
and `total_cp_world_size = pcp_world_size * dcp_world_size`.
store interleave_size tokens on total_cp_rank i,
then store next interleave_size tokens on total_cp_rank i+1.
For `cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
and `cp_world_size = pcp_world_size * dcp_world_size`.
store interleave_size tokens on cp_rank i,
then store next interleave_size tokens on cp_rank i+1.
Interleave_size=1: token-level alignment, where token `i` is stored on
total_cp_rank `i % total_cp_world_size`.
cp_rank `i % cp_world_size`.
Interleave_size=block_size: block-level alignment, where tokens are
first populated to the preceding ranks. Tokens are then stored
in (rank i+1, block j) only after (rank i, block j) is fully occupied.
Expand Down
3 changes: 0 additions & 3 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,9 +1092,6 @@ def get_dcp_group() -> GroupCoordinator:
return _DCP


# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group

_PP: GroupCoordinator | None = None


Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,10 @@ def tp_size(self):
def dp_size(self):
return self.moe_parallel_config.dp_size

@property
def pcp_size(self):
return self.moe_parallel_config.pcp_size

@property
def ep_size(self):
return self.moe_parallel_config.ep_size
Expand All @@ -1039,6 +1043,10 @@ def tp_rank(self):
def dp_rank(self):
return self.moe_parallel_config.dp_rank

@property
def pcp_rank(self):
return self.moe_parallel_config.pcp_rank

@property
def ep_rank(self):
return self.moe_parallel_config.ep_rank
Expand Down
14 changes: 14 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,20 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"Forcing kv cache block size to 64 for FlashMLASparse backend."
)

# lazy import to avoid circular import
from vllm.config import CUDAGraphMode

compilation_config = vllm_config.compilation_config
if (
compilation_config.cudagraph_mode.has_full_cudagraphs()
and parallel_config.prefill_context_parallel_size > 1
):
logger.warning_once(
"Prefill context parallel (PCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
scheduler_config = vllm_config.scheduler_config
# Note: model_config may be None during testing
if (
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_dcp_local_seq_lens,
get_cp_local_seq_lens,
get_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -405,7 +405,7 @@ def schedule(
query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
dcp_context_kv_lens = seq_lens - query_kv_lens

dcp_context_kv_lens = get_dcp_local_seq_lens(
dcp_context_kv_lens = get_cp_local_seq_lens(
dcp_context_kv_lens,
self.dcp_world_size,
self.dcp_rank,
Expand Down
Loading
Loading