From d09cda7876d0ed18565bad0471bf6819b4e7237a Mon Sep 17 00:00:00 2001 From: FENP Date: Fri, 17 Oct 2025 11:43:42 +0800 Subject: [PATCH 01/15] Init support PCP with FlashInfer. Co-authored-by: FENP Co-authored-by: QiuChunshuo Co-authored-by: LookAround Signed-off-by: FENP Signed-off-by: QiuChunshuo Signed-off-by: LookAround --- tests/distributed/test_context_parallel.py | 29 ++- vllm/attention/backends/abstract.py | 13 ++ vllm/attention/ops/common.py | 30 +++ vllm/config/parallel.py | 8 +- vllm/distributed/parallel_state.py | 64 +++++- vllm/engine/arg_utils.py | 13 +- .../model_executor/layers/fused_moe/config.py | 31 ++- vllm/model_executor/layers/fused_moe/layer.py | 8 + vllm/platforms/cuda.py | 7 + vllm/v1/attention/backends/flashinfer.py | 191 +++++++++++++--- vllm/v1/attention/backends/mla/common.py | 46 ++-- vllm/v1/attention/backends/utils.py | 19 ++ vllm/v1/core/kv_cache_coordinator.py | 17 ++ vllm/v1/core/kv_cache_manager.py | 6 +- vllm/v1/core/sched/scheduler.py | 2 + vllm/v1/core/single_type_kv_cache_manager.py | 19 +- vllm/v1/executor/multiproc_executor.py | 17 +- vllm/v1/kv_cache_interface.py | 3 + vllm/v1/worker/block_table.py | 24 +- vllm/v1/worker/gpu_model_runner.py | 216 ++++++++++++++++-- vllm/v1/worker/gpu_worker.py | 1 + 21 files changed, 655 insertions(+), 109 deletions(-) diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 5495640af07e..3f70745a63f4 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -30,6 +30,7 @@ class ParallelSetup(NamedTuple): tp_size: int pp_size: int dcp_size: int + pcp_size: int eager_mode: bool chunked_prefill: bool @@ -37,6 +38,7 @@ class ParallelSetup(NamedTuple): class CPTestOptions(NamedTuple): multi_node_only: bool load_format: str | None = None + attn_backend: str = "FLASH_ATTN" @dataclass @@ -52,20 +54,25 @@ def detailed( tp_base: int = 4, pp_base: int = 1, dcp_base: int = 1, + pcp_base: int = 1, multi_node_only: bool = False, runner: RunnerOption = "auto", load_format: str | None = None, + attn_backend: str = "FLASH_ATTN", ): parallel_setups = [] for eager_mode_val in [False]: for pp_multiplier in [1]: - for dcp_multiplier in [0.5, 1]: + # TODO(qcs): Test the effect of mixed activation + # when CP and DCP are compatible. + for pcp_multiplier, dcp_multiplier in zip([1, 2, 1], [0.5, 1, 1]): for chunked_prefill_val in [True]: parallel_setups.append( ParallelSetup( tp_size=tp_base, pp_size=pp_multiplier * pp_base, dcp_size=int(dcp_multiplier * tp_base), + pcp_size=int(pcp_multiplier * pcp_base), eager_mode=eager_mode_val, chunked_prefill=chunked_prefill_val, ) @@ -75,7 +82,9 @@ def detailed( distributed_backends=["mp"], runner=runner, test_options=CPTestOptions( - multi_node_only=multi_node_only, load_format=load_format + multi_node_only=multi_node_only, + load_format=load_format, + attn_backend=attn_backend, ), ) @@ -108,11 +117,12 @@ def _compare_cp_with_tp( tp_size, pp_size, dcp_size, + pcp_size, eager_mode, chunked_prefill, ) = parallel_setup - multi_node_only, load_format = test_options + multi_node_only, load_format, attn_backend = test_options model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_transformers_version(on_fail="skip") @@ -155,7 +165,7 @@ def _compare_cp_with_tp( "--max-model-len", "2048", "--max-num-seqs", - "8", + "16", ] if chunked_prefill: common_args.append("--enable-chunked-prefill") @@ -172,6 +182,10 @@ def _compare_cp_with_tp( if hf_overrides: common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + cp_env = tp_env = { + "VLLM_ATTENTION_BACKEND": attn_backend, + } + cp_args = [ *common_args, "--tensor-parallel-size", @@ -180,6 +194,8 @@ def _compare_cp_with_tp( str(pp_size), "--decode-context-parallel-size", str(dcp_size), + "--prefill-context-parallel-size", + str(pcp_size), "--distributed-executor-backend", distributed_backend, ] @@ -198,12 +214,15 @@ def _compare_cp_with_tp( model_id, cp_args, tp_args, + cp_env, + tp_env, method=method, max_wait_seconds=720, ) CP_TEXT_GENERATION_MODELS = { + # [MLA attention only] "deepseek-ai/DeepSeek-V2-Lite-Chat": [ CPTestSettings.detailed(), CPTestSettings.detailed(tp_base=2), @@ -211,6 +230,8 @@ def _compare_cp_with_tp( "bigcode/gpt_bigcode-santacoder": [ CPTestSettings.detailed(), CPTestSettings.detailed(tp_base=2), + CPTestSettings.detailed(attn_backend="FLASHINFER"), + CPTestSettings.detailed(tp_base=2, attn_backend="FLASHINFER"), ], } diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index e9c6a278a941..3a96bd7d6fd8 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -127,6 +127,9 @@ class AttentionImpl(ABC, Generic[T]): dcp_world_size: int dcp_rank: int + pcp_world_size: int + pcp_rank: int + def __new__(cls, *args, **kwargs): # use __new__ so that all subclasses will call this self = super().__new__(cls) @@ -139,6 +142,16 @@ def __new__(cls, *args, **kwargs): # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 + try: + from vllm.distributed.parallel_state import get_pcp_group + + self.pcp_world_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group + except AssertionError: + # PCP might not be initialized in testing + self.pcp_world_size = 1 + self.pcp_rank = 0 + self.need_to_return_lse_for_decode = ( self.dcp_world_size > 1 and self.can_return_lse_for_decode ) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index b6b7ecd2552a..217f32b8baec 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -205,6 +205,36 @@ def cp_lse_ag_out_rs( return out +def cp_lse_ag_out_ar( + cp_attn_out: torch.Tensor, + cp_attn_lse: torch.Tensor, + cp_group: GroupCoordinator, + ctx: CPTritonContext = None, +): + """ + cp_attn_out: [ B, H, D ] + cp_attn_lse: [ B, H ] + """ + if cp_group.world_size == 1: + return cp_attn_out + + if ctx is None: + ctx = CPTritonContext() + + lses = torch.empty( + (cp_group.world_size,) + cp_attn_lse.shape, + dtype=cp_attn_lse.dtype, + device=cp_attn_lse.device, + ) + + cp_attn_lse = cp_attn_lse.contiguous() + lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) + out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) + assert out.is_contiguous() + out = cp_group.all_reduce(out) + return out + + @triton.jit def _pack_seq_kernel( x_ptr, # [N, D] diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 953aa1a147de..760e01686dd5 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -71,6 +71,8 @@ class ParallelConfig: """Number of pipeline parallel groups.""" tensor_parallel_size: int = 1 """Number of tensor parallel groups.""" + prefill_context_parallel_size: int = 1 + """Number of prefill context parallel groups.""" data_parallel_size: int = 1 """Number of data parallel groups. MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.""" @@ -467,7 +469,11 @@ def __post_init__(self) -> None: ) # Continue with the rest of the initialization - self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size + self.world_size = ( + self.pipeline_parallel_size + * self.tensor_parallel_size + * self.prefill_context_parallel_size + ) if self.distributed_executor_backend == "external_launcher": logger.info("Using external launcher for distributed inference.") diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 132fb9049163..b00d916e83b8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1085,6 +1085,24 @@ def get_pp_group() -> GroupCoordinator: return _PP +_PCP: GroupCoordinator | None = None + + +def get_pcp_group() -> GroupCoordinator: + assert _PCP is not None, "prefill context parallel group is not initialized" + return _PCP + + +def get_prefill_context_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_pcp_group().world_size + + +def get_prefill_context_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_pcp_group().rank_in_group + + @deprecated( "`get_pipeline_model_parallel_group` has been replaced with " "`get_pp_group` and may be removed in v0.12. Please use " @@ -1207,6 +1225,7 @@ def init_distributed_environment( def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + context_model_parallel_size: int = 1, decode_context_model_parallel_size: int | None = 1, backend: str | None = None, ) -> None: @@ -1256,7 +1275,11 @@ def initialize_model_parallel( # to get group_ranks for each dimension, transpose that dimension to the # last dimension, then reshape to 2D, then unbind the last dimension all_ranks = torch.arange(world_size).reshape( - -1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size + -1, + data_parallel_size, + pipeline_model_parallel_size, + context_model_parallel_size, + tensor_model_parallel_size, ) # noqa # Build the tensor model-parallel groups. @@ -1295,7 +1318,7 @@ def initialize_model_parallel( global _PP assert _PP is None, "pipeline model parallel group is already initialized" group_ranks = ( - all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0) + all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] _PP = init_model_parallel_group( @@ -1304,7 +1327,7 @@ def initialize_model_parallel( global _DP assert _DP is None, "data parallel group is already initialized" - group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0) + group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] _DP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, group_name="dp" @@ -1314,7 +1337,12 @@ def initialize_model_parallel( assert _EP is None, "expert parallel group is already initialized" group_ranks = ( all_ranks.transpose(1, 2) - .reshape(-1, data_parallel_size * tensor_model_parallel_size) + .reshape( + -1, + data_parallel_size + * tensor_model_parallel_size + * context_model_parallel_size, + ) .unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] @@ -1322,21 +1350,33 @@ def initialize_model_parallel( group_ranks, get_world_group().local_rank, backend, group_name="ep" ) + global _PCP + assert _PCP is None, "prefill context parallel group is already initialized" + group_ranks = ( + all_ranks.transpose(3, 4).reshape(-1, context_model_parallel_size).unbind(0) + ) + group_ranks = [x.tolist() for x in group_ranks] + _PCP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="pcp" + ) + logger.info( "rank %s in world size %s is assigned as " - "DP rank %s, PP rank %s, TP rank %s, EP rank %s", + "DP rank %s, PP rank %s, TP rank %s, EP rank %s, PCP rank %s", rank, world_size, _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, _EP.rank_in_group, + _PCP.rank_in_group, ) def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, + prefill_context_model_parallel_size: int = 1, decode_context_model_parallel_size: int | None = 1, backend: str | None = None, ) -> None: @@ -1349,6 +1389,7 @@ def ensure_model_parallel_initialized( initialize_model_parallel( tensor_model_parallel_size, pipeline_model_parallel_size, + prefill_context_model_parallel_size, decode_context_model_parallel_size, backend, ) @@ -1365,6 +1406,12 @@ def ensure_model_parallel_initialized( f"got: {pp_world_size=} vs. " f"wanted: {pipeline_model_parallel_size=}" ) + pcp_world_size = get_pcp_group().world_size + assert pcp_world_size == prefill_context_model_parallel_size, ( + "prefill context parallel group already initialized, but of unexpected size: " + f"{pcp_world_size=} vs. " + f"{prefill_context_model_parallel_size=}" + ) def prepare_communication_buffer_for_model(model: torch.nn.Module): @@ -1382,6 +1429,8 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module): _DP.prepare_communication_buffer_for_model(model) if _EP is not None: _EP.prepare_communication_buffer_for_model(model) + if _PCP is not None: + _PCP.prepare_communication_buffer_for_model(model) def model_parallel_is_initialized(): @@ -1471,6 +1520,11 @@ def destroy_model_parallel(): _EP.destroy() _EP = None + global _PCP + if _PCP: + _PCP.destroy() + _PCP = None + def destroy_distributed_environment(): global _WORLD, _NODE_COUNT diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 917d0ec9f7f3..18e981d9b847 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -371,6 +371,7 @@ class EngineArgs: # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size + prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: int | None = None @@ -722,14 +723,19 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument( "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"] ) + parallel_group.add_argument( + "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"] + ) + parallel_group.add_argument( + "--prefill-context-parallel-size", + "-pcp", + **parallel_kwargs["prefill_context_parallel_size"], + ) parallel_group.add_argument( "--decode-context-parallel-size", "-dcp", **parallel_kwargs["decode_context_parallel_size"], ) - parallel_group.add_argument( - "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"] - ) parallel_group.add_argument( "--data-parallel-rank", "-dpn", @@ -1466,6 +1472,7 @@ def create_engine_config( parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, + prefill_context_parallel_size=self.prefill_context_parallel_size, data_parallel_size=self.data_parallel_size, data_parallel_rank=self.data_parallel_rank or 0, data_parallel_external_lb=data_parallel_external_lb, diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 38ea6acc0fc5..32b2b6573272 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -7,7 +7,11 @@ import vllm.envs as envs from vllm.config import ParallelConfig -from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank +from vllm.distributed import ( + get_dp_group, + get_prefill_context_model_parallel_rank, + get_tensor_model_parallel_rank, +) from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_DTYPES, @@ -634,9 +638,11 @@ def biased_moe_quant_config( @dataclass class FusedMoEParallelConfig: tp_size: int + pcp_size: int dp_size: int ep_size: int tp_rank: int + pcp_rank: int dp_rank: int ep_rank: int @@ -664,7 +670,10 @@ def use_deepep_ll_kernels(self): @staticmethod def make( - tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig + tp_size_: int, + dp_size_: int, + pcp_size_: int, + vllm_parallel_config: ParallelConfig, ) -> "FusedMoEParallelConfig": """ Determine MoE parallel configuration. Based on the input `tp_size_`, @@ -672,7 +681,8 @@ def make( level's of parallelism to use in the fused moe layer. Args: - tp_size_ (int): `tp_size` passed into the FusedMoE constructor. + tp_size_ (int): `tp_size` pa use_ep = (dp_size_ * tp_size_ssed into + the FusedMoE constructor. dp_size_ (int): `dp_size` passed into the FusedMoE constructor. vllm_parallel_config (ParallelConfig): vLLM's parallel config object which contains the `enable_expert_parallel` flag. @@ -745,16 +755,23 @@ def flatten_tp_across_dp(dp_rank: int): tp_rank = dp_rank * tp_size_ + tp_rank return tp_size, tp_rank - use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel + use_ep = ( + dp_size_ * tp_size_ * pcp_size_ > 1 + and vllm_parallel_config.enable_expert_parallel + ) dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + pcp_size = pcp_size_ + pcp_rank = get_prefill_context_model_parallel_rank() if pcp_size_ > 1 else 0 if not use_ep: return FusedMoEParallelConfig( tp_size=tp_size, tp_rank=tp_rank, + pcp_size=pcp_size, + pcp_rank=pcp_rank, dp_size=dp_size, dp_rank=dp_rank, ep_size=1, @@ -766,11 +783,13 @@ def flatten_tp_across_dp(dp_rank: int): assert use_ep # In EP, each device owns a set of experts fully. There is no tensor # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. - ep_size = tp_size - ep_rank = tp_rank + ep_size = tp_size * pcp_size + ep_rank = tp_rank + tp_size * pcp_rank return FusedMoEParallelConfig( tp_size=1, tp_rank=0, + pcp_size=1, + pcp_rank=0, dp_size=dp_size, dp_rank=dp_rank, ep_size=ep_size, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3bb544a49f3a..12edc274ad16 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -18,6 +18,7 @@ from vllm.distributed import ( get_dp_group, get_ep_group, + get_prefill_context_model_parallel_world_size, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) @@ -1061,6 +1062,7 @@ def __init__( tp_size: int | None = None, ep_size: int | None = None, dp_size: int | None = None, + pcp_size: int | None = None, prefix: str = "", custom_routing_function: Callable | None = None, scoring_func: str = "softmax", @@ -1098,6 +1100,11 @@ def __init__( tp_size if tp_size is not None else get_tensor_model_parallel_world_size() ) dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size + pcp_size_ = ( + pcp_size + if pcp_size is not None + else get_prefill_context_model_parallel_world_size() + ) self.is_sequence_parallel = is_sequence_parallel self.sp_size = tp_size_ if is_sequence_parallel else 1 @@ -1105,6 +1112,7 @@ def __init__( self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( tp_size_=tp_size_, dp_size_=dp_size_, + pcp_size_=pcp_size_, vllm_parallel_config=vllm_config.parallel_config, ) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index a6b9df7c1446..c1b4729c58e7 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -206,6 +206,13 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: ) compilation_config.cudagraph_mode = CUDAGraphMode.NONE + if ( + compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and parallel_config.prefill_context_parallel_size > 1 + ): + logger.info("Prefill Context Parallel: disabling cudagraphs since PCP.") + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + @classmethod def get_current_memory_usage( cls, device: torch.types.Device | None = None diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index cd54b964c41f..aeef2db86677 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -10,6 +10,7 @@ from flashinfer import ( BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, MultiLevelCascadeAttentionWrapper, ) from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache @@ -22,7 +23,9 @@ AttentionType, MultipleOf, ) +from vllm.attention.ops.common import cp_lse_ag_out_ar from vllm.config import CUDAGraphMode, VllmConfig +from vllm.distributed.parallel_state import get_pcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -269,6 +272,9 @@ class FlashInferMetadata: qo_indptr_gpu: torch.Tensor | None = None paged_kv_indptr_gpu: torch.Tensor | None = None + # For context parallel + pcp_allgather_restore_idx: torch.Tensor | None = None + class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = ( @@ -326,6 +332,14 @@ def __init__( self.compilation_config.max_capture_size, ) + try: + self.pcp_world_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group + except AssertionError: + # PCP might not be initialized in testing + self.pcp_world_size = 1 + self.pcp_rank = 0 + self.num_qo_heads = self.model_config.get_num_attention_heads( self.vllm_config.parallel_config ) @@ -413,9 +427,14 @@ def _get_workspace_buffer(self): def _get_prefill_wrapper(self): if self._prefill_wrapper is None: - self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), get_kv_cache_layout() - ) + if self.pcp_world_size > 1: + self._prefill_wrapper = BatchPrefillWithRaggedKVCacheWrapper( + self._get_workspace_buffer(), get_kv_cache_layout() + ) + else: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), get_kv_cache_layout() + ) return self._prefill_wrapper def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False): @@ -482,7 +501,12 @@ def build( max_seq_len = common_attn_metadata.max_seq_len seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu + if self.pcp_world_size > 1: + seq_lens_cpu = seq_lens_cpu // self.pcp_world_size + ( + self.pcp_rank < seq_lens_cpu % self.pcp_world_size + ) seq_lens_np = seq_lens_cpu.numpy() + num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu block_table_tensor = common_attn_metadata.block_table_tensor num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size @@ -573,6 +597,11 @@ def build( has_sinks=self.has_sinks, has_spec=uses_spec_reorder, ) + if self.pcp_world_size > 1 and (prefill_use_trtllm or decode_use_trtllm): + raise NotImplementedError( + "Trtllm not support lse, please use flash attention " + "or disable attention sinks." + ) if not (prefill_use_trtllm and decode_use_trtllm): if self.has_sinks: @@ -615,6 +644,7 @@ def build( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, + pcp_allgather_restore_idx=common_attn_metadata.pcp_allgather_restore_idx, ) qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu @@ -660,7 +690,6 @@ def build( qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start] ) paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] - # Recompute max_q_len for the slice of requests we are using # for prefills. This can be different from max_q_len when # we have a non-uniform batch with some short decodes offloaded @@ -669,24 +698,69 @@ def build( attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) if not attn_metadata.prefill_use_trtllm: - attn_metadata.prefill_wrapper.plan( - qo_indptr_cpu, - paged_kv_indptr_cpu, - paged_kv_indices, - paged_kv_last_page_len_cpu[prefill_start:], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - causal=True, - sm_scale=self.sm_scale, - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, - kv_data_type=self.kv_cache_dtype, - fixed_split_size=self.prefill_fixed_split_size, - disable_split_kv=self.disable_split_kv, - ) + if self.pcp_world_size > 1: + assert common_attn_metadata.query_positions is not None + prefill_num_computed_tokens_cpu = num_computed_tokens_cpu[ + prefill_start: + ] + kv_indptr_cpu = qo_indptr_cpu * self.pcp_world_size + # init custom mask for head-tail query order + q_pos = torch.from_numpy( + common_attn_metadata.query_positions[prefill_start:] + ).long() + kv_lens = ( + prefill_num_computed_tokens_cpu + + kv_indptr_cpu[1:] + - kv_indptr_cpu[:-1] + ) + max_q_lens = int(q_pos.max().item()) + 1 + max_kv_lens = int(kv_lens.max().item()) + mask = torch.ones( + max_q_lens, max_kv_lens, dtype=torch.bool + ).tril() + selected_rows = torch.index_select(mask, 0, q_pos) + col_indices = torch.arange(max_kv_lens).expand( + q_pos.size(0), -1 + ) + valid_mask = col_indices < torch.repeat_interleave( + kv_lens, qo_indptr_cpu[1:] - qo_indptr_cpu[:-1] + ).unsqueeze(1) + custom_mask = selected_rows[valid_mask].to(self.device) + + attn_metadata.prefill_wrapper.plan( + qo_indptr_cpu.to(self.device), + kv_indptr_cpu.to(self.device), + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + custom_mask=custom_mask, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + fixed_split_size=self.prefill_fixed_split_size, + disable_split_kv=self.disable_split_kv, + ) + else: + attn_metadata.prefill_wrapper.plan( + qo_indptr_cpu, + paged_kv_indptr_cpu, + paged_kv_indices, + paged_kv_last_page_len_cpu[prefill_start:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + fixed_split_size=self.prefill_fixed_split_size, + disable_split_kv=self.disable_split_kv, + ) else: attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( self.device, non_blocking=True @@ -757,6 +831,8 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: # TODO: The cascade wrapper currently does not support setting # kv cache dtype to something different from query dtype. return False + if self.pcp_world_size > 1: + return False # TODO: Cascade attention doesn't work, disable it for now # return use_cascade_attention(*args, **kwargs) return False @@ -926,6 +1002,24 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens + key_across_cp = get_pcp_group().all_gather(key.contiguous(), dim=0) + value_across_cp = get_pcp_group().all_gather(value.contiguous(), dim=0) + if ( + self.pcp_world_size > 1 + and attn_metadata.pcp_allgather_restore_idx is not None + ): + # Reorder kv after cp allgather. + # Note that there are duplicate decoding tokens, + # but we only save the first one in kvcache. + key_across_cp = torch.index_select( + key_across_cp, 0, attn_metadata.pcp_allgather_restore_idx + ) + value_across_cp = torch.index_select( + value_across_cp, 0, attn_metadata.pcp_allgather_restore_idx + ) + key = key_across_cp + value = value_across_cp + if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. @@ -981,17 +1075,28 @@ def forward( assert prefill_wrapper is not None if not attn_metadata.prefill_use_trtllm: - assert prefill_wrapper._causal assert prefill_wrapper._window_left == self.window_left assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale - prefill_wrapper.run( - prefill_query, - kv_cache_permute, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=output[num_decode_tokens:], - ) + if self.pcp_world_size > 1: + # NOTE(qcs): Allgather causes duplicate decoding tokens. + prefill_key = key[num_decode_tokens * self.pcp_world_size :] + prefill_value = value[num_decode_tokens * self.pcp_world_size :] + prefill_wrapper.run( + prefill_query, + prefill_key, + prefill_value, + out=output[num_decode_tokens:], + ) + else: + assert prefill_wrapper._causal + prefill_wrapper.run( + prefill_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[num_decode_tokens:], + ) else: # prefill_query may be non-contiguous prefill_query = prefill_query.contiguous() @@ -1067,13 +1172,25 @@ def forward( assert decode_wrapper._window_left == self.window_left assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale - decode_wrapper.run( - decode_query, - kv_cache_permute, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=output[:num_decode_tokens], - ) + if self.pcp_world_size > 1: + out, lse = decode_wrapper.run( + decode_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + return_lse=True, + ) + output[:num_decode_tokens] = cp_lse_ag_out_ar( + out, lse, get_pcp_group() + ) + else: + decode_wrapper.run( + decode_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[:num_decode_tokens], + ) else: # decode_query may be non-contiguous decode_query = decode_query.contiguous() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 51a9032f4269..567b61db1bc7 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -976,9 +976,9 @@ def build( def reorg_kvcache( allgatered_kv_c_normed: torch.Tensor, allgatered_k_pe: torch.Tensor, - cp_chunk_seq_lens_lst: list[int], + dcp_chunk_seq_lens_lst: list[int], origin_context_lens: list[int], - cp_world_size: int, + dcp_world_size: int, sum_seq_len: int, max_seq_len: int, chunk_size: int, @@ -986,14 +986,14 @@ def reorg_kvcache( toks: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ - reorg kvcache after cp local gather to tp layout for attn kernel. + reorg kvcache after dcp local gather to tp layout for attn kernel. Args: - cp_chunk_seq_lens_lst: chunk context lengths under CP. - origin_context_lens: origin full context lengths under CP. - cp_world_size: CP size. - sum_seq_len: the sum of cp_chunk_seq_lens_lst. - max_seq_len: the max value of cp_chunk_seq_lens_lst. + dcp_chunk_seq_lens_lst: chunk context lengths under DCP. + origin_context_lens: origin full context lengths under DCP. + dcp_world_size: DCP size. + sum_seq_len: the sum of dcp_chunk_seq_lens_lst. + max_seq_len: the max value of dcp_chunk_seq_lens_lst. chunk_size: equals to max_context_chunk from chunked_context_metadata building. chunk_idx: chunk idx of chunked_prefill. @@ -1003,37 +1003,37 @@ def reorg_kvcache( k_pe_segments = [] src_token_idx = 0 max_seq_len_check = 0 - for cp_chunk_seq_len, origin_context_len in zip( - cp_chunk_seq_lens_lst, origin_context_lens + for dcp_chunk_seq_len, origin_context_len in zip( + dcp_chunk_seq_lens_lst, origin_context_lens ): chunk_context_len = chunk_size - if cp_chunk_seq_len != 0: + if dcp_chunk_seq_len != 0: chunk_context_len = min( chunk_context_len, origin_context_len - chunk_size * chunk_idx ) - cp_target_rank = (chunk_context_len - 1) % cp_world_size + dcp_target_rank = (chunk_context_len - 1) % dcp_world_size cur_seq_len = 0 - for rank in range(cp_world_size): - if rank > cp_target_rank and cp_chunk_seq_len: - real_cp_chunk_seq_len = cp_chunk_seq_len - 1 + for rank in range(dcp_world_size): + if rank > dcp_target_rank and dcp_chunk_seq_len: + real_dcp_chunk_seq_len = dcp_chunk_seq_len - 1 else: - real_cp_chunk_seq_len = cp_chunk_seq_len - if real_cp_chunk_seq_len: + real_dcp_chunk_seq_len = dcp_chunk_seq_len + if real_dcp_chunk_seq_len: kv_c_segment = allgatered_kv_c_normed[ rank * toks + src_token_idx : rank * toks + src_token_idx - + real_cp_chunk_seq_len + + real_dcp_chunk_seq_len ] k_pe_segment = allgatered_k_pe[ rank * toks + src_token_idx : rank * toks + src_token_idx - + real_cp_chunk_seq_len + + real_dcp_chunk_seq_len ] kv_c_segments.append(kv_c_segment) k_pe_segments.append(k_pe_segment) - cur_seq_len += real_cp_chunk_seq_len + cur_seq_len += real_dcp_chunk_seq_len max_seq_len_check = max(max_seq_len_check, cur_seq_len) - src_token_idx += cp_chunk_seq_len + src_token_idx += dcp_chunk_seq_len reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) reorganized_k_pe = torch.cat(k_pe_segments, dim=0) assert reorganized_kv_c_normed.shape[0] == sum_seq_len @@ -1637,11 +1637,11 @@ def _context_parallel_compute_prefill_context( kv_c_normed, k_pe = reorg_kvcache( allgatered_kv_c_normed, allgatered_k_pe, - cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[ + dcp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[ i ], origin_context_lens=prefill_metadata.chunked_context.origin_context_lens, - cp_world_size=dcp_world_size, + dcp_world_size=dcp_world_size, sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], chunk_size=prefill_metadata.chunked_context.chunk_size, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index cb5855548098..8c17732a49ce 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -94,6 +94,10 @@ class CommonAttentionMetadata: dcp_local_seq_lens: torch.Tensor | None = None """Sequence lengths of the local rank in decode context parallelism world""" + # Needed by custom mask calc for context parallelism + query_positions: np.ndarray | None = None + pcp_allgather_restore_idx: torch.Tensor | None = None + def slice_query_start_locs( query_start_loc: torch.Tensor, @@ -190,6 +194,19 @@ def _make_metadata_with_slice( block_table_tensor = attn_metadata.block_table_tensor[request_slice] slot_mapping = attn_metadata.slot_mapping[token_slice] + # TODO(qcs): check if we can split query_positions and + # cp_kv_recover_idx as following approach + query_positions = ( + attn_metadata.query_positions[token_slice] + if attn_metadata.query_positions is not None + else None + ) + cp_allgather_restore_idx = ( + attn_metadata.pcp_allgather_restore_idx[token_slice] + if attn_metadata.pcp_allgather_restore_idx is not None + else None + ) + return CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, @@ -202,6 +219,8 @@ def _make_metadata_with_slice( max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, + query_positions=query_positions, + pcp_allgather_restore_idx=cp_allgather_restore_idx, ) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 137e5e0cdb6d..c65db42bebd4 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -27,6 +27,7 @@ def __init__( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len @@ -44,6 +45,7 @@ def __init__( block_pool=self.block_pool, kv_cache_group_id=i, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups) ) @@ -210,6 +212,7 @@ def __init__( use_eagle: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ): super().__init__( kv_cache_config, @@ -218,6 +221,7 @@ def __init__( False, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) self.num_single_type_manager = len(self.single_type_managers) @@ -250,6 +254,7 @@ def __init__( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ): super().__init__( kv_cache_config, @@ -258,12 +263,16 @@ def __init__( enable_caching, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size self.dcp_world_size = dcp_world_size + self.pcp_world_size = pcp_world_size if dcp_world_size > 1: self.block_size *= dcp_world_size + if pcp_world_size > 1: + self.block_size *= pcp_world_size assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "UnitaryKVCacheCoordinator assumes only one kv cache group" ) @@ -281,6 +290,7 @@ def find_longest_cache_hit( kv_cache_spec=self.kv_cache_spec, use_eagle=self.use_eagle, dcp_world_size=self.dcp_world_size, + pcp_world_size=self.pcp_world_size, ) return hit_blocks, len(hit_blocks[0]) * self.block_size @@ -302,6 +312,7 @@ def __init__( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ): super().__init__( kv_cache_config, @@ -310,8 +321,10 @@ def __init__( enable_caching, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) assert dcp_world_size == 1, "DCP not support hybrid attn now." + assert pcp_world_size == 1, "PCP not support hybrid attn now" self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: @@ -452,6 +465,7 @@ def get_kv_cache_coordinator( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ) -> KVCacheCoordinator: if not enable_caching: return KVCacheCoordinatorNoPrefixCache( @@ -460,6 +474,7 @@ def get_kv_cache_coordinator( use_eagle, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator( @@ -469,6 +484,7 @@ def get_kv_cache_coordinator( enable_caching, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) return HybridKVCacheCoordinator( kv_cache_config, @@ -477,4 +493,5 @@ def get_kv_cache_coordinator( enable_caching, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 74176e4b2051..ef9028b61eb1 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -100,6 +100,7 @@ def __init__( log_stats: bool = False, enable_kv_cache_events: bool = False, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> None: self.max_model_len = max_model_len @@ -124,12 +125,12 @@ def __init__( 0 ].kv_cache_spec.block_size - if dcp_world_size > 1: + if dcp_world_size * pcp_world_size > 1: assert len(kv_cache_config.kv_cache_groups) == 1 # Note(hc): need revisit. When both DCP and any future # PCP are enabled, the block_size may need to be scaled # by a factor of dcp_size × pcp_size? - self.block_size *= dcp_world_size + self.block_size *= dcp_world_size * pcp_world_size self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, @@ -138,6 +139,7 @@ def __init__( enable_caching=self.enable_caching, enable_kv_cache_events=enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 08368b7d99ef..64c8e8185ba1 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -106,6 +106,7 @@ def __init__( self.block_size = block_size self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size + self.pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size # req_id -> Request self.requests: dict[str, Request] = {} @@ -170,6 +171,7 @@ def __init__( log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, + pcp_world_size=self.pcp_world_size, ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 586034182686..82f66d9c202c 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -32,6 +32,7 @@ def __init__( block_pool: BlockPool, kv_cache_group_id: int, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -42,8 +43,9 @@ def __init__( """ self.block_size = kv_cache_spec.block_size self.dcp_world_size = dcp_world_size - if self.dcp_world_size > 1: - self.block_size *= dcp_world_size + self.pcp_world_size = pcp_world_size + if self.dcp_world_size * self.pcp_world_size > 1: + self.block_size *= dcp_world_size * pcp_world_size self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool @@ -212,6 +214,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ Get the longest cache hit prefix of the blocks that is not longer than @@ -268,6 +271,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) @@ -279,8 +283,8 @@ def find_longest_cache_hit( [] for _ in range(len(kv_cache_group_ids)) ) block_size = kv_cache_spec.block_size - if dcp_world_size > 1: - block_size *= dcp_world_size + if dcp_world_size * pcp_world_size > 1: + block_size *= dcp_world_size * pcp_world_size max_num_blocks = max_length // block_size for block_hash in itertools.islice(block_hashes, max_num_blocks): # block_hashes is a chain of block hashes. If a block hash is not @@ -331,11 +335,13 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, SlidingWindowSpec), ( "SlidingWindowManager can only be used for sliding window groups" ) assert dcp_world_size == 1, "DCP not support sliding window attn now." + assert pcp_world_size == 1, "CP not support sliding window attn now." # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window @@ -434,6 +440,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ For chunked local attention, we need to find the longest cache hit @@ -474,6 +481,7 @@ def find_longest_cache_hit( "Hybrid KV cache is not supported for " + "eagle + chunked local attention." ) assert dcp_world_size == 1, "DCP not support chunked local attn now." + assert pcp_world_size == 1, "CP not support chunked local attn now." max_num_blocks = max_length // kv_cache_spec.block_size if max_length > 0: local_attention_start_idx = ( @@ -558,11 +566,13 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, MambaSpec), ( "MambaManager can only be used for mamba groups" ) assert dcp_world_size == 1, "DCP not support mamba now." + assert pcp_world_size == 1, "CP not support mamba now." computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids)) ) @@ -658,6 +668,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, CrossAttentionSpec), ( "CrossAttentionManager can only be used for cross-attention groups" diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 38e8f4ab85d9..c4cbd3078a16 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -30,6 +30,7 @@ from vllm.distributed.parallel_state import ( get_dp_group, get_ep_group, + get_pcp_group, get_pp_group, get_tp_group, ) @@ -67,10 +68,12 @@ def _init_executor(self) -> None: self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size pp_parallel_size = self.parallel_config.pipeline_parallel_size - assert self.world_size == tensor_parallel_size * pp_parallel_size, ( + pcp_size = self.parallel_config.prefill_context_parallel_size + assert self.world_size == tensor_parallel_size * pp_parallel_size * pcp_size, ( f"world_size ({self.world_size}) must be equal to the " f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" - f"_parallel_size ({pp_parallel_size}). " + f"_parallel_size ({pp_parallel_size}) x prefill_context" + f"_parallel_size ({pcp_size}). " ) # Set multiprocessing envs @@ -362,7 +365,11 @@ def _get_output_rank(self) -> int: # 16-23, PP rank 2 # 24-31, PP rank 3 # so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3) - return self.world_size - self.parallel_config.tensor_parallel_size + return ( + self.world_size + - self.parallel_config.tensor_parallel_size + * self.parallel_config.prefill_context_parallel_size + ) @dataclass @@ -715,11 +722,15 @@ def setup_proc_title_and_log_prefix(enable_ep: bool) -> None: pp_rank = get_pp_group().rank_in_group tp_size = get_tp_group().world_size tp_rank = get_tp_group().rank_in_group + pcp_size = get_pcp_group().world_size + pcp_rank = get_pcp_group().rank_in_group process_name = "Worker" if dp_size > 1: process_name += f"_DP{dp_rank}" if pp_size > 1: process_name += f"_PP{pp_rank}" + if pcp_size > 1: + process_name += f"_PCP{pcp_rank}" if tp_size > 1: process_name += f"_TP{tp_rank}" if enable_ep: diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index a9ef1b92c243..c983bd21ee82 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -88,10 +88,13 @@ class FullAttentionSpec(AttentionSpec): def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size + pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size # Note(hc): each dcp rank only need save # (max_model_len//dcp_world_size) tokens locally. if dcp_world_size > 1: max_model_len = cdiv(max_model_len, dcp_world_size) + if pcp_world_size > 1: + max_model_len = cdiv(max_model_len, pcp_world_size) return cdiv(max_model_len, self.block_size) * self.page_size_bytes @classmethod diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 9bf06d51609f..813323755004 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -4,7 +4,7 @@ import numpy as np import torch -from vllm.distributed import get_dcp_group +from vllm.distributed import get_dcp_group, get_pcp_group from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.utils import CpuGpuBuffer @@ -80,12 +80,16 @@ def __init__( self._kernel_block_arange = None try: + self.pcp_world_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group except AssertionError: # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 + self.pcp_world_size = 1 + self.pcp_rank = 0 def append_row( self, @@ -127,14 +131,16 @@ def compute_slot_mapping( # NOTE(woosuk): We can't simply use `token_indices // block_size` # here because M (max_model_len) is not necessarily divisible by # block_size. - if self.dcp_world_size > 1: + if self.dcp_world_size * self.pcp_world_size > 1: # Note(hc): The DCP implement store kvcache with an interleave # style, the kvcache for the token whose token_idx is i is - # always stored on the GPU whose dcp_rank equals i % cp_world_size: + # always stored on the GPU whose dcp_rank equals i % pcp_world_size: # Use a "virtual block" which equals to world_size * block_size # for block_table_indices calculation. - virtual_block_size = self.block_size * self.dcp_world_size + virtual_block_size = ( + self.block_size * self.dcp_world_size * self.pcp_world_size + ) block_table_indices = ( req_indices * self.max_num_blocks_per_req + positions // virtual_block_size @@ -144,9 +150,15 @@ def compute_slot_mapping( # Use virtual_block_size for mask calculation, which marks local # tokens. virtual_block_offsets = positions % virtual_block_size - mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank + self.current_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank + mask = ( + virtual_block_offsets % (self.dcp_world_size * self.pcp_world_size) + == self.current_rank + ) # Calculate local block_offsets - block_offsets = virtual_block_offsets // self.dcp_world_size + block_offsets = virtual_block_offsets // ( + self.dcp_world_size * self.pcp_world_size + ) # Calculate slot_mapping slot_mapping = block_numbers * self.block_size + block_offsets # Write final slots, use -1 for not-local diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7e72ce937be4..4dbef7d50a10 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -35,6 +35,7 @@ from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( + get_pcp_group, get_pp_group, get_tp_group, graph_capture, @@ -252,6 +253,8 @@ def __init__( # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len + self.pcp_world_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -431,6 +434,24 @@ def __init__( if self.supports_mm_inputs: self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + # Persistent buffers for Context Parallism + self.pcp_allgather_restore_idx = self._make_buffer( + self.max_num_tokens, dtype=torch.int64 + ) + self.pcp_padded_slot_mapping = torch.empty( + (self.max_num_tokens,), + dtype=torch.int64, + device=self.device, + ) + self.num_pcp_pads_cpu_tensor = torch.zeros( + (self.max_num_reqs,), device="cpu", dtype=torch.int64, pin_memory=True + ) + self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy() + self.pcp_unpad_mask_cpu_tensor = torch.zeros( + (self.max_num_tokens,), device="cpu", dtype=torch.bool, pin_memory=True + ) + self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: # NOTE: `mrope_positions` is implemented with one additional dummy @@ -919,6 +940,101 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: dummy_modality = mm_budget.get_modality_with_max_tokens() return self._get_mm_dummy_batch(dummy_modality, num_seqs) + def _update_tokens_for_pcp(self, tokens): + """ + If prefill context parallelism is enabled, we will calculate + the number of tokens `tokens` after sequence splitting. + Meanwhile, we will compute: + `positions` the new token positions, + `self.num_pcp_pads_cpu` the number of padding tokens + per request for alignment, + `self.pcp_unpad_mask_cpu` the mask for non-padded tokens, + `self.pcp_allgather_restore_idx` indices to restore the + original vector order after PCP allgather. + Example: + >>> tokens = [1, 5, 8] + >>> pcp_world_size = 2 + >>> pcp_rank = 0 + >>> _update_tokens_for_pcp(tokens) + ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5]) + >>> pcp_rank = 1 + >>> _update_tokens_for_pcp(tokens) + ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7]) + >>> # the following results are same for each pcp rank + >>> self.num_pcp_pads_cpu + [1, 3, 0] + >>> self.pcp_unpad_mask_cpu + [True, False, True, True, True, True, True, False, False, + False, True, True, True, True, True, True, True, True] + >>> self.pcp_allgather_resotre_idx + [0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8] + """ + num_reqs = self.input_batch.num_reqs + self.num_pcp_pads_cpu[:num_reqs] = 0 + if not self.pcp_world_size > 1: + return tokens, None + + num_decode_reqs = sum( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + >= self.input_batch.num_prompt_tokens[:num_reqs] + ) + + num_padded_scheduled_tokens = np.ceil( + tokens / (2 * self.pcp_world_size) + ).astype(np.int32) * (2 * self.pcp_world_size) + # we align scheduled tokens of decode reqs to pcp_world_size instead + # of 2*pcp_world_size + num_padded_scheduled_tokens[:num_decode_reqs] = self.pcp_world_size + self.num_pcp_pads_cpu[:num_reqs] = num_padded_scheduled_tokens - tokens + cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange( + num_padded_scheduled_tokens + ) + self.pcp_unpad_mask_cpu[: pcp_padded_arange.shape[0]] = ( + pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens) + ) + + pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size + pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) + _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens) + _, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) + pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, pcp_tokens) + + def get_current_rank_positions( + positions_start_loc: int | np.ndarray, rank: int + ): + positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) + head_start_loc = positions_start_loc + rank * pcp_chunk_sizes + tail_start_loc = ( + positions_start_loc + + (2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes + ) + positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat( + head_start_loc, pcp_chunk_sizes + ) + # Decode reqs do not have tail chunks. + positions[~pcp_head_chunk_mask] = ( + pcp_chunk_arange[num_decode_reqs:] + + np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_reqs:] + ) + return positions + + positions = get_current_rank_positions(0, self.pcp_rank) + # Decode tokens are duplicate and their positions always be 0. + positions[:num_decode_reqs] = 0 + + padded_pos_start_loc = np.roll(cu_padded_tokens, 1) + padded_pos_start_loc[0] = 0 + all_positions_lst = [ + get_current_rank_positions(padded_pos_start_loc, rank_i) + for rank_i in range(self.pcp_world_size) + ] + all_positions = np.concatenate(all_positions_lst) + self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = ( + all_positions.argsort() + ) + self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) + return pcp_tokens, positions + def _get_cumsum_and_arange( self, num_tokens: np.ndarray, @@ -1064,8 +1180,26 @@ def _prepare_inputs( # Get the number of scheduled tokens for each request. req_ids = self.input_batch.req_ids tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + # NOTE(qcs): we need compute slotmapping for all kv + # instead of sliced sequences num_scheduled_tokens = np.array(tokens, dtype=np.int32) - max_num_scheduled_tokens = max(tokens) + + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + _, arange = self._get_cumsum_and_arange(num_scheduled_tokens) + positions_np = np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + ) + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + + num_scheduled_tokens, positions_cp = self._update_tokens_for_pcp( + num_scheduled_tokens + ) + # update total_num_scheduled_tokens + total_num_scheduled_tokens = sum(num_scheduled_tokens) + total_num_pcp_pads = sum(self.num_pcp_pads_cpu[:num_reqs]) + max_num_scheduled_tokens = max(num_scheduled_tokens) # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] @@ -1077,11 +1211,13 @@ def _prepare_inputs( # Get positions. positions_np = self.positions.np[:total_num_scheduled_tokens] - np.add( - self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np, - ) + if self.pcp_world_size > 1: + assert positions_cp is not None + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + positions_cp[:total_num_scheduled_tokens], + out=positions_np, + ) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -1153,9 +1289,6 @@ def _prepare_inputs( output_idx += num_sched - self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) - # Prepare the attention metadata. self.query_start_loc.np[0] = 0 self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens @@ -1201,7 +1334,10 @@ def _prepare_inputs( # Record the index of requests that should not be sampled, # so that we could clear the sampled tokens before returning - discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np + discard_requests_mask = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + np.array(tokens, dtype=np.int32) + ) < num_tokens_np discard_request_indices = np.nonzero(discard_requests_mask)[0] self.num_discarded_requests = len(discard_request_indices) self.discard_request_indices.np[: self.num_discarded_requests] = ( @@ -1230,10 +1366,15 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 + logits_indices = ( + torch.from_numpy(cu_num_tokens) * self.pcp_world_size + - self.num_pcp_pads_cpu_tensor[:num_reqs] + - 1 + ) num_draft_tokens = None spec_decode_metadata = None else: + assert self.pcp_world_size == 1, "PCP not support spec decode now" # Get the number of draft tokens for each request. # Iterate over the dictionary rather than all requests since not all # requests have draft tokens. @@ -1299,6 +1440,12 @@ def _prepare_inputs( scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs ) + slot_mapping_size = ( + total_num_scheduled_tokens + if self.pcp_world_size == 1 + else total_num_scheduled_tokens * self.pcp_world_size + - total_num_pcp_pads + ) if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): # Encoder-only layers do not have KV cache, so we need to # create a dummy block table and slot mapping for them. @@ -1308,7 +1455,7 @@ def _prepare_inputs( device=self.device, ) slot_mapping = torch.zeros( - (total_num_scheduled_tokens,), + (slot_mapping_size,), dtype=torch.int64, device=self.device, ) @@ -1316,15 +1463,29 @@ def _prepare_inputs( else: blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor(num_reqs) - slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens] + + slot_mapping = blk_table.slot_mapping.gpu[:slot_mapping_size] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. - blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + blk_table.slot_mapping.gpu[slot_mapping_size:].fill_(-1) num_common_prefix_blocks = scheduler_output.num_common_prefix_blocks[ kv_cache_group_id ] + if self.pcp_world_size > 1: + # After cp allgather and restore, there are padded tokens in + # kv, so we need pad slotmapping for alignment. + pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[ + : total_num_scheduled_tokens * self.pcp_world_size + ] + cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[ + : total_num_scheduled_tokens * self.pcp_world_size + ] + pcp_padded_slot_mapping.fill_(-1) + pcp_padded_slot_mapping[cp_unpad_mask] = slot_mapping + slot_mapping = pcp_padded_slot_mapping + common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, @@ -1341,6 +1502,12 @@ def _prepare_inputs( num_logits_indices=logits_indices.size(0), causal=True, encoder_seq_lens=encoder_seq_lens, + query_positions=positions_np, + pcp_allgather_restore_idx=self.pcp_allgather_restore_idx.gpu[ + : total_num_scheduled_tokens * self.pcp_world_size + ] + if self.pcp_world_size > 1 + else None, dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None, @@ -2447,6 +2614,12 @@ def execute_model( self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) elif num_tokens_across_dp is not None: num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + elif self.pcp_world_size > 1: + # NOTE(qcs): For PCP, we pad num_scheduled_tokens_np but + # do not update total_num_scheduled_tokens in scheduler_output + num_input_tokens = self._get_num_input_tokens( + sum(num_scheduled_tokens_np) + ) else: num_input_tokens = self._get_num_input_tokens( scheduler_output.total_num_scheduled_tokens @@ -2517,6 +2690,13 @@ def execute_model( hidden_states = model_output aux_hidden_states = None + if self.pcp_world_size > 1: + hidden_states = get_pcp_group().all_gather(hidden_states, 0) + hidden_states = torch.index_select( + hidden_states, + 0, + self.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]], + ) if not self.broadcast_pp_output: # Common case. if not get_pp_group().is_last_rank: @@ -3307,7 +3487,9 @@ def _dummy_run( self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() - cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + cum_num_tokens, query_positions = self._get_cumsum_and_arange( + num_scheduled_tokens + ) self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() @@ -3333,6 +3515,10 @@ def _dummy_run( kv_cache_group_id ].slot_mapping.gpu[:num_tokens], causal=True, + query_positions=query_positions, + pcp_allgather_restore_idx=self.pcp_allgather_restore_idx.gpu[ + : total_num_scheduled_tokens * self.pcp_world_size + ], dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 00dc7682c973..964ab774e6ec 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -776,6 +776,7 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, + parallel_config.prefill_context_parallel_size, parallel_config.decode_context_parallel_size, ) From 551d87eb7827db5da62d4ed6588b65bfc3ffd312 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Mon, 20 Oct 2025 17:37:32 +0800 Subject: [PATCH 02/15] [typo] wrong param name and comment Signed-off-by: QiuChunshuo --- vllm/distributed/parallel_state.py | 8 ++++---- vllm/model_executor/layers/fused_moe/config.py | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b00d916e83b8..8f7792d8c7af 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1225,7 +1225,7 @@ def init_distributed_environment( def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, - context_model_parallel_size: int = 1, + prefill_context_model_parallel_size: int = 1, decode_context_model_parallel_size: int | None = 1, backend: str | None = None, ) -> None: @@ -1278,7 +1278,7 @@ def initialize_model_parallel( -1, data_parallel_size, pipeline_model_parallel_size, - context_model_parallel_size, + prefill_context_model_parallel_size, tensor_model_parallel_size, ) # noqa @@ -1341,7 +1341,7 @@ def initialize_model_parallel( -1, data_parallel_size * tensor_model_parallel_size - * context_model_parallel_size, + * prefill_context_model_parallel_size, ) .unbind(0) ) @@ -1353,7 +1353,7 @@ def initialize_model_parallel( global _PCP assert _PCP is None, "prefill context parallel group is already initialized" group_ranks = ( - all_ranks.transpose(3, 4).reshape(-1, context_model_parallel_size).unbind(0) + all_ranks.transpose(3, 4).reshape(-1, prefill_context_model_parallel_size).unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] _PCP = init_model_parallel_group( diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 32b2b6573272..f024f0d0fcbb 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -681,8 +681,7 @@ def make( level's of parallelism to use in the fused moe layer. Args: - tp_size_ (int): `tp_size` pa use_ep = (dp_size_ * tp_size_ssed into - the FusedMoE constructor. + tp_size_ (int): `tp_size` passed into the FusedMoE constructor. dp_size_ (int): `dp_size` passed into the FusedMoE constructor. vllm_parallel_config (ParallelConfig): vLLM's parallel config object which contains the `enable_expert_parallel` flag. From f4e83323b76282cfb240bebd78f5fc2787fb5805 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Wed, 22 Oct 2025 21:42:25 +0800 Subject: [PATCH 03/15] [bugfix] number of padded tokens may greater than max_num_batched_tokens Signed-off-by: QiuChunshuo --- vllm/v1/worker/gpu_model_runner.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4dbef7d50a10..63d35d778286 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -435,11 +435,13 @@ def __init__( self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) # Persistent buffers for Context Parallism + max_num_padded_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_world_size self.pcp_allgather_restore_idx = self._make_buffer( - self.max_num_tokens, dtype=torch.int64 + max_num_padded_tokens, + dtype=torch.int64 ) self.pcp_padded_slot_mapping = torch.empty( - (self.max_num_tokens,), + (max_num_padded_tokens,), dtype=torch.int64, device=self.device, ) @@ -448,7 +450,7 @@ def __init__( ) self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy() self.pcp_unpad_mask_cpu_tensor = torch.zeros( - (self.max_num_tokens,), device="cpu", dtype=torch.bool, pin_memory=True + (max_num_padded_tokens,), device="cpu", dtype=torch.bool, pin_memory=True ) self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() From eb5628ea1ec5b59b821141c4e1414715a383a30e Mon Sep 17 00:00:00 2001 From: FENP Date: Thu, 23 Oct 2025 11:33:03 +0800 Subject: [PATCH 04/15] bug fix: write positions when not use pcp Signed-off-by: FENP --- vllm/v1/worker/gpu_model_runner.py | 45 ++++++++++++++++-------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 63d35d778286..156ba77e2d35 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1182,26 +1182,8 @@ def _prepare_inputs( # Get the number of scheduled tokens for each request. req_ids = self.input_batch.req_ids tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - # NOTE(qcs): we need compute slotmapping for all kv - # instead of sliced sequences num_scheduled_tokens = np.array(tokens, dtype=np.int32) - - req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) - _, arange = self._get_cumsum_and_arange(num_scheduled_tokens) - positions_np = np.add( - self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - ) - self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) - - num_scheduled_tokens, positions_cp = self._update_tokens_for_pcp( - num_scheduled_tokens - ) - # update total_num_scheduled_tokens - total_num_scheduled_tokens = sum(num_scheduled_tokens) - total_num_pcp_pads = sum(self.num_pcp_pads_cpu[:num_reqs]) - max_num_scheduled_tokens = max(num_scheduled_tokens) + max_num_scheduled_tokens = max(tokens) # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] @@ -1213,11 +1195,32 @@ def _prepare_inputs( # Get positions. positions_np = self.positions.np[:total_num_scheduled_tokens] + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, + ) + + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + if self.pcp_world_size > 1: - assert positions_cp is not None + num_scheduled_tokens, pcp_positions = self._update_tokens_for_pcp( + num_scheduled_tokens + ) + assert pcp_positions is not None + + # Re-update after PCP split sequences. + total_num_scheduled_tokens = sum(num_scheduled_tokens) + total_num_pcp_pads = sum(self.num_pcp_pads_cpu[:num_reqs]) + max_num_scheduled_tokens = max(num_scheduled_tokens) + + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + cu_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + positions_np = self.positions.np[:total_num_scheduled_tokens] np.add( self.input_batch.num_computed_tokens_cpu[req_indices], - positions_cp[:total_num_scheduled_tokens], + pcp_positions[:total_num_scheduled_tokens], out=positions_np, ) From 44473de230a3a3dabbf6e148020b1e11fbca03af Mon Sep 17 00:00:00 2001 From: FENP Date: Tue, 21 Oct 2025 14:57:00 +0800 Subject: [PATCH 05/15] disable prefix caching and chunk prefill when using PCP Signed-off-by: FENP --- vllm/engine/arg_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 18e981d9b847..fe288d5281bf 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1734,6 +1734,15 @@ def _set_default_args( self.enable_prefix_caching = False else: self.enable_prefix_caching = True + + if self.prefill_context_parallel_size > 1: + self.enable_chunked_prefill = False + self.enable_prefix_caching = False + logger.warning( + "--prefill-context-parallel-size > 1 is not compatible with " + "chunked prefill and prefix caching now. Chunked prefill " + "and prefix caching have been disabled." + ) else: pooling_type = model_config.pooler_config.pooling_type is_causal = getattr(model_config.hf_config, "is_causal", True) From f0ab17cc6e2708656100931266476ea9b4c1b214 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Thu, 23 Oct 2025 16:21:34 +0800 Subject: [PATCH 06/15] [Perf] Optimize custom_mask computation Signed-off-by: QiuChunshuo --- vllm/v1/attention/backends/flashinfer.py | 52 +++++++++++++++--------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index aeef2db86677..1ccad7543acd 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -480,6 +480,28 @@ def _get_cascade_wrapper(self): ) return self._cascade_wrapper + def _get_pcp_custom_mask( + self, + qo_indptr_cpu: torch.Tensor, + q_pos: torch.Tensor, + kv_lens: torch.Tensor, + ) -> torch.Tensor: + max_q_lens = int(q_pos.max().item()) + 1 + max_kv_lens = int(kv_lens.max().item()) + mask = torch.ones( + max_q_lens, + max_kv_lens, + dtype=torch.bool, + device=q_pos.device, + ).tril() + custom_mask_lst = [ + mask[q_pos[q_pos_start_loc:q_pos_end_loc], :kv_len].flatten() + for kv_len, q_pos_start_loc, q_pos_end_loc in + zip(kv_lens, qo_indptr_cpu[:-1], qo_indptr_cpu[1:]) + ] + custom_mask = torch.cat(custom_mask_lst) + return custom_mask + def build( self, common_prefix_len: int, @@ -705,27 +727,17 @@ def build( ] kv_indptr_cpu = qo_indptr_cpu * self.pcp_world_size # init custom mask for head-tail query order - q_pos = torch.from_numpy( - common_attn_metadata.query_positions[prefill_start:] - ).long() - kv_lens = ( - prefill_num_computed_tokens_cpu - + kv_indptr_cpu[1:] - - kv_indptr_cpu[:-1] - ) - max_q_lens = int(q_pos.max().item()) + 1 - max_kv_lens = int(kv_lens.max().item()) - mask = torch.ones( - max_q_lens, max_kv_lens, dtype=torch.bool - ).tril() - selected_rows = torch.index_select(mask, 0, q_pos) - col_indices = torch.arange(max_kv_lens).expand( - q_pos.size(0), -1 + custom_mask = self._get_pcp_custom_mask( + qo_indptr_cpu=qo_indptr_cpu, + q_pos=torch.from_numpy( + common_attn_metadata.query_positions[prefill_start:] + ).long().to(self.device), + kv_lens=( + prefill_num_computed_tokens_cpu + + kv_indptr_cpu[1:] + - kv_indptr_cpu[:-1] + ).to(self.device), ) - valid_mask = col_indices < torch.repeat_interleave( - kv_lens, qo_indptr_cpu[1:] - qo_indptr_cpu[:-1] - ).unsqueeze(1) - custom_mask = selected_rows[valid_mask].to(self.device) attn_metadata.prefill_wrapper.plan( qo_indptr_cpu.to(self.device), From 272c8f1b6c1cdaebe8740ccee994bc450f868312 Mon Sep 17 00:00:00 2001 From: FENP Date: Tue, 21 Oct 2025 14:58:32 +0800 Subject: [PATCH 07/15] code cleanup and fix scheduler_block_size Signed-off-by: FENP --- vllm/attention/ops/common.py | 35 +++++++++---------- vllm/distributed/parallel_state.py | 20 ----------- .../model_executor/layers/fused_moe/config.py | 4 +-- vllm/model_executor/layers/fused_moe/layer.py | 4 +-- vllm/v1/core/single_type_kv_cache_manager.py | 6 ++-- vllm/v1/engine/core.py | 1 + 6 files changed, 25 insertions(+), 45 deletions(-) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 217f32b8baec..1b7c1dabc10c 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -168,12 +168,11 @@ def correct_attn_out( return out, lse -def cp_lse_ag_out_rs( +def _cp_lse_common( cp_attn_out: torch.Tensor, cp_attn_lse: torch.Tensor, cp_group: GroupCoordinator, ctx: CPTritonContext = None, - return_lse=False, ): """ cp_attn_out: [ B, H, D ] @@ -195,6 +194,21 @@ def cp_lse_ag_out_rs( lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) assert out.is_contiguous() + return out, lse + + +def cp_lse_ag_out_rs( + cp_attn_out: torch.Tensor, + cp_attn_lse: torch.Tensor, + cp_group: GroupCoordinator, + ctx: CPTritonContext = None, + return_lse: bool = False, +): + """ + cp_attn_out: [ B, H, D ] + cp_attn_lse: [ B, H ] + """ + out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx) out = cp_group.reduce_scatter(out, dim=1) if return_lse: @@ -215,22 +229,7 @@ def cp_lse_ag_out_ar( cp_attn_out: [ B, H, D ] cp_attn_lse: [ B, H ] """ - if cp_group.world_size == 1: - return cp_attn_out - - if ctx is None: - ctx = CPTritonContext() - - lses = torch.empty( - (cp_group.world_size,) + cp_attn_lse.shape, - dtype=cp_attn_lse.dtype, - device=cp_attn_lse.device, - ) - - cp_attn_lse = cp_attn_lse.contiguous() - lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) - out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) - assert out.is_contiguous() + out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx) out = cp_group.all_reduce(out) return out diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 8f7792d8c7af..6924464872a3 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1093,16 +1093,6 @@ def get_pcp_group() -> GroupCoordinator: return _PCP -def get_prefill_context_model_parallel_world_size(): - """Return world size for the tensor model parallel group.""" - return get_pcp_group().world_size - - -def get_prefill_context_model_parallel_rank(): - """Return my rank for the tensor model parallel group.""" - return get_pcp_group().rank_in_group - - @deprecated( "`get_pipeline_model_parallel_group` has been replaced with " "`get_pp_group` and may be removed in v0.12. Please use " @@ -1476,16 +1466,6 @@ def get_tensor_model_parallel_rank(): return get_tp_group().rank_in_group -def get_decode_context_model_parallel_world_size(): - """Return world size for the decode context model parallel group.""" - return get_dcp_group().world_size - - -def get_decode_context_model_parallel_rank(): - """Return my rank for the decode context model parallel group.""" - return get_dcp_group().rank_in_group - - def get_node_count() -> int: """Return the total number of nodes in the distributed environment.""" assert _NODE_COUNT is not None, "distributed environment is not initialized" diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index f024f0d0fcbb..f62b5705180d 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -9,7 +9,7 @@ from vllm.config import ParallelConfig from vllm.distributed import ( get_dp_group, - get_prefill_context_model_parallel_rank, + get_pcp_group, get_tensor_model_parallel_rank, ) from vllm.logger import init_logger @@ -763,7 +763,7 @@ def flatten_tp_across_dp(dp_rank: int): dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 tp_size, tp_rank = flatten_tp_across_dp(dp_rank) pcp_size = pcp_size_ - pcp_rank = get_prefill_context_model_parallel_rank() if pcp_size_ > 1 else 0 + pcp_rank = get_pcp_group().rank_in_group if pcp_size_ > 1 else 0 if not use_ep: return FusedMoEParallelConfig( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 12edc274ad16..6c5864dfd368 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -18,7 +18,7 @@ from vllm.distributed import ( get_dp_group, get_ep_group, - get_prefill_context_model_parallel_world_size, + get_pcp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) @@ -1103,7 +1103,7 @@ def __init__( pcp_size_ = ( pcp_size if pcp_size is not None - else get_prefill_context_model_parallel_world_size() + else get_pcp_group().world_size ) self.is_sequence_parallel = is_sequence_parallel diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 82f66d9c202c..a1235581f813 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -341,7 +341,7 @@ def find_longest_cache_hit( "SlidingWindowManager can only be used for sliding window groups" ) assert dcp_world_size == 1, "DCP not support sliding window attn now." - assert pcp_world_size == 1, "CP not support sliding window attn now." + assert pcp_world_size == 1, "PCP not support sliding window attn now." # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window @@ -481,7 +481,7 @@ def find_longest_cache_hit( "Hybrid KV cache is not supported for " + "eagle + chunked local attention." ) assert dcp_world_size == 1, "DCP not support chunked local attn now." - assert pcp_world_size == 1, "CP not support chunked local attn now." + assert pcp_world_size == 1, "PCP not support chunked local attn now." max_num_blocks = max_length // kv_cache_spec.block_size if max_length > 0: local_attention_start_idx = ( @@ -572,7 +572,7 @@ def find_longest_cache_hit( "MambaManager can only be used for mamba groups" ) assert dcp_world_size == 1, "DCP not support mamba now." - assert pcp_world_size == 1, "CP not support mamba now." + assert pcp_world_size == 1, "PCP not support mamba now." computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids)) ) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0ca60ce5cf9a..8bf517d11f92 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -148,6 +148,7 @@ def __init__( scheduler_block_size = ( vllm_config.cache_config.block_size * vllm_config.parallel_config.decode_context_parallel_size + * vllm_config.parallel_config.prefill_context_parallel_size ) self.scheduler: SchedulerInterface = Scheduler( From e2e2952ba2732bd374d8c540b2ad9f405c4bfa31 Mon Sep 17 00:00:00 2001 From: FENP Date: Tue, 21 Oct 2025 15:41:23 +0800 Subject: [PATCH 08/15] increase kv cache size by pcp size Signed-off-by: FENP --- vllm/v1/core/kv_cache_utils.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6c9a77ccb2b6..806f6ef34be8 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1189,11 +1189,20 @@ def _report_kv_cache_config( // len(kv_cache_config.kv_cache_groups) * min_block_size ) - if vllm_config.parallel_config.decode_context_parallel_size > 1: - num_tokens *= vllm_config.parallel_config.decode_context_parallel_size + if ( + vllm_config.parallel_config.prefill_context_parallel_size * + vllm_config.parallel_config.decode_context_parallel_size > 1 + ): + num_tokens *= (vllm_config.parallel_config.prefill_context_parallel_size * + vllm_config.parallel_config.decode_context_parallel_size) + cp_size = (vllm_config.parallel_config.prefill_context_parallel_size * + vllm_config.parallel_config.decode_context_parallel_size) logger.info( - "Multiplying the GPU KV cache size by the dcp_world_size %d.", - vllm_config.parallel_config.decode_context_parallel_size, + "Multiplying the GPU KV cache size by the cp_world_size %d " + "(pcp_world_size %d * dcp_world_size %d).", + cp_size, + vllm_config.parallel_config.prefill_context_parallel_size, + vllm_config.parallel_config.decode_context_parallel_size ) num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) From 1598b45d7ed1111a6da56ed0b08f9a3ac10de6f1 Mon Sep 17 00:00:00 2001 From: FENP Date: Fri, 24 Oct 2025 15:36:28 +0800 Subject: [PATCH 09/15] Perf: support PIECEWISE cuda graph for PCP Signed-off-by: FENP --- vllm/config/vllm.py | 9 +++++++++ vllm/platforms/cuda.py | 7 ------- vllm/v1/attention/backends/flashinfer.py | 25 ++++++++++++------------ vllm/v1/worker/gpu_model_runner.py | 7 ++++++- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 7ee522ea9f0c..67eae381a58d 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -359,6 +359,15 @@ def __post_init__(self): ): self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + # prefill context parallel do not support full cudagraphs now. + if self.parallel_config.prefill_context_parallel_size > 1: + logger.warning( + "Prefill context parallel (PCP) is enabled, which is " + "incompatible with full CUDA graphs. Set " + "cudagraph_mode to PIECEWISE." + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + # decode context parallel do not support full cudagraphs now. if self.parallel_config.decode_context_parallel_size > 1: logger.warning( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c1b4729c58e7..a6b9df7c1446 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -206,13 +206,6 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: ) compilation_config.cudagraph_mode = CUDAGraphMode.NONE - if ( - compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and parallel_config.prefill_context_parallel_size > 1 - ): - logger.info("Prefill Context Parallel: disabling cudagraphs since PCP.") - compilation_config.cudagraph_mode = CUDAGraphMode.NONE - @classmethod def get_current_memory_usage( cls, device: torch.types.Device | None = None diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 1ccad7543acd..b51da01cc2b4 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1014,24 +1014,25 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens - key_across_cp = get_pcp_group().all_gather(key.contiguous(), dim=0) - value_across_cp = get_pcp_group().all_gather(value.contiguous(), dim=0) - if ( - self.pcp_world_size > 1 - and attn_metadata.pcp_allgather_restore_idx is not None - ): - # Reorder kv after cp allgather. + if (self.pcp_world_size > 1): + assert attn_metadata.pcp_allgather_restore_idx is not None + # NOTE(yyj): we must `slice` key and value because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. To be optimized for performance! + key_across_cp = get_pcp_group().all_gather( + key[:num_actual_tokens].contiguous(), dim=0 + ) + value_across_cp = get_pcp_group().all_gather( + value[:num_actual_tokens].contiguous(), dim=0 + ) + # Reorder kv after pcp allgather. # Note that there are duplicate decoding tokens, # but we only save the first one in kvcache. - key_across_cp = torch.index_select( + key = torch.index_select( key_across_cp, 0, attn_metadata.pcp_allgather_restore_idx ) - value_across_cp = torch.index_select( + value = torch.index_select( value_across_cp, 0, attn_metadata.pcp_allgather_restore_idx ) - key = key_across_cp - value = value_across_cp - if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 156ba77e2d35..ce250b827c0f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2696,7 +2696,12 @@ def execute_model( aux_hidden_states = None if self.pcp_world_size > 1: - hidden_states = get_pcp_group().all_gather(hidden_states, 0) + # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. + hidden_states = get_pcp_group().all_gather( + hidden_states[:num_scheduled_tokens], + 0, + ) hidden_states = torch.index_select( hidden_states, 0, From 502fc0d8068d8a886180c0002d4ec73e1fe9f11e Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Mon, 27 Oct 2025 19:48:18 +0800 Subject: [PATCH 10/15] [bugfix] fix _update_tokens_for_pcp for MTP Signed-off-by: QiuChunshuo --- vllm/v1/worker/gpu_model_runner.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ce250b827c0f..4f4e52d07f0a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -958,10 +958,10 @@ def _update_tokens_for_pcp(self, tokens): >>> pcp_world_size = 2 >>> pcp_rank = 0 >>> _update_tokens_for_pcp(tokens) - ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5]) + ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7]) >>> pcp_rank = 1 >>> _update_tokens_for_pcp(tokens) - ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7]) + ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5]) >>> # the following results are same for each pcp rank >>> self.num_pcp_pads_cpu [1, 3, 0] @@ -980,13 +980,15 @@ def _update_tokens_for_pcp(self, tokens): self.input_batch.num_computed_tokens_cpu[:num_reqs] >= self.input_batch.num_prompt_tokens[:num_reqs] ) + num_decode_tokens = sum(tokens[:num_decode_reqs]) num_padded_scheduled_tokens = np.ceil( tokens / (2 * self.pcp_world_size) ).astype(np.int32) * (2 * self.pcp_world_size) - # we align scheduled tokens of decode reqs to pcp_world_size instead - # of 2*pcp_world_size - num_padded_scheduled_tokens[:num_decode_reqs] = self.pcp_world_size + # we duplicate scheduled tokens of decode reqs to pcp_world_size + num_padded_scheduled_tokens[:num_decode_reqs] = ( + tokens[:num_decode_reqs] * self.pcp_world_size + ) self.num_pcp_pads_cpu[:num_reqs] = num_padded_scheduled_tokens - tokens cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange( num_padded_scheduled_tokens @@ -997,6 +999,7 @@ def _update_tokens_for_pcp(self, tokens): pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) + pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs] _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens) _, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, pcp_tokens) @@ -1015,14 +1018,18 @@ def get_current_rank_positions( ) # Decode reqs do not have tail chunks. positions[~pcp_head_chunk_mask] = ( - pcp_chunk_arange[num_decode_reqs:] - + np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_reqs:] + pcp_chunk_arange[num_decode_tokens:] + + np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:] ) return positions positions = get_current_rank_positions(0, self.pcp_rank) - # Decode tokens are duplicate and their positions always be 0. - positions[:num_decode_reqs] = 0 + # Decode tokens are duplicated only after AG. But their positions are + # same without prefill context parallel. + if num_decode_reqs > 0: + positions[:num_decode_tokens] = self._get_cumsum_and_arange( + tokens[:num_decode_reqs] + )[1] padded_pos_start_loc = np.roll(cu_padded_tokens, 1) padded_pos_start_loc[0] = 0 From 512228871c5f72c610413deb8e537b9c6260f01f Mon Sep 17 00:00:00 2001 From: FENP Date: Tue, 28 Oct 2025 15:15:40 +0800 Subject: [PATCH 11/15] bug fix: add dispatch & combine to PCP Signed-off-by: FENP --- .../model_executor/layers/fused_moe/config.py | 20 ++++++++-------- vllm/model_executor/layers/fused_moe/layer.py | 24 +++++++++++++++++++ 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index f62b5705180d..13a0942fa54b 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -746,12 +746,12 @@ def make( between the 4 devices. """ - def flatten_tp_across_dp(dp_rank: int): + def flatten_tp_across_dp_and_pcp(dp_rank: int, pcp_rank: int): tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() - # There are actually dp_size_ * tp_size_ devices. Update tp_size - # and tp_rank so we shard across all devices. - tp_size = dp_size_ * tp_size_ - tp_rank = dp_rank * tp_size_ + tp_rank + # There are actually dp_size_ * pcp_size_ * tp_size_ devices. + # Update tp_size and tp_rank so we shard across all devices. + tp_size = dp_size_ * pcp_size_ * tp_size_ + tp_rank = dp_rank * pcp_size_ * tp_size_ + pcp_rank * tp_size_ + tp_rank return tp_size, tp_rank use_ep = ( @@ -761,9 +761,9 @@ def flatten_tp_across_dp(dp_rank: int): dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 - tp_size, tp_rank = flatten_tp_across_dp(dp_rank) pcp_size = pcp_size_ pcp_rank = get_pcp_group().rank_in_group if pcp_size_ > 1 else 0 + tp_size, tp_rank = flatten_tp_across_dp_and_pcp(dp_rank, pcp_rank) if not use_ep: return FusedMoEParallelConfig( @@ -782,13 +782,13 @@ def flatten_tp_across_dp(dp_rank: int): assert use_ep # In EP, each device owns a set of experts fully. There is no tensor # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. - ep_size = tp_size * pcp_size - ep_rank = tp_rank + tp_size * pcp_rank + ep_size = tp_size + ep_rank = tp_rank return FusedMoEParallelConfig( tp_size=1, tp_rank=0, - pcp_size=1, - pcp_rank=0, + pcp_size=pcp_size, + pcp_rank=pcp_rank, dp_size=dp_size, dp_rank=dp_rank, ep_size=ep_size, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6c5864dfd368..9c9a6ba295db 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1342,6 +1342,10 @@ def tp_size(self): @property def dp_size(self): return self.moe_parallel_config.dp_size + + @property + def pcp_size(self): + return self.moe_parallel_config.pcp_size @property def ep_size(self): @@ -1354,6 +1358,10 @@ def tp_rank(self): @property def dp_rank(self): return self.moe_parallel_config.dp_rank + + @property + def pcp_rank(self): + return self.moe_parallel_config.pcp_rank @property def ep_rank(self): @@ -2340,6 +2348,16 @@ def forward_impl( hidden_states, router_logits, self.is_sequence_parallel ) + if self.pcp_size > 1: + hidden_states = get_pcp_group().all_gather( + hidden_states, + dim=0, + ) + router_logits = get_pcp_group().all_gather( + router_logits, + dim=0, + ) + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -2383,6 +2401,12 @@ def reduce_output( if do_naive_dispatch_combine and do_combine: states = get_ep_group().combine(states, self.is_sequence_parallel) + if self.pcp_size > 1: + states = get_pcp_group().reduce_scatter( + states, + dim=0, + ) + if ( not self.is_sequence_parallel and self.reduce_results From c30363e7d41fa32a942e04bfcbcd4f5da2d12417 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Tue, 28 Oct 2025 17:43:59 +0800 Subject: [PATCH 12/15] [bugfix] use wrong slice for hidden_states before ag Signed-off-by: QiuChunshuo --- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4f4e52d07f0a..5e2729486fb5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2630,7 +2630,7 @@ def execute_model( # NOTE(qcs): For PCP, we pad num_scheduled_tokens_np but # do not update total_num_scheduled_tokens in scheduler_output num_input_tokens = self._get_num_input_tokens( - sum(num_scheduled_tokens_np) + num_scheduled_tokens_np.sum() ) else: num_input_tokens = self._get_num_input_tokens( @@ -2706,7 +2706,7 @@ def execute_model( # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx # ignores the padding from CUDA Graph. hidden_states = get_pcp_group().all_gather( - hidden_states[:num_scheduled_tokens], + hidden_states[:num_scheduled_tokens_np.sum()], 0, ) hidden_states = torch.index_select( From 1bb334b76efa61b5c9dd2e3f37f58227a96a335e Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Wed, 29 Oct 2025 09:19:04 +0800 Subject: [PATCH 13/15] [bugfix] allocate more buffer for pcp pad Signed-off-by: QiuChunshuo --- vllm/v1/worker/gpu_model_runner.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5e2729486fb5..b7d78dee9133 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -400,9 +400,13 @@ def __init__( # Cache the device properties. self._init_device_properties() + if self.pcp_world_size > 1: + max_num_padded_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_world_size + else: + max_num_padded_tokens = self.max_num_tokens # Persistent buffers for CUDA graphs. - self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) - self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) + self.input_ids = self._make_buffer(max_num_padded_tokens, dtype=torch.int32) + self.positions = self._make_buffer(max_num_padded_tokens, dtype=torch.int64) self.query_start_loc = self._make_buffer( self.max_num_reqs + 1, dtype=torch.int32 ) @@ -417,7 +421,7 @@ def __init__( self.inputs_embeds = self._make_buffer( self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False ) - self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + self.is_token_ids = self._make_buffer(max_num_padded_tokens, dtype=torch.bool) self.discard_request_indices = self._make_buffer( self.max_num_reqs, dtype=torch.int64 ) @@ -435,7 +439,6 @@ def __init__( self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) # Persistent buffers for Context Parallism - max_num_padded_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_world_size self.pcp_allgather_restore_idx = self._make_buffer( max_num_padded_tokens, dtype=torch.int64 @@ -476,7 +479,7 @@ def __init__( # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context self.arange_np = np.arange( - max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), + max(self.max_num_reqs + 1, self.max_model_len, max_num_padded_tokens), dtype=np.int64, ) @@ -490,7 +493,7 @@ def __init__( self.kv_sharing_fast_prefill_logits_indices = None if self.cache_config.kv_sharing_fast_prefill: self.kv_sharing_fast_prefill_logits_indices = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device=self.device + max_num_padded_tokens, dtype=torch.int32, device=self.device ) self.uniform_decode_query_len = ( From d09bbc664dcc64157ae704b12bd559e673e53360 Mon Sep 17 00:00:00 2001 From: FENP Date: Wed, 29 Oct 2025 17:26:45 +0800 Subject: [PATCH 14/15] misc:fix comments Signed-off-by: FENP --- tests/distributed/test_context_parallel.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 3 +++ vllm/v1/worker/gpu_model_runner.py | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 3f70745a63f4..71df5c81b320 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -64,7 +64,7 @@ def detailed( for eager_mode_val in [False]: for pp_multiplier in [1]: # TODO(qcs): Test the effect of mixed activation - # when CP and DCP are compatible. + # when PCP and DCP are compatible. for pcp_multiplier, dcp_multiplier in zip([1, 2, 1], [0.5, 1, 1]): for chunked_prefill_val in [True]: parallel_setups.append( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9c9a6ba295db..f929a5cee604 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2348,6 +2348,9 @@ def forward_impl( hidden_states, router_logits, self.is_sequence_parallel ) + # NOTE: Similar with DP, PCP also needs dispatch and combine. For + # simplicity, AgRsAll2All was added separately for PCP here. Maybe + # we should modify All2AllManager abstract to better support PCP. if self.pcp_size > 1: hidden_states = get_pcp_group().all_gather( hidden_states, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b7d78dee9133..cec2c8e92108 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -438,7 +438,7 @@ def __init__( if self.supports_mm_inputs: self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) - # Persistent buffers for Context Parallism + # Persistent buffers for Prefill Context Parallism self.pcp_allgather_restore_idx = self._make_buffer( max_num_padded_tokens, dtype=torch.int64 @@ -1489,7 +1489,7 @@ def _prepare_inputs( ] if self.pcp_world_size > 1: - # After cp allgather and restore, there are padded tokens in + # After pcp allgather and restore, there are padded tokens in # kv, so we need pad slotmapping for alignment. pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[ : total_num_scheduled_tokens * self.pcp_world_size From a58383337bbd7ef0a6b072737af257b288e36c2c Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Fri, 31 Oct 2025 16:23:17 +0800 Subject: [PATCH 15/15] [Perf] remove custom_mask Co-authored-by: QiuChunshuo Co-authored-by: gaojc <1055866782@qq.com> Signed-off-by: QiuChunshuo Signed-off-by: gaojc <1055866782@qq.com> --- vllm/v1/attention/backends/flashinfer.py | 134 ++++++++++++----- vllm/v1/attention/backends/utils.py | 14 ++ vllm/v1/worker/gpu_model_runner.py | 177 +++++++++++++++++++---- 3 files changed, 256 insertions(+), 69 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index b51da01cc2b4..96c1b1c220db 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -24,6 +24,7 @@ MultipleOf, ) from vllm.attention.ops.common import cp_lse_ag_out_ar +from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed.parallel_state import get_pcp_group from vllm.logger import init_logger @@ -51,6 +52,7 @@ get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills, + PrefillContextParallelMetadata, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -274,6 +276,7 @@ class FlashInferMetadata: # For context parallel pcp_allgather_restore_idx: torch.Tensor | None = None + pcp_metadata: PrefillContextParallelMetadata | None = None class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): @@ -425,16 +428,18 @@ def _get_workspace_buffer(self): ) return self._workspace_buffer - def _get_prefill_wrapper(self): - if self._prefill_wrapper is None: - if self.pcp_world_size > 1: - self._prefill_wrapper = BatchPrefillWithRaggedKVCacheWrapper( - self._get_workspace_buffer(), get_kv_cache_layout() - ) - else: - self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + def _get_prefill_wrapper(self, attn_metadata): + # if self._prefill_wrapper is None: + if self.pcp_world_size > 1: + self._prefill_wrapper = {} + for key in ["head", "tail"]: + self._prefill_wrapper[key] = BatchPrefillWithRaggedKVCacheWrapper( self._get_workspace_buffer(), get_kv_cache_layout() ) + else: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), get_kv_cache_layout() + ) return self._prefill_wrapper def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False): @@ -667,6 +672,7 @@ def build( num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, pcp_allgather_restore_idx=common_attn_metadata.pcp_allgather_restore_idx, + pcp_metadata=common_attn_metadata.pcp_metadata, ) qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu @@ -699,7 +705,7 @@ def build( if num_prefills > 0: # Decodes are first so prefills start after the last decode prefill_start = num_decodes - attn_metadata.prefill_wrapper = self._get_prefill_wrapper() + attn_metadata.prefill_wrapper = self._get_prefill_wrapper(common_attn_metadata) assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 assert ( @@ -721,38 +727,40 @@ def build( if not attn_metadata.prefill_use_trtllm: if self.pcp_world_size > 1: + assert common_attn_metadata.pcp_metadata is not None assert common_attn_metadata.query_positions is not None - prefill_num_computed_tokens_cpu = num_computed_tokens_cpu[ - prefill_start: - ] - kv_indptr_cpu = qo_indptr_cpu * self.pcp_world_size - # init custom mask for head-tail query order - custom_mask = self._get_pcp_custom_mask( - qo_indptr_cpu=qo_indptr_cpu, - q_pos=torch.from_numpy( - common_attn_metadata.query_positions[prefill_start:] - ).long().to(self.device), - kv_lens=( - prefill_num_computed_tokens_cpu - + kv_indptr_cpu[1:] - - kv_indptr_cpu[:-1] - ).to(self.device), - ) - attn_metadata.prefill_wrapper.plan( + pcp_metadata = common_attn_metadata.pcp_metadata + qo_indptr_cpu = pcp_metadata.q_head_start_loc + kv_for_head_indptr = pcp_metadata.kv_for_head_indptr + kv_for_tail_indptr = pcp_metadata.kv_for_tail_indptr + + attn_metadata.prefill_wrapper["head"].plan( qo_indptr_cpu.to(self.device), - kv_indptr_cpu.to(self.device), + kv_for_head_indptr.to(self.device), self.num_qo_heads, self.num_kv_heads, self.head_dim, - custom_mask=custom_mask, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + ) + # tail + attn_metadata.prefill_wrapper["tail"].plan( + qo_indptr_cpu.to(self.device), + kv_for_tail_indptr.to(self.device), + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + causal=True, sm_scale=self.sm_scale, window_left=self.window_left, logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, - fixed_split_size=self.prefill_fixed_split_size, - disable_split_kv=self.disable_split_kv, ) else: attn_metadata.prefill_wrapper.plan( @@ -926,6 +934,32 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): if self.sinks is not None and self.sinks.dtype != torch.float32: self.sinks = self.sinks.to(torch.float32) + def _attention_with_head_and_tail(self, + q_head: torch.Tensor, + q_tail: torch.Tensor, + k_head: torch.Tensor, + v_head: torch.Tensor, + k_tail: torch.Tensor, + v_tail: torch.Tensor, + prefill_wrapper: BatchPrefillWithRaggedKVCacheWrapper, + ): + output_head = torch.empty_like(q_head) + prefill_wrapper["head"].run( + q_head, + k_head, + v_head, + out=output_head, + ) + + output_tail = torch.empty_like(q_tail) + prefill_wrapper["tail"].run( + q_tail, + k_tail, + v_tail, + out=output_tail, + ) + return output_head, output_tail + def forward( self, layer: torch.nn.Module, @@ -1088,20 +1122,44 @@ def forward( assert prefill_wrapper is not None if not attn_metadata.prefill_use_trtllm: - assert prefill_wrapper._window_left == self.window_left - assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) - assert prefill_wrapper._sm_scale == self.scale if self.pcp_world_size > 1: + assert type(prefill_wrapper) == dict + for _, prefill_wrapper_i in prefill_wrapper.items(): + assert prefill_wrapper_i._window_left == self.window_left + assert prefill_wrapper_i._logits_soft_cap == (self.logits_soft_cap or 0.0) + assert prefill_wrapper_i._sm_scale == self.scale + assert attn_metadata.pcp_metadata is not None + pcp_metadata = attn_metadata.pcp_metadata + q_head_indices = pcp_metadata.q_head_indices + q_tail_indices = pcp_metadata.q_tail_indices + kv_for_head_indices = pcp_metadata.kv_for_head_indices + kv_for_tail_indices = pcp_metadata.kv_for_tail_indices + q_full_indices = pcp_metadata.q_full_indices + # NOTE(qcs): Allgather causes duplicate decoding tokens. prefill_key = key[num_decode_tokens * self.pcp_world_size :] prefill_value = value[num_decode_tokens * self.pcp_world_size :] - prefill_wrapper.run( - prefill_query, - prefill_key, - prefill_value, - out=output[num_decode_tokens:], + + output_head, output_tail = self._attention_with_head_and_tail( + torch.index_select(prefill_query, 0, q_head_indices), + torch.index_select(prefill_query, 0, q_tail_indices), + torch.index_select(prefill_key, 0, kv_for_head_indices), + torch.index_select(prefill_value, 0, kv_for_head_indices), + torch.index_select(prefill_key, 0, kv_for_tail_indices), + torch.index_select(prefill_value, 0, kv_for_tail_indices), + prefill_wrapper, + ) + + output_full = torch.index_select( + torch.cat([output_head, output_tail], dim=0), + 0, + q_full_indices ) + output[num_decode_tokens:] = output_full else: + assert prefill_wrapper._window_left == self.window_left + assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) + assert prefill_wrapper._sm_scale == self.scale assert prefill_wrapper._causal prefill_wrapper.run( prefill_query, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 8c17732a49ce..930dfc31881c 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -48,6 +48,19 @@ def is_valid_kv_cache_layout(value: str) -> bool: return value in get_args(KVCacheLayoutType) +@dataclass +class PrefillContextParallelMetadata: + """ + Attention metadata for prefill context parallel + """ + q_head_indices: torch.Tensor + q_tail_indices: torch.Tensor + q_head_start_loc: torch.Tensor + kv_for_head_indices: torch.Tensor + kv_for_tail_indices : torch.Tensor + kv_for_head_indptr: torch.Tensor + kv_for_tail_indptr: torch.Tensor + q_full_indices: torch.Tensor @dataclass class CommonAttentionMetadata: @@ -97,6 +110,7 @@ class CommonAttentionMetadata: # Needed by custom mask calc for context parallelism query_positions: np.ndarray | None = None pcp_allgather_restore_idx: torch.Tensor | None = None + pcp_metadata: PrefillContextParallelMetadata | None = None def slice_query_start_locs( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cec2c8e92108..015b9e638e02 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -93,6 +93,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, create_fast_prefill_custom_backend, + PrefillContextParallelMetadata, reorder_batch_to_split_decodes_and_prefills, split_attn_metadata, ) @@ -439,23 +440,24 @@ def __init__( self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) # Persistent buffers for Prefill Context Parallism - self.pcp_allgather_restore_idx = self._make_buffer( - max_num_padded_tokens, - dtype=torch.int64 - ) - self.pcp_padded_slot_mapping = torch.empty( - (max_num_padded_tokens,), - dtype=torch.int64, - device=self.device, - ) - self.num_pcp_pads_cpu_tensor = torch.zeros( - (self.max_num_reqs,), device="cpu", dtype=torch.int64, pin_memory=True - ) - self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy() - self.pcp_unpad_mask_cpu_tensor = torch.zeros( - (max_num_padded_tokens,), device="cpu", dtype=torch.bool, pin_memory=True - ) - self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() + if self.pcp_world_size > 1: + self.pcp_allgather_restore_idx = self._make_buffer( + max_num_padded_tokens, + dtype=torch.int64 + ) + self.pcp_padded_slot_mapping = torch.empty( + (max_num_padded_tokens,), + dtype=torch.int64, + device=self.device, + ) + self.num_pcp_pads_cpu_tensor = torch.zeros( + (self.max_num_reqs,), device="cpu", dtype=torch.int64, pin_memory=True + ) + self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy() + self.pcp_unpad_mask_cpu_tensor = torch.zeros( + (max_num_padded_tokens,), device="cpu", dtype=torch.bool, pin_memory=True + ) + self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -945,6 +947,99 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: dummy_modality = mm_budget.get_modality_with_max_tokens() return self._get_mm_dummy_batch(dummy_modality, num_seqs) + def _get_pcp_metadata(self, q_lens, kv_lens): + """ + During the prefill phrase, the attention computation is divided into + two parts: q_head and q_tail. Here, we calculate the kv indices + corresponding to q_head or q_tail. Meawhile, the q and kv indptr are + also computed to build the attention wrapper. + If the pcp_size is 2, the variables are following: + >>> q_lens [4, 8] kv_lens [8, 16] + >>> pcp_chunk_sizes[2, 4] + >>> q_indptr [0, 2, 4] + >>> q_head_indices [0, 1, 4, 5, 6, 7] q_tail_indices [2, 3, 8, 9, 10, 11] + >>> kv_head_len r0 [2, 4] / r1 [4, 8] + >>> kv_for_head_indptr r0 [0, 2, 6] / r1 [0, 4, 12] + >>> kv_for_head_indices r0 [0, 1, 8, 9, 10, 11] + >>> r1 [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15] + >>> kv_tail_len r0 [8, 16] / r1 [6, 12] + >>> kv_for_tail_indptr r0 [0, 8, 24] / r1 [0, 6, 18] + >>> kv_for_tail_indices r0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 23] + >>> r1 [0, 1, 2, 3, 4, 5, 8, 9, ..., 19] + """ + pcp_chunk_sizes = q_lens // 2 + q_indptr = np.zeros(len(pcp_chunk_sizes) + 1) + q_indptr[1:], q_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) + + # [4, 12] -> [12, 4] + q_head_start_loc = np.roll(np.cumsum(q_lens), 1) + q_head_start_loc[0] = 0 # [0, 4] + q_head_indices = q_chunk_arange + np.repeat( + q_head_start_loc, + pcp_chunk_sizes, + ) + + # [0, 4] + [2, 4] = [2, 8] + q_tail_start_loc = q_head_start_loc + pcp_chunk_sizes + q_tail_indices = q_chunk_arange + np.repeat( + q_tail_start_loc, + pcp_chunk_sizes, + ) + + # [8, 24] -> [24, 8] + kv_start_loc = np.roll(np.cumsum(kv_lens), 1) + kv_start_loc[0] = 0 # [0, 8] + # kv_for_q_head + kv_head_len = pcp_chunk_sizes * (self.pcp_rank + 1) + kv_for_head_indptr = np.zeros(len(kv_head_len) + 1) + kv_for_head_indptr[1:], kv_nomask_head_arange = self._get_cumsum_and_arange(kv_head_len) + kv_for_head_indices = kv_nomask_head_arange + np.repeat( + kv_start_loc, + kv_head_len, + ) + # kv_for_q_tail + kv_tail_len = pcp_chunk_sizes * (2 * self.pcp_world_size - self.pcp_rank) + kv_for_tail_indptr = np.zeros(len(kv_tail_len) + 1) + kv_for_tail_indptr[1:], kv_nomask_tail_arange = self._get_cumsum_and_arange(kv_tail_len) + kv_for_tail_indices = kv_nomask_tail_arange + np.repeat( + kv_start_loc, + kv_tail_len, + ) + + head_tail_indices = { + "q_head": q_head_indices, + "q_tail": q_tail_indices, + "kv_head": kv_for_head_indices, + "kv_tail": kv_for_tail_indices, + } + head_tail_indptr = { + "q": q_indptr, + "kv_head": kv_for_head_indptr, + "kv_tail": kv_for_tail_indptr + } + for key, value in head_tail_indices.items(): + head_tail_indices[key] = torch.from_numpy(value).to( + device=self.device, dtype=torch.int64, non_blocking=True + ) + for key, value in head_tail_indptr.items(): + head_tail_indptr[key] = torch.from_numpy(value).to( + dtype=torch.int64 + ) + + q_full_indices = torch.cat([head_tail_indices["q_head"], head_tail_indices["q_tail"]]) + q_full_indices = q_full_indices.to(torch.float32).argsort().to(torch.int32) + + return PrefillContextParallelMetadata( + q_head_indices=head_tail_indices["q_head"], + q_tail_indices=head_tail_indices["q_tail"], + q_head_start_loc=head_tail_indptr["q"], + kv_for_head_indices=head_tail_indices["kv_head"], + kv_for_tail_indices=head_tail_indices["kv_tail"], + kv_for_head_indptr=head_tail_indptr["kv_head"], + kv_for_tail_indptr=head_tail_indptr["kv_tail"], + q_full_indices=q_full_indices, + ) + def _update_tokens_for_pcp(self, tokens): """ If prefill context parallelism is enabled, we will calculate @@ -976,8 +1071,6 @@ def _update_tokens_for_pcp(self, tokens): """ num_reqs = self.input_batch.num_reqs self.num_pcp_pads_cpu[:num_reqs] = 0 - if not self.pcp_world_size > 1: - return tokens, None num_decode_reqs = sum( self.input_batch.num_computed_tokens_cpu[:num_reqs] @@ -1045,7 +1138,15 @@ def get_current_rank_positions( all_positions.argsort() ) self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) - return pcp_tokens, positions + return ( + pcp_tokens, + positions, + self._get_pcp_metadata( + pcp_tokens[num_decode_reqs:], + num_padded_scheduled_tokens[num_decode_reqs:], + ) if num_reqs > num_decode_reqs + else None, + ) def _get_cumsum_and_arange( self, @@ -1214,11 +1315,12 @@ def _prepare_inputs( self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + pcp_metadata = None if self.pcp_world_size > 1: - num_scheduled_tokens, pcp_positions = self._update_tokens_for_pcp( - num_scheduled_tokens - ) - assert pcp_positions is not None + num_scheduled_tokens, pcp_positions, pcp_metadata = \ + self._update_tokens_for_pcp( + num_scheduled_tokens + ) # Re-update after PCP split sequences. total_num_scheduled_tokens = sum(num_scheduled_tokens) @@ -1381,11 +1483,14 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = ( - torch.from_numpy(cu_num_tokens) * self.pcp_world_size - - self.num_pcp_pads_cpu_tensor[:num_reqs] - - 1 - ) + if self.pcp_world_size > 1: + logits_indices = ( + torch.from_numpy(cu_num_tokens) * self.pcp_world_size + - self.num_pcp_pads_cpu_tensor[:num_reqs] + - 1 + ) + else: + logits_indices = query_start_loc[1:] - 1 num_draft_tokens = None spec_decode_metadata = None else: @@ -1523,6 +1628,7 @@ def _prepare_inputs( ] if self.pcp_world_size > 1 else None, + pcp_metadata=pcp_metadata, dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None, @@ -3512,7 +3618,14 @@ def _dummy_run( ) self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() - + + pcp_metadata = None + if self.pcp_world_size > 1: + num_decode_reqs = sum(num_scheduled_tokens == 1) + pcp_metadata = self._get_pcp_metadata( + num_scheduled_tokens[num_decode_reqs:], + num_scheduled_tokens[num_decode_reqs:] * 2, + ) for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups ): @@ -3536,9 +3649,11 @@ def _dummy_run( ].slot_mapping.gpu[:num_tokens], causal=True, query_positions=query_positions, + pcp_metadata=pcp_metadata if self.pcp_world_size > 1 else None, pcp_allgather_restore_idx=self.pcp_allgather_restore_idx.gpu[ : total_num_scheduled_tokens * self.pcp_world_size - ], + ] if self.pcp_world_size > 1 + else None, dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None,