diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index a28630921771..d0414f81b33b 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -50,6 +50,7 @@ class ParallelSetup(NamedTuple): tp_size: int pp_size: int dcp_size: int + pcp_size: int cp_kv_cache_interleave_size: int eager_mode: bool chunked_prefill: bool @@ -73,6 +74,7 @@ def detailed( tp_base: int = 4, pp_base: int = 1, dcp_multipliers: list[float] | None = None, + pcp_base: int = 1, cp_kv_cache_interleave_size: int = 1, multi_node_only: bool = False, runner: RunnerOption = "auto", @@ -91,7 +93,8 @@ def detailed( ParallelSetup( tp_size=tp_base, pp_size=pp_multiplier * pp_base, - dcp_size=int(dcp_multiplier * tp_base), + dcp_size=max(1, int(dcp_multiplier * tp_base)), + pcp_size=pcp_base, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size, eager_mode=eager_mode_val, chunked_prefill=chunked_prefill_val, @@ -129,6 +132,8 @@ def iter_params(self, model_id: str): cp_kv_cache_interleave_size=64, attn_backend="FLASHMLA", ), + CPTestSettings.detailed(tp_base=1, pcp_base=4, cp_kv_cache_interleave_size=64), + CPTestSettings.detailed(tp_base=2, pcp_base=2, cp_kv_cache_interleave_size=64), ], "Qwen/Qwen2.5-1.5B-Instruct": [ CPTestSettings.detailed( @@ -156,6 +161,7 @@ def _test_cp_gsm8k( tp_size, pp_size, dcp_size, + pcp_size, cp_kv_cache_interleave_size, eager_mode, chunked_prefill, @@ -212,7 +218,9 @@ def _test_cp_gsm8k( str(pp_size), "--decode-context-parallel-size", str(dcp_size), - "--dcp-kv-cache-interleave-size", + "--prefill-context-parallel-size", + str(pcp_size), + "--cp-kv-cache-interleave-size", str(cp_kv_cache_interleave_size), "--distributed-executor-backend", distributed_backend, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 025ede1eb0a4..595de8a46ce4 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -320,8 +320,8 @@ class AttentionImpl(ABC, Generic[T]): pcp_world_size: int pcp_rank: int - total_cp_world_size: int - total_cp_rank: int + cp_world_size: int + cp_rank: int def __new__(cls, *args, **kwargs): # use __new__ so that all subclasses will call this @@ -343,11 +343,11 @@ def __new__(cls, *args, **kwargs): except AssertionError: self.pcp_world_size = 1 self.pcp_rank = 0 - self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size - self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank + self.cp_world_size = self.pcp_world_size * self.dcp_world_size + self.cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank self.need_to_return_lse_for_decode = ( - self.dcp_world_size > 1 and self.can_return_lse_for_decode + self.cp_world_size > 1 and self.can_return_lse_for_decode ) return self diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index bd6bc864d45d..b114c5f60067 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -467,3 +467,220 @@ def unpack_seq_triton( out = out.reshape(output_shape) return out + + +@triton.jit +def _fused_pcp_qkv_select_kernel( + q_ptr, + q_stride_B, + q_stride_H, + k_ptr, + k_stride_B, + k_stride_H, + v_ptr, + v_stride_B, + v_stride_H, + query_start_ptr, + out_q_head_ptr, + out_q_tail_ptr, + out_k_head_ptr, + out_k_tail_ptr, + out_v_head_ptr, + out_v_tail_ptr, + pcp_world_size: tl.constexpr, + pcp_rank: tl.constexpr, + n_head: tl.constexpr, + q_head_dim: tl.constexpr, + k_head_dim: tl.constexpr, + v_head_dim: tl.constexpr, + SEQ_BLOCK_SIZE: tl.constexpr, + DIM_BLOCK_SIZE: tl.constexpr, +): + req_id = tl.program_id(0) // (2 * pcp_world_size) + seq_block_id = tl.program_id(0) % (2 * pcp_world_size) + head_id = tl.program_id(1) + dim_block_id = tl.program_id(2) + dim_off = tl.arange(0, DIM_BLOCK_SIZE) + dim_block_id * DIM_BLOCK_SIZE + + q_start_loc = tl.load(query_start_ptr + req_id) + q_end_loc = tl.load(query_start_ptr + req_id + 1) + q_select_len = (q_end_loc - q_start_loc) // 2 + + # Select Q + if seq_block_id < 2: + block_q_start_loc = q_start_loc + seq_block_id * q_select_len + out_ptr = out_q_head_ptr if seq_block_id == 0 else out_q_tail_ptr + for qi in range(tl.cdiv(q_select_len, SEQ_BLOCK_SIZE)): + q_offset = tl.arange(0, SEQ_BLOCK_SIZE) + qi * SEQ_BLOCK_SIZE + mask = (dim_off[None, :] < q_head_dim) & (q_offset[:, None] < q_select_len) + q_src_idx = block_q_start_loc + q_offset[:, None] + q_dst_idx = q_start_loc // 2 + q_offset[:, None] + q_val = tl.load( + q_ptr + + q_src_idx * q_stride_B + + head_id * q_stride_H + + dim_off[None, :], + mask=mask, + ) + tl.store( + out_ptr + + q_dst_idx * n_head * q_head_dim + + head_id * q_head_dim + + dim_off[None, :], + q_val, + mask=mask, + ) + + # Select KV + kv_start_loc = q_start_loc * pcp_world_size + kv_select_len = q_select_len + k_d_mask = dim_off[None, :] < k_head_dim + v_d_mask = dim_off[None, :] < v_head_dim + block_src_kv_start_loc = kv_start_loc + seq_block_id * kv_select_len + block_dst_kv_head_start_loc = ( + kv_start_loc // 2 // pcp_world_size * (pcp_rank + 1) + + seq_block_id * kv_select_len + ) + block_dst_kv_tail_start_loc = ( + kv_start_loc // 2 // pcp_world_size * (2 * pcp_world_size - pcp_rank) + + seq_block_id * kv_select_len + ) + for ki in range(tl.cdiv(kv_select_len, SEQ_BLOCK_SIZE)): + kv_offset = tl.arange(0, SEQ_BLOCK_SIZE) + ki * SEQ_BLOCK_SIZE + kv_block_mask = kv_offset[:, None] < kv_select_len + kv_src_idx = block_src_kv_start_loc + kv_offset[:, None] + kv_dst_idx_head = block_dst_kv_head_start_loc + kv_offset[:, None] + kv_dst_idx_tail = block_dst_kv_tail_start_loc + kv_offset[:, None] + k_val = tl.load( + k_ptr + kv_src_idx * k_stride_B + head_id * k_stride_H + dim_off[None, :], + mask=k_d_mask & kv_block_mask, + ) + v_val = tl.load( + v_ptr + kv_src_idx * v_stride_B + head_id * v_stride_H + dim_off[None, :], + mask=v_d_mask & kv_block_mask, + ) + if seq_block_id < pcp_rank + 1: + tl.store( + out_k_head_ptr + + kv_dst_idx_head * n_head * k_head_dim + + head_id * k_head_dim + + dim_off[None, :], + k_val, + mask=k_d_mask & kv_block_mask, + ) + tl.store( + out_v_head_ptr + + kv_dst_idx_head * n_head * v_head_dim + + head_id * v_head_dim + + dim_off[None, :], + v_val, + mask=v_d_mask & kv_block_mask, + ) + if seq_block_id < 2 * pcp_world_size - pcp_rank: + tl.store( + out_k_tail_ptr + + kv_dst_idx_tail * n_head * k_head_dim + + head_id * k_head_dim + + dim_off[None, :], + k_val, + mask=k_d_mask & kv_block_mask, + ) + tl.store( + out_v_tail_ptr + + kv_dst_idx_tail * n_head * v_head_dim + + head_id * v_head_dim + + dim_off[None, :], + v_val, + mask=v_d_mask & kv_block_mask, + ) + + +def fused_pcp_qkv_select( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + query_start_loc: torch.Tensor, + pcp_world_size: int, + pcp_rank: int, +): + """ + Select the query and kv tensors for PCP. Instead of calling + `torch.index_select` multiple times, this function fuses the + selection for Q, K, and V into a single kernel to reduce + kernel launch overhead. + Args: + q: query tensor on the current PCP rank. + k: key tensor across PCP ranks. + v: value tensor across PCP ranks. + query_start_loc: start location of each query. + pcp_world_size: number of PCP ranks. + pcp_rank: rank of the current PCP rank. + Returns: + q_head: selected query tensor for pcp head. + k_head: selected key tensor for pcp head. + v_head: selected value tensor for pcp head. + q_tail: selected query tensor for pcp tail. + k_tail: selected key tensor for pcp tail. + v_tail: selected value tensor for pcp tail. + + """ + q_head = torch.empty( + (q.size(0) // 2,) + q.shape[1:], device=q.device, dtype=q.dtype + ) + q_tail = torch.empty_like(q_head) + k_head = torch.empty( + (q.size(0) // 2 * (pcp_rank + 1),) + k.shape[1:], device=k.device, dtype=k.dtype + ) + v_head = torch.empty( + (q.size(0) // 2 * (pcp_rank + 1),) + v.shape[1:], device=v.device, dtype=v.dtype + ) + k_tail = torch.empty( + (q.size(0) // 2 * (2 * pcp_world_size - pcp_rank),) + k.shape[1:], + device=k.device, + dtype=k.dtype, + ) + v_tail = torch.empty( + (q.size(0) // 2 * (2 * pcp_world_size - pcp_rank),) + v.shape[1:], + device=v.device, + dtype=v.dtype, + ) + BS = len(query_start_loc) - 1 + DIM_BLOCK_SIZE: int = 64 + SEQ_BLOCK_SIZE: int = 256 + assert q.shape[1] == k.shape[1] == v.shape[1] + n_head = q.shape[1] + n_dim_block = ( + max(q.shape[2], k.shape[2], v.shape[2]) + DIM_BLOCK_SIZE + ) // DIM_BLOCK_SIZE + grid = ( + 2 * pcp_world_size * BS, + n_head, + n_dim_block, + ) + _fused_pcp_qkv_select_kernel[grid]( + q, + q.stride(0), + q.stride(1), + k, + k.stride(0), + k.stride(1), + v, + v.stride(0), + v.stride(1), + query_start_loc, + q_head, + q_tail, + k_head, + k_tail, + v_head, + v_tail, + pcp_world_size, + pcp_rank, + n_head, + q.shape[2], + k.shape[2], + v.shape[2], + SEQ_BLOCK_SIZE, + DIM_BLOCK_SIZE, + ) + return q_head, k_head, v_head, q_tail, k_tail, v_tail diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 11504fb08355..fea2cd181fbd 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -242,12 +242,12 @@ class is dynamically inherited by the worker class. This is used to inject """ cp_kv_cache_interleave_size: int = 1 """Interleave size of kv_cache storage while using DCP or PCP. - For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`, - and `total_cp_world_size = pcp_world_size * dcp_world_size`. - store interleave_size tokens on total_cp_rank i, - then store next interleave_size tokens on total_cp_rank i+1. + For `cp_rank = pcp_rank * dcp_world_size + dcp_rank`, + and `cp_world_size = pcp_world_size * dcp_world_size`. + store interleave_size tokens on cp_rank i, + then store next interleave_size tokens on cp_rank i+1. Interleave_size=1: token-level alignment, where token `i` is stored on - total_cp_rank `i % total_cp_world_size`. + cp_rank `i % cp_world_size`. Interleave_size=block_size: block-level alignment, where tokens are first populated to the preceding ranks. Tokens are then stored in (rank i+1, block j) only after (rank i, block j) is fully occupied. diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f5ada5a009ec..bbd0adfbcce8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1092,9 +1092,6 @@ def get_dcp_group() -> GroupCoordinator: return _DCP -# kept for backward compatibility -get_context_model_parallel_group = get_dcp_group - _PP: GroupCoordinator | None = None diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index d581e91f36d0..4266514bc94e 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1027,6 +1027,10 @@ def tp_size(self): def dp_size(self): return self.moe_parallel_config.dp_size + @property + def pcp_size(self): + return self.moe_parallel_config.pcp_size + @property def ep_size(self): return self.moe_parallel_config.ep_size @@ -1039,6 +1043,10 @@ def tp_rank(self): def dp_rank(self): return self.moe_parallel_config.dp_rank + @property + def pcp_rank(self): + return self.moe_parallel_config.pcp_rank + @property def ep_rank(self): return self.moe_parallel_config.ep_rank diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2dc4ba5d70ca..028e37900369 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -233,6 +233,20 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: "Forcing kv cache block size to 64 for FlashMLASparse backend." ) + # lazy import to avoid circular import + from vllm.config import CUDAGraphMode + + compilation_config = vllm_config.compilation_config + if ( + compilation_config.cudagraph_mode.has_full_cudagraphs() + and parallel_config.prefill_context_parallel_size > 1 + ): + logger.warning_once( + "Prefill context parallel (PCP) is enabled, which is " + "incompatible with full CUDA graphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE scheduler_config = vllm_config.scheduler_config # Note: model_config may be None during testing if ( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 3445e998d637..82badd4477a3 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -45,7 +45,7 @@ AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - get_dcp_local_seq_lens, + get_cp_local_seq_lens, get_kv_cache_layout, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -405,7 +405,7 @@ def schedule( query_kv_lens = query_start_loc[1:] - query_start_loc[:-1] dcp_context_kv_lens = seq_lens - query_kv_lens - dcp_context_kv_lens = get_dcp_local_seq_lens( + dcp_context_kv_lens = get_cp_local_seq_lens( dcp_context_kv_lens, self.dcp_world_size, self.dcp_rank, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 623ae892ecda..7f4400ca2b8a 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -53,7 +53,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, KVCacheLayoutType, - get_dcp_local_seq_lens, + get_cp_local_seq_lens, get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters, @@ -892,7 +892,7 @@ def build( seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu ) - seq_lens_cpu = get_dcp_local_seq_lens( + seq_lens_cpu = get_cp_local_seq_lens( seq_lens_cpu, self.dcp_world_size, self.dcp_rank, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index e9ec96835f27..c848497a45df 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -205,11 +205,19 @@ MLAAttentionImpl, ) from vllm.attention.backends.utils import get_mla_dims -from vllm.attention.ops.common import cp_lse_ag_out_rs +from vllm.attention.ops.common import ( + cp_lse_ag_out_ar, + cp_lse_ag_out_rs, + fused_pcp_qkv_select, +) from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank +from vllm.distributed.parallel_state import ( + get_dcp_group, + get_pcp_group, + is_global_first_rank, +) from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -225,9 +233,11 @@ from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, - get_dcp_local_seq_lens, + get_cp_local_seq_lens, + get_pcp_query_indices, get_per_layer_parameters, infer_global_hyperparameters, + pcp_kv_allgather_and_restore, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -350,6 +360,14 @@ class ChunkedContextMetadata: cu_seq_lens_lst: list[list[int]] | None = None chunk_size: int | None = None + @dataclass + class PCPMetadata: + # For PCP + output_restore_idx: torch.Tensor | None = None + query_start_loc: torch.Tensor | None = None + kv_head_start_loc: torch.Tensor | None = None + kv_tail_start_loc: torch.Tensor | None = None + block_table: torch.Tensor query_start_loc: torch.Tensor max_query_len: int @@ -357,6 +375,7 @@ class ChunkedContextMetadata: query_seq_lens: torch.Tensor | None = None workspace_buffer: torch.Tensor | None = None q_data_type: torch.dtype | None = None + pcp_metadata: PCPMetadata | None = None @dataclass @@ -379,7 +398,7 @@ class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata): class MLACommonDecodeMetadata: block_table: torch.Tensor seq_lens: torch.Tensor - dcp_tot_seq_lens: torch.Tensor | None + cp_tot_seq_lens: torch.Tensor | None D = TypeVar("D", bound=MLACommonDecodeMetadata) @@ -426,6 +445,8 @@ class MLACommonMetadata(Generic[D]): | None ) = None + pcp_allgather_restore_idx: torch.Tensor | None = None + def __post_init__(self): if self.head_dim is not None and not MLACommonBackend.supports_head_size( self.head_dim @@ -535,7 +556,7 @@ def __init__( vllm_config: VllmConfig, device: torch.device, metadata_cls: type[M] | None = None, - supports_dcp_with_varlen: bool = False, + supports_cp_with_varlen: bool = False, ): self.metadata_cls = ( metadata_cls if metadata_cls is not None else MLACommonMetadata @@ -558,9 +579,19 @@ def __init__( # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 - self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size - self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size + try: + self.pcp_world_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group + except AssertionError: + # PCP might not be initialized in testing + self.pcp_world_size = 1 + self.pcp_rank = 0 + self.cp_world_size = self.dcp_world_size * self.pcp_world_size + self.cp_local_block_size = parallel_config.cp_kv_cache_interleave_size + self.cp_virtual_block_size = self.cp_local_block_size * self.cp_world_size self.cp_kv_cache_interleave_size = parallel_config.cp_kv_cache_interleave_size + # TODO(yyj) Remove this once the PCP bug for decode_length > 1 is fixed. + supports_cp_with_varlen = supports_cp_with_varlen and self.pcp_world_size == 1 # Don't try to access the runner on AMD if self.aot_schedule: @@ -570,16 +601,16 @@ def __init__( self.determine_chunked_prefill_workspace_size(vllm_config) ) - if self.dcp_world_size > 1: - # Note(hc): The local kvcache is incomplete when DCP is triggered, - # an additional kvcache allgather across the DCP group is therefore - # required, so the workspace has to be enlarged by 1/DCP relative + if self.cp_world_size > 1: + # Note(hc): The local kvcache is incomplete when DCP or PCP is triggered, + # an additional kvcache allgather across the DCP&PCP group is therefore + # required, so the workspace has to be enlarged by 1/CP relative # to the original TP allocation. - assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0 + assert self.chunked_prefill_workspace_size % self.cp_world_size == 0 self.chunked_prefill_workspace = torch.empty( ( self.chunked_prefill_workspace_size - + self.chunked_prefill_workspace_size // self.dcp_world_size, + + self.chunked_prefill_workspace_size // self.cp_world_size, self.model_config.get_head_size(), ), dtype=self.model_config.dtype, @@ -636,7 +667,7 @@ def __init__( supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY self._init_reorder_batch_threshold( - self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen + self.reorder_batch_threshold, supports_spec_decode, supports_cp_with_varlen ) # Validate consistency between query_len_support and reorder_batch_threshold @@ -730,12 +761,12 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, - dcp_tot_seq_lens_device: torch.Tensor | None, + cp_tot_seq_lens_device: torch.Tensor | None, ) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, - dcp_tot_seq_lens=dcp_tot_seq_lens_device, + cp_tot_seq_lens=cp_tot_seq_lens_device, ) def build_for_cudagraph_capture( @@ -765,6 +796,7 @@ def build( num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len max_seq_len = common_attn_metadata.max_seq_len + pcp_allgather_restore_idx = common_attn_metadata.pcp_allgather_restore_idx # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because @@ -776,7 +808,11 @@ def build( query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens - dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens + cp_local_seq_lens = common_attn_metadata.cp_local_seq_lens + + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( @@ -801,6 +837,9 @@ def build( prefill_query_start_loc = ( query_start_loc[reqs_start:] - query_start_loc[reqs_start] ) + prefill_query_start_loc_cpu = ( + query_start_loc_cpu[reqs_start:] - query_start_loc_cpu[reqs_start] + ) chunked_context_metadata = None if max_context_len_cpu > 0: @@ -863,34 +902,34 @@ def build( chunk_len = chunk_token_to_seq_tensor.shape[0] token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor - if self.dcp_world_size > 1: - local_context_lens_allranks = get_dcp_local_seq_lens( + if self.cp_world_size > 1: + local_context_lens_allranks = get_cp_local_seq_lens( context_lens_cpu, - self.dcp_world_size, + self.cp_world_size, None, - self.dcp_local_block_size, + self.cp_local_block_size, ) # Note(qcs): The max local context lengths - # padded to `dcp_local_block_size`. + # padded to `cp_local_block_size`. padded_local_context_lens_cpu = ( cdiv( context_lens_cpu, - self.dcp_virtual_block_size, + self.cp_virtual_block_size, ) - * self.dcp_local_block_size + * self.cp_local_block_size ) # Note(hc): The above max_context_chunk already enforces - # block_size alignment, DCP just need the block_size can - # be divisible by dcp_world_size, because DCP use + # block_size alignment, DCP and PCP just need the block_size can + # be divisible by cp_world_size, because DCP and PCP use # cp_gather_cache which not require `cp_chunk_starts` # aligned to page_size. - assert max_context_chunk % self.dcp_world_size == 0 + assert max_context_chunk % self.cp_world_size == 0 padded_local_max_context_chunk_across_ranks = ( cdiv( max_context_chunk, - self.dcp_virtual_block_size, + self.cp_virtual_block_size, ) - * self.dcp_local_block_size + * self.cp_local_block_size ) local_chunk_starts = ( torch.arange(num_chunks, dtype=torch.int32) @@ -922,7 +961,7 @@ def build( if self._use_cudnn_prefill else MLACommonPrefillMetadata.ChunkedContextMetadata ) - if self.dcp_world_size > 1: + if self.cp_world_size > 1: chunked_context_metadata = chunked_context_metadata_cls( cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), starts=local_chunk_starts.to(device, non_blocking=True), @@ -964,11 +1003,33 @@ def build( <= self.chunked_prefill_workspace_size ) + pcp_metadata = None + if self.pcp_world_size > 1: + # NOTE(yyj): We need to get the indices here for + # restoring the output. + q_head_idx, q_tail_idx = get_pcp_query_indices( + prefill_query_start_loc_cpu + ) + output_res_idx = torch.cat([q_head_idx, q_tail_idx]).argsort() + pcp_metadata = MLACommonPrefillMetadata.PCPMetadata( + output_restore_idx=output_res_idx.to( + device, dtype=torch.int32, non_blocking=True + ), + query_start_loc=prefill_query_start_loc // 2, + kv_head_start_loc=prefill_query_start_loc + // 2 + * (self.pcp_rank + 1), + kv_tail_start_loc=prefill_query_start_loc + // 2 + * (self.pcp_world_size * 2 - self.pcp_rank), + ) + prefill_metadata = self.prefill_metadata_cls( block_table=block_table_tensor[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, + pcp_metadata=pcp_metadata, ) if self._use_cudnn_prefill: @@ -986,17 +1047,17 @@ def build( decode_metadata = None if num_decodes > 0: - dcp_tot_seq_lens_device = None - if self.dcp_world_size > 1: - dcp_tot_seq_lens_device = seq_lens[:num_decodes] - seq_lens = dcp_local_seq_lens + cp_tot_seq_lens_device = None + if self.cp_world_size > 1: + cp_tot_seq_lens_device = seq_lens[:num_decodes] + seq_lens = cp_local_seq_lens - # After DCP distribution, the maximum number of tokens for any rank is + # After CP distribution, the maximum number of tokens for any rank is # ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size, # and I is cp_kv_cache_interleave_size. # This eliminates GPU->CPU sync while minimizing workspace # over-allocation. - num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size + num_partitions = self.cp_world_size * self.cp_kv_cache_interleave_size max_seq_len = ( (max_seq_len + num_partitions - 1) // num_partitions ) * self.cp_kv_cache_interleave_size @@ -1008,7 +1069,7 @@ def build( query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], query_start_loc_device=query_start_loc[: num_decodes + 1], num_decode_tokens=num_decode_tokens, - dcp_tot_seq_lens_device=dcp_tot_seq_lens_device, + cp_tot_seq_lens_device=cp_tot_seq_lens_device, ) attn_metadata = self.metadata_cls( @@ -1025,6 +1086,7 @@ def build( num_prefills=num_prefills, prefill=prefill_metadata, decode=decode_metadata, + pcp_allgather_restore_idx=pcp_allgather_restore_idx, ) if self._use_fi_prefill and num_prefills > 0: @@ -1331,6 +1393,9 @@ def __init__(self, *args, **kwargs) -> None: ) self.dcp_world_size: int | None = None + self.pcp_world_size: int | None = None + self.dcp_rank: int | None = None + self.pcp_rank: int | None = None self.chunked_prefill_workspace_size = ( MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( @@ -1381,24 +1446,103 @@ def _flash_attn_varlen_diff_headdims( def _run_prefill_new_tokens_fa( self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse ): - return self._flash_attn_varlen_diff_headdims( - q=q, - k=k, - v=v, - cu_seqlens_q=prefill.query_start_loc, - cu_seqlens_k=prefill.query_start_loc, - max_seqlen_q=prefill.max_query_len, - max_seqlen_k=prefill.max_query_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=return_softmax_lse, - ) + assert self.pcp_world_size is not None + assert self.pcp_rank is not None + if self.pcp_world_size > 1: + # NOTE When PCP is enabled, we split the queries keys and values into + # "head" and "tail" parts using the DualChunkSwap strategy to balance + # workload across PCP ranks. We run attention twice (once for the head + # part and once for the tail part), then concatenate the results and + # restore the original ordering. + # + # Considering pcp_world_size=2 and sequence is [0,1,2,3,4,5,6,7] + # + # pcp_rank0: Q [0,1,6,7] KV [0,1,2,3,4,5,6,7] + # Q\KV 0 1 2 3 4 5 6 7 + # head 0 1 0 0 0 0 0 0 0 + # 1 1 1 0 0 0 0 0 0 + # ------------------- + # tail 6 1 1 1 1 1 1 1 0 + # 7 1 1 1 1 1 1 1 1 + # + # pcp_rank1: Q[2,3,4,5] KV [0,1,2,3,4,5,6,7] + # Q\KV 0 1 2 3 4 5 6 7 + # head 2 1 1 1 0 0 0 0 0 + # 3 1 1 1 1 0 0 0 0 + # ------------------- + # tail 4 1 1 1 1 1 0 0 0 + # 5 1 1 1 1 1 1 0 0 + + q_head, k_head, v_head, q_tail, k_tail, v_tail = fused_pcp_qkv_select( + q=q, + k=k, + v=v, + query_start_loc=prefill.query_start_loc, + pcp_rank=self.pcp_rank, + pcp_world_size=self.pcp_world_size, + ) + + pcp_metadata = prefill.pcp_metadata + assert pcp_metadata is not None + output_head, lse_head = self._flash_attn_varlen_diff_headdims( + q=q_head, + k=k_head, + v=v_head, + cu_seqlens_q=pcp_metadata.query_start_loc, + cu_seqlens_k=pcp_metadata.kv_head_start_loc, + max_seqlen_q=prefill.max_query_len // 2, + max_seqlen_k=prefill.max_query_len // 2 * (self.pcp_rank + 1), + softmax_scale=self.scale, + causal=True, + return_softmax_lse=True, + ) + + output_tail, lse_tail = self._flash_attn_varlen_diff_headdims( + q=q_tail, + k=k_tail, + v=v_tail, + cu_seqlens_q=pcp_metadata.query_start_loc, + cu_seqlens_k=pcp_metadata.kv_tail_start_loc, + max_seqlen_q=prefill.max_query_len // 2, + max_seqlen_k=prefill.max_query_len + // 2 + * (self.pcp_world_size * 2 - self.pcp_rank), + softmax_scale=self.scale, + causal=True, + return_softmax_lse=True, + ) + + output = torch.cat([output_head, output_tail], dim=0) + output_restore_idx = pcp_metadata.output_restore_idx + if return_softmax_lse: + # FA returns LSE in shape [ H, B ] + lse = torch.cat([lse_head, lse_tail], dim=-1) + return ( + torch.index_select(output, 0, output_restore_idx), + torch.index_select(lse, -1, output_restore_idx), + ) + else: + return torch.index_select(output, 0, output_restore_idx) + else: + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill.query_start_loc, + cu_seqlens_k=prefill.query_start_loc, + max_seqlen_q=prefill.max_query_len, + max_seqlen_k=prefill.max_query_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=return_softmax_lse, + ) def _run_prefill_new_tokens_fi( self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse ): assert isinstance(prefill, FlashInferPrefillMetadata) assert prefill.prefill_main is not None + assert self.pcp_world_size == 1, "PCP is not supported for FlashInfer Prefill." ret = prefill.prefill_main.run( q=q, @@ -1416,6 +1560,8 @@ def _run_prefill_new_tokens_cudnn( ): assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.query_seq_lens is not None + assert self.pcp_world_size == 1, "PCP is not supported for CUDNN Prefill." + output, lse = cudnn_batch_prefill_with_kv_cache( q=q, k_cache=k, @@ -1501,6 +1647,7 @@ def _run_prefill_new_tokens_trtllm_ragged( assert prefill.query_seq_lens is not None assert prefill.workspace_buffer is not None + assert self.pcp_world_size == 1, "PCP is not supported for TRT-LLM Prefill." ret = trtllm_ragged_attention_deepseek( query=q, @@ -1763,9 +1910,9 @@ def _context_parallel_compute_prefill_context( kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, - dcp_world_size: int, + cp_world_size: int, ): - assert k_scale is None, "DCP not support scaled kvcache now." + assert k_scale is None, "PCP/DCP not support scaled kvcache now." assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill assert prefill_metadata.chunked_context is not None @@ -1792,19 +1939,23 @@ def _context_parallel_compute_prefill_context( seq_starts=prefill_metadata.chunked_context.starts[i], ) # workspace - # |------- N tokens --------|--------- N*dcp_size tokens ----------| + # |------- N tokens --------|--------- N*cp_size tokens ----------| # |<- use for loca_gather ->|<--------- use for allgather -------->| - allgather_offset = workspace.shape[0] // (dcp_world_size + 1) - assert allgather_offset * (dcp_world_size + 1) == workspace.shape[0] + allgather_offset = workspace.shape[0] // (cp_world_size + 1) + assert allgather_offset * (cp_world_size + 1) == workspace.shape[0] assert toks <= allgather_offset local_gathered_kvcache = workspace[:toks] cur_allgather_workspace = workspace[ - allgather_offset : allgather_offset * (1 + dcp_world_size) + allgather_offset : allgather_offset * (1 + cp_world_size) ] - assert toks * dcp_world_size <= cur_allgather_workspace.shape[0] - cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size] + assert toks * cp_world_size <= cur_allgather_workspace.shape[0] + cur_allgather_kvcache = cur_allgather_workspace[: toks * cp_world_size] + # TODO(yyj) Reduce to a single all-gather operation cur_allgather_kvcache.copy_( - get_dcp_group().all_gather(local_gathered_kvcache, dim=0) + get_pcp_group().all_gather( + get_dcp_group().all_gather(local_gathered_kvcache, dim=0), + dim=0, + ) ) assert ( cur_allgather_kvcache.shape[-1] @@ -1874,6 +2025,7 @@ def _forward_prefill( # TODO (zyongye): Prefill function here assert attn_metadata.prefill is not None assert self.dcp_world_size is not None + assert self.pcp_world_size is not None has_context = attn_metadata.prefill.chunked_context is not None kv_nope = self.kv_b_proj(kv_c_normed)[0].view( @@ -1893,14 +2045,14 @@ def _forward_prefill( if has_context: suffix_output, suffix_lse = output_prefill - if self.dcp_world_size > 1: + if self.dcp_world_size * self.pcp_world_size > 1: context_output, context_lse = ( self._context_parallel_compute_prefill_context( q, kv_c_and_k_pe_cache, attn_metadata, k_scale=None, - dcp_world_size=self.dcp_world_size, + cp_world_size=self.dcp_world_size * self.pcp_world_size, ) ) else: @@ -1975,17 +2127,34 @@ def forward( if self.dcp_world_size is None: self.dcp_world_size = get_dcp_group().world_size + if self.pcp_world_size is None: + self.pcp_world_size = get_pcp_group().world_size + if self.dcp_rank is None: + self.dcp_rank = get_dcp_group().rank_in_group + if self.pcp_rank is None: + self.pcp_rank = get_pcp_group().rank_in_group fp8_attention = self.kv_cache_dtype.startswith("fp8") num_actual_toks = attn_metadata.num_actual_tokens + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + if self.pcp_world_size > 1: + assert attn_metadata.pcp_allgather_restore_idx is not None + k_c_normed, k_pe = pcp_kv_allgather_and_restore( + k_c_normed, + k_pe, + num_actual_toks, + attn_metadata.pcp_allgather_restore_idx, + get_pcp_group(), + ) + # Inputs and outputs may be padded for CUDA graphs output_padded = output output = output[:num_actual_toks, ...] q = q[:num_actual_toks, ...] - k_c_normed = k_c_normed[:num_actual_toks, ...] - k_pe = k_pe[:num_actual_toks, ...] assert ( attn_metadata.num_decodes is not None @@ -1997,12 +2166,6 @@ def forward( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - decode_q = q[:num_decode_tokens] - - prefill_q = q[num_decode_tokens:] - prefill_k_pe = k_pe[num_decode_tokens:] - prefill_k_c_normed = k_c_normed[num_decode_tokens:] - # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( @@ -2018,6 +2181,9 @@ def forward( kv_cache = kv_cache.view(current_platform.fp8_dtype()) if has_prefill: + prefill_q = q[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens * self.pcp_world_size :] + prefill_k_c_normed = k_c_normed[num_decode_tokens * self.pcp_world_size :] self._forward_prefill( prefill_q, prefill_k_c_normed, @@ -2031,6 +2197,7 @@ def forward( if has_decode: assert attn_metadata.decode is not None + decode_q = q[:num_decode_tokens] decode_q_nope, decode_q_pe = decode_q.split( [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) @@ -2113,10 +2280,20 @@ def forward( # correct dcp attn_out with lse. if self.dcp_world_size > 1: - attn_out = cp_lse_ag_out_rs( + attn_out, lse = cp_lse_ag_out_rs( attn_out, lse, get_dcp_group(), + return_lse=True, + is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), + ) + + # recorect pcp attn_out with lse. + if self.pcp_world_size > 1: + attn_out = cp_lse_ag_out_ar( + attn_out, + lse, + get_pcp_group(), is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), ) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index b4a68f472e9c..922b1920f40e 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -112,7 +112,7 @@ def __init__( vllm_config, device, FlashAttnMLAMetadata, - supports_dcp_with_varlen=(interleave_size == 1), + supports_cp_with_varlen=(interleave_size == 1), ) self.max_num_splits = 0 # No upper bound on the number of splits. self.fa_aot_schedule = get_flash_attn_version() == 3 @@ -174,7 +174,7 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, - dcp_tot_seq_lens_device: torch.Tensor | None, + cp_tot_seq_lens_device: torch.Tensor | None, ) -> FlashAttnMLADecodeMetadata: query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] max_query_len = query_lens_cpu.max().item() @@ -224,13 +224,14 @@ def _build_decode( max_seq_len=max_seq_len, scheduler_metadata=scheduler_metadata, max_num_splits=max_num_splits, - dcp_tot_seq_lens=dcp_tot_seq_lens_device, + cp_tot_seq_lens=cp_tot_seq_lens_device, ) return metadata class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): can_return_lse_for_decode: bool = True + supports_pcp: bool = True def __init__( self, @@ -311,6 +312,10 @@ def _forward_decode( # to prevent invalid grid configuration during graph capture. max_seqlen_q = max(attn_metadata.decode.max_query_len, 1) + assert self.pcp_world_size is not None + assert self.dcp_world_size is not None + assert self.pcp_rank is not None + assert self.dcp_rank is not None attn_out = flash_attn_varlen_func( q=q_pe, k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 @@ -327,14 +332,14 @@ def _forward_decode( fa_version=3, # only version 3 is supported scheduler_metadata=attn_metadata.decode.scheduler_metadata, num_splits=attn_metadata.decode.max_num_splits, - cp_world_size=self.dcp_world_size, - cp_rank=self.dcp_rank, - cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens, + cp_world_size=self.pcp_world_size * self.dcp_world_size, + cp_rank=self.pcp_rank * self.dcp_world_size + self.dcp_rank, + cp_tot_seqused_k=attn_metadata.decode.cp_tot_seq_lens, ) if self.need_to_return_lse_for_decode: o, lse = attn_out - # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ] + # FA returns LSE in shape [ H, B ] but CP wants [ B, H ] return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ] else: o = attn_out diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 913503ce4494..d3150671cb06 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -148,7 +148,7 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, - dcp_tot_seq_lens_device: torch.Tensor | None, + cp_tot_seq_lens_device: torch.Tensor | None, ) -> FlashMLADecodeMetadata: query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] # we use the max but all should be the same due to uniform length requirement @@ -194,12 +194,13 @@ def _build_decode( seq_lens=seq_lens_device, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, - dcp_tot_seq_lens=dcp_tot_seq_lens_device, + cp_tot_seq_lens=cp_tot_seq_lens_device, ) class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): can_return_lse_for_decode: bool = True + supports_pcp: bool = True def __init__( self, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index e8921f8a1c40..34bbfe220659 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -111,7 +111,7 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, - dcp_tot_seq_lens_device: torch.Tensor | None, + cp_tot_seq_lens_device: torch.Tensor | None, ) -> AiterMLADecodeMetadata: # kernel block size is always 1, although the kv block size is not 1. device = self.device @@ -172,7 +172,7 @@ def _build_decode( paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, qo_indptr=qo_indptr, - dcp_tot_seq_lens=dcp_tot_seq_lens_device, + cp_tot_seq_lens=cp_tot_seq_lens_device, max_qo_len=max_qo_len, attn_out_dtype=self.decode_attn_out_dtype, ) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 6b94f786a26b..78cd61618eee 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -37,6 +37,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout, ) +from vllm.distributed.parallel_state import GroupCoordinator from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.v1.kv_cache_interface import AttentionSpec @@ -92,9 +93,14 @@ class CommonAttentionMetadata: encoder_seq_lens: torch.Tensor | None = None encoder_seq_lens_cpu: np.ndarray | None = None - dcp_local_seq_lens: torch.Tensor | None = None - dcp_local_seq_lens_cpu: torch.Tensor | None = None - """Sequence lengths of the local rank in decode context parallelism world""" + cp_local_seq_lens: torch.Tensor | None = None + cp_local_seq_lens_cpu: torch.Tensor | None = None + """ + Sequence lengths of the local rank in (decode & prefill) context parallelism world + """ + + pcp_allgather_restore_idx: torch.Tensor | None = None + """ Indices to restore the original order of KV in prefill context parallelism """ # WARNING: Deprecated fields. Will be removed in a future release (v0.14.0) _seq_lens_cpu: torch.Tensor | None = None @@ -156,8 +162,8 @@ def unpadded( num_logits_indices=self.num_logits_indices, encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens), encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu), - dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens), - dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu), + cp_local_seq_lens=maybe_slice_reqs(self.cp_local_seq_lens), + cp_local_seq_lens_cpu=maybe_slice_reqs(self.cp_local_seq_lens_cpu), ) @@ -348,7 +354,7 @@ def _init_reorder_batch_threshold( self, reorder_batch_threshold: int | None = 1, supports_spec_as_decode: bool = False, - supports_dcp_with_varlen: bool = False, + supports_cp_with_varlen: bool = False, ) -> None: self.reorder_batch_threshold = reorder_batch_threshold if self.reorder_batch_threshold is not None and supports_spec_as_decode: @@ -367,8 +373,8 @@ def _init_reorder_batch_threshold( if ( self.vllm_config.parallel_config.decode_context_parallel_size > 1 - and not supports_dcp_with_varlen - ): + or self.vllm_config.parallel_config.prefill_context_parallel_size > 1 + ) and not supports_cp_with_varlen: self.reorder_batch_threshold = 1 @abstractmethod @@ -1207,26 +1213,25 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): return nums_dict, batch_ptr, token_chunk_offset_ptr -def get_dcp_local_seq_lens( +def get_cp_local_seq_lens( seq_lens: torch.Tensor, - dcp_size: int = 1, - dcp_rank: int | None = None, + cp_world_size: int = 1, + cp_rank: int | None = None, cp_kv_cache_interleave_size: int = 1, ) -> torch.Tensor: - """While using dcp, kv_cache size stored on each rank may be different, - use this function to calculate split decode seq_lens of each dcp rank. - Only consider dcp now, we can extend the case of cp based on this. + """While using dcp or pcp, kv_cache size stored on each rank may be different, + use this function to calculate split decode seq_lens of each cp rank. """ num_requests = seq_lens.size(0) - if dcp_rank is None: + if cp_rank is None: rank_offsets = ( - torch.arange(dcp_size, dtype=torch.int32, device=seq_lens.device) + torch.arange(cp_world_size, dtype=torch.int32, device=seq_lens.device) .unsqueeze(0) .repeat(num_requests, 1) ) else: rank_offsets = torch.tensor( - [[dcp_rank]], dtype=torch.int32, device=seq_lens.device + [[cp_rank]], dtype=torch.int32, device=seq_lens.device ) seq_lens_tiled = ( seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1]) @@ -1234,14 +1239,97 @@ def get_dcp_local_seq_lens( base = ( seq_lens_tiled // cp_kv_cache_interleave_size - // dcp_size + // cp_world_size * cp_kv_cache_interleave_size ) - remainder = seq_lens_tiled - base * dcp_size + remainder = seq_lens_tiled - base * cp_world_size remainder = torch.clip( remainder - rank_offsets * cp_kv_cache_interleave_size, 0, cp_kv_cache_interleave_size, ) - dcp_local_seq_lens = base + remainder - return dcp_local_seq_lens.squeeze(1) + cp_local_seq_lens = base + remainder + return cp_local_seq_lens.squeeze(1) + + +def pcp_kv_allgather_and_restore( + key: torch.Tensor, + value: torch.Tensor, + num_actual_tokens: int, + pcp_allgather_restore_idx: torch.Tensor, + pcp_group: GroupCoordinator, +): + """ + All-gather key and value tensors across PCP ranks and restore the original order. + Args: + key: key tensor for the current pcp rank. + value: value tensor for the current pcp rank. + num_actual_tokens: number of actual tokens (Exclude graph padding tokens). + pcp_allgather_restore_idx: indices to restore the original order. + pcp_group: PCP group coordinator. + Returns: + key: all-gathered and restored key tensor. + value: all-gathered and restored value tensor. + """ + # NOTE(yyj): we must `slice` key and value because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. + # TODO(yyj) Batch all-gather operations to reduce launch overhead. + # Be careful about the dimensions of key and value. + key_across_cp = pcp_group.all_gather(key[:num_actual_tokens].contiguous(), dim=0) + value_across_cp = pcp_group.all_gather( + value[:num_actual_tokens].contiguous(), dim=0 + ) + # Reorder kv after pcp allgather. + # Note that there are duplicate decoding tokens after allgather. + key = torch.index_select(key_across_cp, 0, pcp_allgather_restore_idx) + value = torch.index_select(value_across_cp, 0, pcp_allgather_restore_idx) + return key, value + + +def get_pcp_part_indices( + cu_num_tokens: torch.Tensor, + M: int, + N: int, + return_head=False, + return_tail=False, +): + """ + When using PCP, we need to split the KV and Query and select a local shard. + This function helps get the indices of the selected shards. + Args: + cu_num_tokens: cumulative number of tokens. + M: the number of shards to select. + N: the number of shards to split. + return_head: whether to return the indices start from head. + return_tail: whether to return the indices start from tail. + """ + cu_num_tokens_np = np.asarray(cu_num_tokens) # e.g. [0,2,4,8] + starts = cu_num_tokens_np[:-1] # [0, 2, 4] + ends = cu_num_tokens_np[1:] # [2, 4, 8] + select_len = (ends - starts) * M // N # [1, 1, 2], M=1, N=2 + select_num_tokens = cu_num_tokens_np[-1] * M // N + + seq_ids = np.repeat(np.arange(len(select_len)), select_len) # [0,1,2,2] + + start_loc = np.concatenate([[0], np.cumsum(select_len)[:-1]]) # [0,1,2] + local_offsets = np.arange(select_num_tokens) - start_loc[seq_ids] # [0,0,0,1] + head_indices = None + tail_indices = None + if return_head: + head_indices = starts[seq_ids] + local_offsets + if return_tail: + start_loc = ends - select_len + tail_indices = start_loc[seq_ids] + local_offsets + + return head_indices, tail_indices + + +def get_pcp_query_indices(cu_num_tokens: torch.Tensor): + head_indices, tail_indices = get_pcp_part_indices( + cu_num_tokens, + 1, + 2, + return_head=True, + return_tail=True, + ) + return torch.from_numpy(head_indices), torch.from_numpy(tail_indices) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 66697132b365..6a37b67ac4c2 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -687,7 +687,9 @@ def prepare_inputs_padded( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens], causal=True, - dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, + cp_local_seq_lens=common_attn_metadata.cp_local_seq_lens, + cp_local_seq_lens_cpu=common_attn_metadata.cp_local_seq_lens_cpu, + pcp_allgather_restore_idx=common_attn_metadata.pcp_allgather_restore_idx, ) return ( @@ -967,7 +969,9 @@ def prepare_inputs( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, - dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, + cp_local_seq_lens=common_attn_metadata.cp_local_seq_lens, + cp_local_seq_lens_cpu=common_attn_metadata.cp_local_seq_lens_cpu, + pcp_allgather_restore_idx=common_attn_metadata.pcp_allgather_restore_idx, ) return spec_common_attn_metadata, token_indices diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 37ec0fb97e06..96a86cdc7c69 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -138,16 +138,16 @@ def compute_slot_mapping( # NOTE(woosuk): We can't simply use `token_indices // block_size` # here because M (max_model_len) is not necessarily divisible by # block_size. - total_cp_world_size = self.pcp_world_size * self.dcp_world_size - total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank - if total_cp_world_size > 1: + cp_world_size = self.pcp_world_size * self.dcp_world_size + cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank + if cp_world_size > 1: # Note(hc): The DCP implement store kvcache with an interleave # style, the kvcache for the token whose token_idx is i is # always stored on the GPU whose dcp_rank equals i % cp_world_size: # Use a "virtual block" which equals to world_size * block_size # for block_table_indices calculation. - virtual_block_size = self.block_size * total_cp_world_size + virtual_block_size = self.block_size * cp_world_size block_table_indices = ( req_indices * self.max_num_blocks_per_req + positions // virtual_block_size @@ -160,13 +160,13 @@ def compute_slot_mapping( mask = ( virtual_block_offsets // self.cp_kv_cache_interleave_size - % total_cp_world_size - == total_cp_rank + % cp_world_size + == cp_rank ) # Calculate local block_offsets block_offsets = ( virtual_block_offsets - // (total_cp_world_size * self.cp_kv_cache_interleave_size) + // (cp_world_size * self.cp_kv_cache_interleave_size) * self.cp_kv_cache_interleave_size + virtual_block_offsets % self.cp_kv_cache_interleave_size ) diff --git a/vllm/v1/worker/cp_utils.py b/vllm/v1/worker/cp_utils.py index f666c739b0be..aac94721bacf 100644 --- a/vllm/v1/worker/cp_utils.py +++ b/vllm/v1/worker/cp_utils.py @@ -2,7 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import TYPE_CHECKING, Any, cast +import numpy as np +import torch + from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed.parallel_state import get_pcp_group +from vllm.v1.utils import CpuGpuBuffer if TYPE_CHECKING: from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -10,6 +15,284 @@ AttentionLayerBase = object +class PCPManager: + """ + Manager for Prefill Context Parallelism (PCP) metadata and buffers. + + This manager encapsulates all PCP-related buffers and logic so that the + ModelRunner can access them via `self.pcp_manager`. + """ + + def __init__( + self, + pcp_world_size: int, + pcp_rank: int, + max_padded_num_tokens: int, + max_num_reqs: int, + device: torch.device, + pin_memory: bool = False, + ) -> None: + self.pcp_world_size = pcp_world_size + self.pcp_rank = pcp_rank + self.device = device + self.pcp_allgather_restore_idx = CpuGpuBuffer( + max_padded_num_tokens, + dtype=torch.int64, + device=device, + pin_memory=pin_memory, + ) + self.pcp_padded_slot_mapping = torch.empty( + (max_padded_num_tokens,), + dtype=torch.int64, + device=device, + ) + self.num_pcp_pads_cpu_tensor = torch.zeros( + (max_num_reqs,), device="cpu", dtype=torch.int64 + ) + self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy() + self.pcp_unpad_mask_cpu_tensor = torch.zeros( + (max_padded_num_tokens,), + device="cpu", + dtype=torch.bool, + ) + self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() + + def _get_cumsum_and_arange( + self, + num_scheduled_tokens: np.ndarray, + arange_np: np.ndarray, + cumsum_dtype: np.dtype | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Get the cumulative sum and batched arange of the given array. + # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_scheduled_tokens]) + """ + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_scheduled_tokens, dtype=cumsum_dtype) + total_num_tokens = cu_num_tokens[-1] + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat( + cu_num_tokens - num_scheduled_tokens, num_scheduled_tokens + ) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = arange_np[:total_num_tokens] - cumsums_offsets + return cu_num_tokens, arange + + def update_tokens_for_pcp( + self, + num_scheduled_tokens: np.ndarray, + arange_np: np.ndarray, + num_reqs: int, + reorder_batch_threshold: int | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """ + Update token counts and positions for Prefill Context Parallelism (PCP). + + When using Prefill Context Parallelism, each request's prefill sequence is + split across multiple PCP ranks. The splitting strategy used here is the + "DualChunkSwap" style: each request's (padded) sequence is split into + 2 * pcp_world_size chunks and ranks are assigned chunks in an interleaved + head/tail pattern to balance load. + + This function: + - Computes how many tokens each request should be processed by the current + PCP rank (pcp_tokens). + - Computes the flattened positions of those tokens within the local + padded buffer (pcp_positions). + - Updates runner state arrays used to restore original order and mask out + padded tokens after allgather: + - self.num_pcp_pads_cpu: number of pads added per request + - self.pcp_unpad_mask_cpu: boolean mask marking real tokens in the + padded allgather buffer + - self.pcp_allgather_restore_idx: index array used to restore original + ordering after per-rank allgather and interleaving. + + Args: + num_scheduled_tokens: 1D numpy array of length num_reqs containing + the number of new tokens scheduled per request. + arange_np: 1D numpy array of length max_padded_num_tokens used for + efficient batched arange operations. + num_reqs: Total number of requests in the batch. + reorder_batch_threshold: Threshold for decode vs prefill requests. + + Returns: + Tuple (pcp_tokens, pcp_positions): + - pcp_tokens: number of tokens per request that this PCP rank will + actually process (after splitting / replication). + - pcp_positions: flattened positions for those tokens on this rank, + used to build the positions buffer for the model. + + Example: + >>> Assume tokens = [1, 5, 8], pcp_world_size = 2. After _update_tokens_for_pcp. + >>> pcp_rank = 0 get ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7]) + >>> pcp_rank = 1 get ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5]) + >>> Meanwhile, the following results are same for each pcp rank + >>> self.num_pcp_pads_cpu + [1, 3, 0] + >>> self.pcp_unpad_mask_cpu + [True, False, True, True, True, True, True, False, False, + False, True, True, True, True, True, True, True, True] + >>> self.pcp_allgather_resotre_idx + [0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8] + """ + + assert reorder_batch_threshold is not None, ( + "PCP depends on reorder batch to split decode and prefill requests." + ) + num_decode_reqs = sum(num_scheduled_tokens <= reorder_batch_threshold) + num_decode_tokens = sum(num_scheduled_tokens[:num_decode_reqs]) + # DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size). + # We first pad each request's token count up to that multiple. + num_padded_scheduled_tokens = np.ceil( + num_scheduled_tokens / (2 * self.pcp_world_size) + ).astype(np.int32) * (2 * self.pcp_world_size) + # PCP does not split decode requests. For decode requests, we instead + # duplicate the scheduled tokens across the pcp_world_size ranks. + num_padded_scheduled_tokens[:num_decode_reqs] = ( + num_scheduled_tokens[:num_decode_reqs] * self.pcp_world_size + ) + + # Record how many pads were added per request (padded - original). + self.num_pcp_pads_cpu[:num_reqs] = ( + num_padded_scheduled_tokens - num_scheduled_tokens + ) + # cu_padded_tokens: cumulative sum of padded token counts, + # pcp_padded_arange: per-request arange flattened for padded tokens. + cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange( + num_padded_scheduled_tokens, arange_np + ) + # Build the mask that marks which positions in the padded allgather buffer + # correspond to real (unpadded) tokens. + self.pcp_unpad_mask_cpu[: pcp_padded_arange.shape[0]] = ( + pcp_padded_arange + < np.repeat(num_scheduled_tokens, num_padded_scheduled_tokens) + ) + + pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size + # Compute per-request "chunk sizes" for the head/tail splitting. + # For prefill requests, we further split the pcp_tokens into two chunks + # (head and tail). For decode requests, the chunk equals pcp_tokens. + pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) + pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs] + + # Build arange-style helpers for pcp tokens and chunk sizes: + # - pcp_arange gives indices repeated for each token in pcp_tokens + # - pcp_chunk_arange gives indices repeated for each position inside chunks + _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens, arange_np) + _, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes, arange_np) + + # Mask that marks whether a position belongs to the head chunk (True) + # or the tail chunk (False). For decode requests, tail chunk won't exist + # and is handled specially below. + pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, pcp_tokens) + + def get_current_rank_positions( + positions_start_loc: int | np.ndarray, rank: int + ): + """ + Compute flattened positions for the given rank with a given start + offset for each request (positions_start_loc). + + - For head chunks: start at positions_start_loc + rank * chunk_size. + - For tail chunks: start at positions_start_loc + (2*pcp_world_size- rank - + 1) * chunk_size. + - For decode requests: no tail chunks; their positions are filled from the + contiguous (unpadded) `tokens` arange instead (handled after). + """ + positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) + head_start_loc = positions_start_loc + rank * pcp_chunk_sizes + tail_start_loc = ( + positions_start_loc + + (2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes + ) + # Fill head positions using chunk arange offset by head_start_loc. + positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat( + head_start_loc, pcp_chunk_sizes + ) + # Fill tail positions. Note decode requests do not have tail chunks, + # so the tail filling is only for prefill positions. + positions[~pcp_head_chunk_mask] = ( + pcp_chunk_arange[num_decode_tokens:] + + np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:] + ) + return positions + + positions = get_current_rank_positions(0, self.pcp_rank) + # Decode tokens are duplicated only after AG. But their positions are + # same without prefill context parallel. + if num_decode_reqs > 0: + positions[:num_decode_tokens] = self._get_cumsum_and_arange( + num_scheduled_tokens[:num_decode_reqs], arange_np + )[1] + + # Build the restore index used after allgather. + padded_pos_start_loc = np.roll(cu_padded_tokens, 1) + padded_pos_start_loc[0] = 0 + all_positions_lst = [ + get_current_rank_positions(padded_pos_start_loc, rank_i) + for rank_i in range(self.pcp_world_size) + ] + all_positions = np.concatenate(all_positions_lst) + self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = ( + all_positions.argsort() + ) + self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) + return ( + pcp_tokens[:num_reqs], + positions, + ) + + def get_logits_indices(self, cu_num_tokens: np.ndarray, num_reqs: int): + return ( + torch.from_numpy(cu_num_tokens) * self.pcp_world_size + - self.num_pcp_pads_cpu_tensor[:num_reqs] + - 1 + ).to(self.device, non_blocking=True) + + def get_discard_request_mask( + self, + num_computed_tokens_cpu: np.ndarray, + num_scheduled_tokens: np.ndarray, + num_reqs: int, + num_tokens_np: np.ndarray, + ): + return ( + num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens * self.pcp_world_size + - self.num_pcp_pads_cpu[:num_reqs] + ) < num_tokens_np + + def get_padded_slot_mapping(self, num_tokens: int, slot_mapping: torch.Tensor): + # After pcp allgather and restore, there are padded tokens in kv, + # so we need pad slotmapping for alignment. + pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[ + : num_tokens * self.pcp_world_size + ] + cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[ + : num_tokens * self.pcp_world_size + ] + pcp_padded_slot_mapping.fill_(-1) + pcp_padded_slot_mapping[cp_unpad_mask] = slot_mapping + return pcp_padded_slot_mapping + + def get_restore_hidden_states( + self, hidden_states: torch.Tensor, num_tokens_unpadded: int + ): + # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. + hidden_states = get_pcp_group().all_gather( + hidden_states[:num_tokens_unpadded], + 0, + ) + restore_idx = self.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]] + return torch.index_select( + hidden_states, + 0, + restore_idx, + ) + + def check_attention_cp_compatibility(vllm_config: VllmConfig) -> None: pcp_size = vllm_config.parallel_config.prefill_context_parallel_size dcp_size = vllm_config.parallel_config.decode_context_parallel_size diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4117354dd8de..cc33bce7dbd1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -43,6 +43,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( get_dcp_group, + get_pcp_group, get_pp_group, get_tp_group, graph_capture, @@ -108,7 +109,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, create_fast_prefill_custom_backend, - get_dcp_local_seq_lens, + get_cp_local_seq_lens, reorder_batch_to_split_decodes_and_prefills, split_attn_metadata, ) @@ -152,7 +153,7 @@ from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext -from vllm.v1.worker.cp_utils import check_attention_cp_compatibility +from vllm.v1.worker.cp_utils import PCPManager, check_attention_cp_compatibility from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -316,7 +317,11 @@ def __init__( # Always set to false after the first forward pass self.calculate_kv_scales = self.cache_config.calculate_kv_scales self.dcp_world_size = self.parallel_config.decode_context_parallel_size + self.pcp_world_size = self.parallel_config.prefill_context_parallel_size + self.cp_world_size = self.dcp_world_size * self.pcp_world_size self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group + self.pcp_rank = 0 if self.pcp_world_size <= 1 else get_pcp_group().rank_in_group + self.cp_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -481,25 +486,38 @@ def __init__( # Cache the device properties. self._init_device_properties() + if self.pcp_world_size > 1: + # NOTE For PCP, we will pad the tokens of each request + # to a multiple of 2 * pcp_size that is possible greater + # than the max_num_batched_tokens. + max_padded_num_tokens = ( + self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_world_size + ) + else: + max_padded_num_tokens = self.max_num_tokens + # Persistent buffers for CUDA graphs. - self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) - self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) + self.input_ids = self._make_buffer(max_padded_num_tokens, dtype=torch.int32) + self.positions = self._make_buffer(max_padded_num_tokens, dtype=torch.int64) self.query_start_loc = self._make_buffer( self.max_num_reqs + 1, dtype=torch.int32 ) self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) - if self.dcp_world_size > 1: - self.dcp_local_seq_lens = self._make_buffer( + if self.cp_world_size > 1: + self.cp_local_seq_lens = self._make_buffer( self.max_num_reqs, dtype=torch.int32 ) # Because inputs_embeds may be bfloat16 and we don't need a numpy # version of this tensor, avoid a RuntimeError by not creating a # numpy buffer. self.inputs_embeds = self._make_buffer( - self.max_num_tokens, self.inputs_embeds_size, dtype=self.dtype, numpy=False + max_padded_num_tokens, + self.inputs_embeds_size, + dtype=self.dtype, + numpy=False, ) - self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + self.is_token_ids = self._make_buffer(max_padded_num_tokens, dtype=torch.bool) self.discard_request_mask = self._make_buffer( self.max_num_reqs, dtype=torch.bool ) @@ -512,7 +530,20 @@ def __init__( # Only relevant for multimodal models if self.supports_mm_inputs: - self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + self.is_mm_embed = self._make_buffer( + max_padded_num_tokens, dtype=torch.bool + ) + + # Manager for Prefill Context Parallism + if self.pcp_world_size > 1: + self.pcp_manager = PCPManager( + self.pcp_world_size, + self.pcp_rank, + max_padded_num_tokens, + self.max_num_reqs, + self.device, + self.pin_memory, + ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -527,7 +558,7 @@ def __init__( # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = self._make_buffer( - (3, self.max_num_tokens + 1), dtype=torch.int64 + (3, max_padded_num_tokens + 1), dtype=torch.int64 ) # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) @@ -543,7 +574,7 @@ def __init__( # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context self.arange_np = np.arange( - max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), + max(self.max_num_reqs + 1, self.max_model_len, max_padded_num_tokens), dtype=np.int64, ) @@ -557,7 +588,7 @@ def __init__( self.kv_sharing_fast_prefill_logits_indices = None if self.cache_config.kv_sharing_fast_prefill: self.kv_sharing_fast_prefill_logits_indices = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device=self.device + max_padded_num_tokens, dtype=torch.int32, device=self.device ) self.uniform_decode_query_len = 1 + self.num_spec_tokens @@ -1329,6 +1360,32 @@ def _prepare_inputs( out=positions_np, ) + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + + if self.pcp_world_size > 1: + num_scheduled_tokens[:num_reqs], pcp_positions = ( + self.pcp_manager.update_tokens_for_pcp( + num_scheduled_tokens[:num_reqs], + self.arange_np, + self.input_batch.num_reqs, + self.reorder_batch_threshold, + ) + ) + + # Re-update after PCP split sequences. + total_num_scheduled_tokens = sum(num_scheduled_tokens) + scheduler_output.total_num_scheduled_tokens = total_num_scheduled_tokens + + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + cu_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + positions_np = self.positions.np[:total_num_scheduled_tokens] + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + pcp_positions[:total_num_scheduled_tokens], + out=positions_np, + ) + # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -1404,9 +1461,6 @@ def _prepare_inputs( output_idx += num_sched - self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) - # Prepare the attention metadata. self.query_start_loc.np[0] = 0 self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens @@ -1428,9 +1482,19 @@ def _prepare_inputs( # Record which requests should not be sampled, # so that we could clear the sampled tokens before returning - self.discard_request_mask.np[:num_reqs] = ( - self.seq_lens.np[:num_reqs] < num_tokens_np - ) + if self.pcp_world_size > 1: + self.discard_request_mask.np[:num_reqs] = ( + self.pcp_manager.get_discard_request_mask( + num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu, + num_scheduled_tokens=num_scheduled_tokens, + num_reqs=num_reqs, + num_tokens_np=num_tokens_np, + ) + ) + else: + self.discard_request_mask.np[:num_reqs] = ( + self.seq_lens.np[:num_reqs] < num_tokens_np + ) self.discard_request_mask.copy_to_gpu(num_reqs) # Copy the tensors to the GPU. @@ -1464,10 +1528,15 @@ def _prepare_inputs( # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 + if self.pcp_world_size > 1: + logits_indices = self.pcp_manager.get_logits_indices( + cu_num_tokens, num_reqs + ) num_draft_tokens = None spec_decode_metadata = None num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) else: + assert self.pcp_world_size == 1, "PCP not support spec decode now" # Get the number of draft tokens for each request. # Iterate over the dictionary rather than all requests since not all # requests have draft tokens. @@ -1535,6 +1604,10 @@ def _build_attention_metadata( if len(self.kv_cache_config.kv_cache_groups) == 0: return {}, None + assert num_tokens_padded is None or self.pcp_world_size == 1, ( + "PCP not support pad attn now" + ) + num_tokens_padded = num_tokens_padded or num_tokens num_reqs_padded = num_reqs_padded or num_reqs assert num_reqs_padded is not None and num_tokens_padded is not None @@ -1563,6 +1636,13 @@ def _build_attention_metadata( def _get_block_table_and_slot_mapping(kv_cache_gid: int): assert num_reqs_padded is not None and num_tokens_padded is not None kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec + + maybe_pcp_full_tokens = ( + num_tokens_padded + if self.pcp_world_size == 1 + else num_tokens * self.pcp_world_size + - sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs]) + ) if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): blk_table_tensor = torch.zeros( (num_reqs_padded, 1), @@ -1577,12 +1657,18 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): else: blk_table = self.input_batch.block_table[kv_cache_gid] blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded) - slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded] - - # Fill unused with -1. Needed for reshape_and_cache in full cuda - # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID - slot_mapping[num_tokens:num_tokens_padded].fill_(-1) - blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) + slot_mapping = blk_table.slot_mapping.gpu[:maybe_pcp_full_tokens] + + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID + if self.pcp_world_size == 1: + slot_mapping[num_tokens:num_tokens_padded].fill_(-1) + blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) + if self.pcp_world_size > 1: + slot_mapping = self.pcp_manager.get_padded_slot_mapping( + num_tokens, + slot_mapping, + ) return blk_table_tensor, slot_mapping @@ -1604,20 +1690,25 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): causal=True, ) - if self.dcp_world_size > 1: - self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( + if self.cp_world_size > 1: + self.cp_local_seq_lens.cpu[:num_reqs] = get_cp_local_seq_lens( self.seq_lens.cpu[:num_reqs], - self.dcp_world_size, - self.dcp_rank, + self.cp_world_size, + self.cp_rank, self.parallel_config.cp_kv_cache_interleave_size, ) - self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0) - self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded) + self.cp_local_seq_lens.cpu[num_reqs:].fill_(0) + self.cp_local_seq_lens.copy_to_gpu(num_reqs_padded) - cm_base.dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded] - cm_base.dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[ - :num_reqs_padded - ] + cm_base.cp_local_seq_lens = self.cp_local_seq_lens.gpu[:num_reqs_padded] + cm_base.cp_local_seq_lens_cpu = self.cp_local_seq_lens.cpu[:num_reqs_padded] + + if self.pcp_world_size > 1: + cm_base.pcp_allgather_restore_idx = ( + self.pcp_manager.pcp_allgather_restore_idx.gpu[ + : num_tokens * self.pcp_world_size + ] + ) if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill: cm_base.num_logits_indices = logits_indices.size(0) @@ -3111,6 +3202,9 @@ def execute_model( scheduler_output, num_scheduled_tokens_np, ) + if self.pcp_world_size > 1: + max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens cascade_attn_prefix_lens = None # Disable cascade attention when using microbatching (DBO) @@ -3235,6 +3329,15 @@ def execute_model( hidden_states = model_output aux_hidden_states = None + if self.pcp_world_size > 1: + # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. + hidden_states = self.pcp_manager.get_restore_hidden_states( + hidden_states, + num_tokens_unpadded, + ) + # Restore total_num_scheduled_tokens. + scheduler_output.total_num_scheduled_tokens = num_scheduled_tokens if not self.broadcast_pp_output: # Common case. if not get_pp_group().is_last_rank: @@ -4640,7 +4743,7 @@ def profile_run(self) -> None: # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states = self._dummy_run( - self.max_num_tokens, is_profile=True + self.max_num_tokens // self.pcp_world_size, is_profile=True ) if get_pp_group().is_last_rank: if self.is_pooling_model: