From 901028a51c985fc46819026d1ae816205bf92654 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Fri, 10 Oct 2025 19:43:45 +0800 Subject: [PATCH 01/12] disable cudagraph global for cp Signed-off-by: QiuChunshuo --- vllm/platforms/cuda.py | 8 ++++++++ vllm/v1/attention/backends/flashinfer.py | 3 +-- vllm/v1/attention/backends/utils.py | 3 +++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 05f129f513a0..39f7a13786d7 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -192,6 +192,14 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: "compatible. Set the all_to_all backend to deepep_low_latency " "to use those kernels instead.") compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and parallel_config.context_parallel_size > 1): + logger.info( + "Context Parallel: disabling cudagraphs since CP." + ) + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + @classmethod def get_current_memory_usage(cls, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4bf47c8f3e08..48a0264d9e3b 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -262,9 +262,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.kv_cache_spec.block_size) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req - # NOTE(qcs): Context Parallel do not support graph mode now self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\ - decode_mode() == CUDAGraphMode.FULL and self.cp_world_size == 1) + decode_mode() == CUDAGraphMode.FULL) if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index ff4e10e82edd..87fe97ce9d48 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -139,6 +139,8 @@ def _make_metadata_with_slice( block_table_tensor = attn_metadata.block_table_tensor[request_slice] slot_mapping = attn_metadata.slot_mapping[token_slice] + # TODO(qcs): check if we can split query_positions and + # cp_kv_recover_idx as following approach query_positions = attn_metadata.query_positions[token_slice] \ if attn_metadata.query_positions is not None else None cp_kv_recover_idx = attn_metadata.cp_kv_recover_idx[token_slice] \ @@ -710,6 +712,7 @@ def split_decodes_and_prefills( num_prefills = num_reqs - num_decodes num_decode_tokens = query_start_loc[first_prefill].item() num_prefill_tokens = num_tokens - num_decode_tokens + print(f"q lens: {query_lens}, num_tokens: {num_tokens}, D_tokens: {num_decode_tokens}, P_tokens: {num_prefill_tokens} ") return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) From cf4d3ab68bfc756e2fff361b8cbdc5c5f1dd7a8b Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Sat, 11 Oct 2025 18:11:26 +0800 Subject: [PATCH 02/12] [Feature] support multi-requests Signed-off-by: QiuChunshuo --- vllm/v1/attention/backends/flashinfer.py | 90 ++++++++++-------------- vllm/v1/attention/backends/utils.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 77 ++++++++++---------- 3 files changed, 76 insertions(+), 94 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 48a0264d9e3b..18fe47910766 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -603,16 +603,19 @@ def build(self, if self.cp_world_size > 1: # NOTE(qcs): no chunked prefill and prefix caching kv_indptr_cpu = qo_indptr_cpu * self.cp_world_size + kv_lens = kv_indptr_cpu[1:] - kv_indptr_cpu[:-1] - \ + common_attn_metadata.num_cp_pads[prefill_start:] + kv_indptr_cpu[1:] = torch.cumsum(kv_lens, 0) # init custom mask for head-tail query order mask_arr = [] - q_pos = common_attn_metadata.query_positions + q_pos = common_attn_metadata.query_positions[prefill_start:] for i in range(num_prefills): # |---------|--|--| # |-------| # cp_world_size = 2 # Q = 2 # C = 8 - # cur_q_pos = [0,3] + # cur_q_pos = [0,3] // [1, 2] in another rank # context_mask_i.shape = (2, 8) # upper = [0,1,2,3] # local_mask_i = [[True, False, False, False], @@ -625,7 +628,7 @@ def build(self, mask_arr.append(torch.zeros(0, dtype=torch.bool)) continue context_mask_i = torch.ones((Q, C), dtype=torch.bool) - upper = torch.arange(Q*self.cp_world_size) + upper = torch.arange(kv_lens[i]) local_mask_i = (upper.unsqueeze(0) <= cur_q_pos.unsqueeze(1)) mask_i = torch.cat([context_mask_i, local_mask_i], dim=1) mask_arr.append(mask_i.flatten()) @@ -873,6 +876,26 @@ def forward( # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_tokens = attn_metadata.num_prefill_tokens + + key_across_cp = get_cp_group().all_gather( + key.contiguous(), dim=0) + value_across_cp = get_cp_group().all_gather( + value.contiguous(), dim=0) + if (self.cp_world_size > 1 + and attn_metadata.cp_kv_recover_idx is not None): + # reorder kv after cp allgather and remove duplicate decoding tokens + key_across_cp = torch.index_select( + key_across_cp, 0, + attn_metadata.cp_kv_recover_idx + ) + value_across_cp = torch.index_select( + value_across_cp, 0, + attn_metadata.cp_kv_recover_idx + ) + key = key_across_cp + value = value_across_cp if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. @@ -882,17 +905,16 @@ def forward( # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - if self.cp_world_size == 1: - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 # to process the cache when the kv_cache_dtype is fp8 @@ -912,9 +934,6 @@ def forward( output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) return output - num_decode_tokens = attn_metadata.num_decode_tokens - num_prefill_tokens = attn_metadata.num_prefill_tokens - stride_order = FlashInferBackend.get_kv_cache_stride_order() kv_cache_permute = kv_cache.permute(*stride_order) # Regular attention (common case). @@ -932,34 +951,12 @@ def forward( self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale if self.cp_world_size > 1: - key_across_cp = get_cp_group().all_gather( - key[num_decode_tokens:].contiguous(), dim=0) - value_across_cp = get_cp_group().all_gather( - value[num_decode_tokens:].contiguous(), dim=0) - key_across_cp = torch.index_select( - key_across_cp, 0, - attn_metadata.cp_kv_recover_idx - ) - value_across_cp = torch.index_select( - value_across_cp, 0, - attn_metadata.cp_kv_recover_idx - ) - torch.ops._C_cache_ops.reshape_and_cache_flash( - key_across_cp, - value_across_cp, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping[num_decode_tokens:], - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) # TODO(qcs): 考虑 chunked prefill/ prefix cache 情况下 # kvcache的获取与拼接 prefill_wrapper.run( prefill_query, - key_across_cp, - value_across_cp, + key[num_decode_tokens:], + value[num_decode_tokens:], out=output[num_decode_tokens:], ) else: @@ -1046,17 +1043,6 @@ def forward( or 0.0) assert decode_wrapper._sm_scale == self.scale if self.cp_world_size > 1: - torch.ops._C_cache_ops.reshape_and_cache_flash( - key[:num_decode_tokens], - value[:num_decode_tokens], - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping[:num_decode_tokens], - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - kv_cache_permute = kv_cache.permute(*stride_order) out, lse = decode_wrapper.run( decode_query, kv_cache_permute, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 87fe97ce9d48..47a9018eb2de 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -84,6 +84,8 @@ class CommonAttentionMetadata: # Needed by custom mask calc for context parallelism query_positions: Optional[np.ndarray] = None cp_kv_recover_idx: Optional[torch.Tensor] = None + cp_kv_recover_idx_for_slotmapping: Optional[torch.Tensor] = None + num_cp_pads: Optional[torch.Tensor] = None def slice_query_start_locs( query_start_loc: torch.Tensor, @@ -712,7 +714,6 @@ def split_decodes_and_prefills( num_prefills = num_reqs - num_decodes num_decode_tokens = query_start_loc[first_prefill].item() num_prefill_tokens = num_tokens - num_decode_tokens - print(f"q lens: {query_lens}, num_tokens: {num_tokens}, D_tokens: {num_decode_tokens}, P_tokens: {num_prefill_tokens} ") return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 757cc5e7fccc..9565eb5717fe 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -836,10 +836,12 @@ def _num_scheduled_tokens_prefill_cp(self, num_tokens, def _update_tokens_for_cp(self, tokens, scheduler_output: "SchedulerOutput"): if not self.cp_world_size > 1: + self.num_cp_pads = None + self.cp_kv_recover_idx = None return tokens num_reqs = self.input_batch.num_reqs - self.num_cp_pads = np.empty(num_reqs, dtype=np.int32) - self.cp_kv_recover_idx: List[List[int]] = [[] + self.num_cp_pads = torch.empty(num_reqs, dtype=torch.int32) + self.cp_kv_recover_idx: Union(List[List[int]], torch.Tensor, None) = [[] for _ in range(self.cp_world_size) ] self.position_cp = np.zeros(self.max_num_tokens, dtype=np.int32) @@ -861,26 +863,31 @@ def _update_tokens_for_cp(self, tokens, scheduler_output: "SchedulerOutput"): start_index += num_tokens tokens[i] = num_tokens else: - self.num_cp_pads[i] = 0 + self.num_cp_pads[i] = self.cp_world_size-1 # we allgather cp_world_size duplicated tokens in decode phase self.position_cp[start_index:start_index + num_tokens] = [idx for idx in range(num_tokens)] start_index += num_tokens - for rank in range(len(self.cp_kv_recover_idx)): - self.cp_kv_recover_idx[rank].append(rank) + num_added_recover_tokens = len(self.cp_kv_recover_idx[0]) * self.cp_world_size + for rank in range(self.cp_world_size): + self.cp_kv_recover_idx[rank].append(rank+num_added_recover_tokens) + + + cp_kv_recover_idx = torch.from_numpy(np.concatenate(self.cp_kv_recover_idx) + ).to(device=self.device) + cp_kv_recover_idx = cp_kv_recover_idx.to( + torch.float32).argsort( + stable=True).to(torch.int32) + mask = torch.ones_like(cp_kv_recover_idx).to(torch.bool) + cur_req_end_loc = 0 + for i, req_id in enumerate(self.input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + cp_pad = self.num_cp_pads[i] + cur_req_end_loc += num_tokens + cp_pad + mask[cur_req_end_loc-cp_pad:cur_req_end_loc] = 0 + self.cp_kv_recover_idx = cp_kv_recover_idx[mask] + return tokens - def _update_logits_indices_for_cp(self, cu_num_tokens, scheduler_output: "SchedulerOutput"): - # todo: find a better way to get is_prefill - is_prefill = list( - scheduler_output.num_scheduled_tokens.values())[0] > 1 - num_reqs = self.input_batch.num_reqs - if self.cp_world_size > 1 and is_prefill: - # logits_indices = cu_num_tokens - num_cp_pads[:num_reqs] - 1 # if without all-gather and only sample on cp0 - logits_indices = cu_num_tokens * self.cp_world_size \ - - torch.tensor(self.num_cp_pads[:num_reqs]).to(cu_num_tokens) - 1 - else: - logits_indices = cu_num_tokens - 1 - return logits_indices def _get_cumsum_and_arange( self, @@ -1039,7 +1046,7 @@ def _prepare_inputs( if self.cp_world_size > 1: req_indices_for_slotmapping = np.repeat(self.arange_np[:num_reqs], original_num_scheduled_tokens) - _, original_arange = self._get_cumsum_and_arange( + cu_num_tokens_for_logits_indices, original_arange = self._get_cumsum_and_arange( original_num_scheduled_tokens) positions_np_for_slotmapping = self.positions.np[ :total_num_scheduled_tokens_for_slotmapping].copy() @@ -1053,6 +1060,7 @@ def _prepare_inputs( np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) + cu_num_tokens_for_logits_indices = cu_num_tokens req_indices_for_slotmapping = req_indices positions_np_for_slotmapping = positions_np @@ -1191,10 +1199,7 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = self._update_logits_indices_for_cp( - query_start_loc[1:], - scheduler_output - ) + logits_indices = torch.from_numpy(cu_num_tokens_for_logits_indices) - 1 num_draft_tokens = None spec_decode_metadata = None else: @@ -1239,18 +1244,6 @@ def _prepare_inputs( if self.cp_world_size > 1: # Prepare the metadata for Context Parallel total_num_scheduled_tokens_for_slotmapping = sum(original_num_scheduled_tokens[:num_reqs]) - - total_prefill_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) - cp_kv_recover_idx = torch.zeros(total_prefill_num_scheduled_tokens * self.cp_world_size, - dtype=torch.int32, - device=self.device) - cp_kv_recover_idx.copy_(torch.tensor( - np.array(self.cp_kv_recover_idx).flatten().tolist()), - non_blocking=True) - self.cp_kv_recover_idx = cp_kv_recover_idx.to( - torch.float32).argsort().to(torch.int32) - else: - self.cp_kv_recover_idx = None # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -1305,6 +1298,7 @@ def _prepare_inputs( encoder_seq_lens=encoder_seq_lens, query_positions=positions_np, cp_kv_recover_idx=self.cp_kv_recover_idx, + num_cp_pads=self.num_cp_pads, ) if self.speculative_config and \ @@ -1970,11 +1964,6 @@ def _pool( ) def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: - cp_size = self.vllm_config.parallel_config.context_parallel_size - if cp_size > 1: - # TODO(qcs): When ContextParallel is adapted to GraphMode, - # revise this length alignment strategy again. - return cdiv(num_scheduled_tokens, self.cp_world_size * 2) * 2 if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH and hasattr(self, "cudagraph_batch_sizes") @@ -1999,6 +1988,7 @@ def _preprocess( intermediate_tensors: Optional[IntermediateTensors] = None, ubatch_slices: Optional[UBatchSlices] = None, num_tokens_after_padding: Optional[torch.Tensor] = None, + num_scheduled_tokens_after_cp: Optional[int] = None, ) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor, Optional[IntermediateTensors], dict[str, Any]]: @@ -2009,7 +1999,11 @@ def _preprocess( num_input_tokens = int(num_tokens_after_padding[0].item() * 2) self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) elif ubatch_slices is None: - num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) + if self.cp_world_size == 1: + num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) + else: + assert num_scheduled_tokens_after_cp is not None + num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens_after_cp) num_pad, num_tokens_after_padding = self.get_dp_padding( num_input_tokens) num_input_tokens += num_pad @@ -2316,7 +2310,8 @@ def execute_model( intermediate_tensors, model_kwargs, ) = self._preprocess(scheduler_output, intermediate_tensors, - ubatch_slices, num_tokens_after_padding) + ubatch_slices, num_tokens_after_padding, + sum(num_scheduled_tokens_np)) if ubatch_slices is not None: num_input_tokens = num_input_tokens // 2 From 07f9920cee8aade2227ab3dcadaf50ccb6ad8cb7 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Sat, 11 Oct 2025 18:26:57 +0800 Subject: [PATCH 03/12] [typo] remove unused parameters Signed-off-by: QiuChunshuo --- vllm/v1/attention/backends/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 47a9018eb2de..8b5cccf63d51 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -84,7 +84,6 @@ class CommonAttentionMetadata: # Needed by custom mask calc for context parallelism query_positions: Optional[np.ndarray] = None cp_kv_recover_idx: Optional[torch.Tensor] = None - cp_kv_recover_idx_for_slotmapping: Optional[torch.Tensor] = None num_cp_pads: Optional[torch.Tensor] = None def slice_query_start_locs( From 1407678190034d250c9208ab490770836ed1aabb Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Tue, 14 Oct 2025 10:03:52 +0800 Subject: [PATCH 04/12] [ModelRunner] Refactor _update_tokens_for_cp Signed-off-by: QiuChunshuo --- vllm/v1/attention/backends/flashinfer.py | 54 +++--- vllm/v1/attention/backends/utils.py | 9 +- vllm/v1/worker/gpu_model_runner.py | 227 +++++++++++------------ 3 files changed, 148 insertions(+), 142 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 18fe47910766..6712bde45c80 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -239,7 +239,7 @@ class FlashInferMetadata: paged_kv_indptr_gpu: Optional[torch.Tensor] = None # For context parallel - cp_kv_recover_idx: Optional[torch.Tensor] = None + cp_allgather_restore_idx: Optional[torch.Tensor] = None class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): @@ -551,7 +551,7 @@ def build(self, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, - cp_kv_recover_idx=common_attn_metadata.cp_kv_recover_idx, + cp_allgather_restore_idx=common_attn_metadata.cp_allgather_restore_idx, ) qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu @@ -598,17 +598,15 @@ def build(self, qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[ prefill_start] paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] - prefill_num_computed_tokens_cpu = num_computed_tokens_cpu[prefill_start:] + prefill_num_computed_tokens_cpu = \ + num_computed_tokens_cpu[prefill_start:] if not attn_metadata.prefill_use_trtllm: if self.cp_world_size > 1: - # NOTE(qcs): no chunked prefill and prefix caching kv_indptr_cpu = qo_indptr_cpu * self.cp_world_size - kv_lens = kv_indptr_cpu[1:] - kv_indptr_cpu[:-1] - \ - common_attn_metadata.num_cp_pads[prefill_start:] - kv_indptr_cpu[1:] = torch.cumsum(kv_lens, 0) # init custom mask for head-tail query order mask_arr = [] - q_pos = common_attn_metadata.query_positions[prefill_start:] + q_pos = common_attn_metadata.query_positions[ + prefill_start:] for i in range(num_prefills): # |---------|--|--| # |-------| @@ -619,18 +617,23 @@ def build(self, # context_mask_i.shape = (2, 8) # upper = [0,1,2,3] # local_mask_i = [[True, False, False, False], - # [True, True, True, True]] # size=(2, 4) + # [True, True, True, True]] # mask_i.shape = (2, 12) - cur_q_pos = torch.from_numpy(q_pos[qo_indptr_cpu[i]:qo_indptr_cpu[i+1]]) + cur_q_pos = torch.from_numpy(q_pos[qo_indptr_cpu[i] + :qo_indptr_cpu[i+1]]) Q = len(cur_q_pos) C = prefill_num_computed_tokens_cpu[i] if Q <= 0: - mask_arr.append(torch.zeros(0, dtype=torch.bool)) + mask_arr.append(torch.zeros(0, + dtype=torch.bool)) continue - context_mask_i = torch.ones((Q, C), dtype=torch.bool) - upper = torch.arange(kv_lens[i]) - local_mask_i = (upper.unsqueeze(0) <= cur_q_pos.unsqueeze(1)) - mask_i = torch.cat([context_mask_i, local_mask_i], dim=1) + context_mask_i = torch.ones((Q, C), + dtype=torch.bool) + upper = torch.arange(Q*self.cp_world_size) + local_mask_i = (upper.unsqueeze(0) + <= cur_q_pos.unsqueeze(1)) + mask_i = torch.cat([context_mask_i, local_mask_i], + dim=1) mask_arr.append(mask_i.flatten()) custom_mask = torch.cat(mask_arr, dim=0).to(self.device) @@ -884,15 +887,17 @@ def forward( value_across_cp = get_cp_group().all_gather( value.contiguous(), dim=0) if (self.cp_world_size > 1 - and attn_metadata.cp_kv_recover_idx is not None): - # reorder kv after cp allgather and remove duplicate decoding tokens + and attn_metadata.cp_allgather_restore_idx is not None): + # Reorder kv after cp allgather. + # Note that there are duplicate decoding tokens, + # but we only save the first one in kvcache. key_across_cp = torch.index_select( key_across_cp, 0, - attn_metadata.cp_kv_recover_idx + attn_metadata.cp_allgather_restore_idx ) value_across_cp = torch.index_select( value_across_cp, 0, - attn_metadata.cp_kv_recover_idx + attn_metadata.cp_allgather_restore_idx ) key = key_across_cp value = value_across_cp @@ -951,12 +956,15 @@ def forward( self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale if self.cp_world_size > 1: - # TODO(qcs): 考虑 chunked prefill/ prefix cache 情况下 - # kvcache的获取与拼接 + # NOTE(qcs): Allgather causes duplicate decoding tokens. + prefill_key = key[ + num_decode_tokens*self.cp_world_size:] + prefill_value = value[ + num_decode_tokens*self.cp_world_size:] prefill_wrapper.run( prefill_query, - key[num_decode_tokens:], - value[num_decode_tokens:], + prefill_key, + prefill_value, out=output[num_decode_tokens:], ) else: diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 8b5cccf63d51..2dd7dd89f0bb 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -83,8 +83,7 @@ class CommonAttentionMetadata: # Needed by custom mask calc for context parallelism query_positions: Optional[np.ndarray] = None - cp_kv_recover_idx: Optional[torch.Tensor] = None - num_cp_pads: Optional[torch.Tensor] = None + cp_allgather_restore_idx: Optional[torch.Tensor] = None def slice_query_start_locs( query_start_loc: torch.Tensor, @@ -144,8 +143,8 @@ def _make_metadata_with_slice( # cp_kv_recover_idx as following approach query_positions = attn_metadata.query_positions[token_slice] \ if attn_metadata.query_positions is not None else None - cp_kv_recover_idx = attn_metadata.cp_kv_recover_idx[token_slice] \ - if attn_metadata.cp_kv_recover_idx is not None else None + cp_allgather_restore_idx = attn_metadata.cp_allgather_restore_idx[token_slice] \ + if attn_metadata.cp_allgather_restore_idx is not None else None return CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -160,7 +159,7 @@ def _make_metadata_with_slice( block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, query_positions=query_positions, - cp_kv_recover_idx=cp_kv_recover_idx, + cp_allgather_restore_idx=cp_allgather_restore_idx, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9565eb5717fe..91b14d61abcf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -797,96 +797,92 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: dummy_modality = mm_budget.get_modality_with_max_tokens() return self._get_mm_dummy_batch(dummy_modality, num_seqs) - def _num_scheduled_tokens_prefill_cp(self, num_tokens, - num_computed_tokens, - cp_kv_recover_idx): - num_scheduled_tokens = num_tokens - num_computed_tokens - num_cp_padded_scheduled_tokens = cdiv( - num_scheduled_tokens, 2 * self.cp_world_size) * (2 * self.cp_world_size - ) # pad to 2*cp_world_size - cp_pad = num_cp_padded_scheduled_tokens - num_scheduled_tokens - full_indices = list( - range(self.max_num_tokens * self.cp_world_size * self.dcp_world_size + - self.cp_world_size * self.dcp_world_size * self.max_num_reqs)) - chunk_size = num_cp_padded_scheduled_tokens // (2 * self.cp_world_size) - - # split position_ids (and use split position_ids to split input_ids afterwards) - req_position_cp = [] - req_position_cp.extend( - full_indices[self.cp_rank * chunk_size:(self.cp_rank + 1) * - chunk_size]) - req_position_cp.extend( - full_indices[num_cp_padded_scheduled_tokens - (self.cp_rank + 1) * - chunk_size:num_cp_padded_scheduled_tokens - - self.cp_rank * chunk_size]) - - # used to recover kv order in cp prefill (after all-gather kv and before storing kv_cache) - num_added_recover_tokens = len(cp_kv_recover_idx[0]) * self.cp_world_size - for rank in range(self.cp_world_size): - cp_kv_recover_idx[rank].extend( - full_indices[rank * chunk_size + - num_added_recover_tokens:(rank + 1) * chunk_size + - num_added_recover_tokens]) - cp_kv_recover_idx[rank].extend(full_indices[ - num_cp_padded_scheduled_tokens - (rank + 1) * chunk_size + - num_added_recover_tokens:num_cp_padded_scheduled_tokens - - rank * chunk_size + num_added_recover_tokens]) - - return req_position_cp, num_cp_padded_scheduled_tokens, cp_pad - - def _update_tokens_for_cp(self, tokens, scheduler_output: "SchedulerOutput"): - if not self.cp_world_size > 1: - self.num_cp_pads = None - self.cp_kv_recover_idx = None - return tokens + def _update_tokens_for_cp(self, tokens): + """ + If context parallelism is enabled, we will calculate + the number of tokens `tokens` after sequence splitting. + Meanwhile, we will compute: + `positions` the new token positions, + `num_cp_pads` the number of padding tokens per request for alignment, + `unpad_mask` the mask for non-padded tokens, + `cp_allgather_restore_idx` indices to restore the original vector + order after CP allgather. + Example: + >>> tokens = [1, 5, 8] + >>> cp_world_size = 2 + >>> cp_rank = 0 + >>> _update_tokens_for_cp(tokens) + ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5], [1, 3, 0], [True, False, + True, True, True, True, True, False, False, False, True, True, + True, True, True, True, True, True], [0, 9, 1, 2, 10, 11, 12, 13, + 3, 4, 5, 6, 14, 15, 16, 17, 7, 8]) + >>> cp_rank = 1 + >>> _update_tokens_for_cp(tokens) + ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7], [1, 3, 0], [True, False, + True, True, True, True, True, False, False, False, True, True, + True, True, True, True, True, True], [0, 9, 1, 2, 10, 11, 12, 13, + 3, 4, 5, 6, 14, 15, 16, 17, 7, 8]) + """ num_reqs = self.input_batch.num_reqs - self.num_cp_pads = torch.empty(num_reqs, dtype=torch.int32) - self.cp_kv_recover_idx: Union(List[List[int]], torch.Tensor, None) = [[] - for _ in range(self.cp_world_size) - ] - self.position_cp = np.zeros(self.max_num_tokens, dtype=np.int32) - start_index = 0 - - for i, req_id in enumerate(self.input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - is_prefill = num_tokens > 1 # todo: compare num prompt tokens and num sch tokens + computed tokens - if is_prefill: - # when cp > 1 & prefill, need to pad & split sequence here - req_position_cp, num_cp_padded_scheduled_tokens, self.num_cp_pads[ - i] = self._num_scheduled_tokens_prefill_cp( - num_tokens, - self.input_batch.num_computed_tokens_cpu[i], - self.cp_kv_recover_idx) - num_tokens = len(req_position_cp) - self.position_cp[start_index:start_index + - num_tokens] = req_position_cp - start_index += num_tokens - tokens[i] = num_tokens - else: - self.num_cp_pads[i] = self.cp_world_size-1 # we allgather cp_world_size duplicated tokens in decode phase - self.position_cp[start_index:start_index + - num_tokens] = [idx for idx in range(num_tokens)] - start_index += num_tokens - num_added_recover_tokens = len(self.cp_kv_recover_idx[0]) * self.cp_world_size - for rank in range(self.cp_world_size): - self.cp_kv_recover_idx[rank].append(rank+num_added_recover_tokens) - - - cp_kv_recover_idx = torch.from_numpy(np.concatenate(self.cp_kv_recover_idx) - ).to(device=self.device) - cp_kv_recover_idx = cp_kv_recover_idx.to( - torch.float32).argsort( - stable=True).to(torch.int32) - mask = torch.ones_like(cp_kv_recover_idx).to(torch.bool) - cur_req_end_loc = 0 - for i, req_id in enumerate(self.input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - cp_pad = self.num_cp_pads[i] - cur_req_end_loc += num_tokens + cp_pad - mask[cur_req_end_loc-cp_pad:cur_req_end_loc] = 0 - self.cp_kv_recover_idx = cp_kv_recover_idx[mask] - - return tokens + num_cp_pads = torch.zeros(num_reqs, dtype=torch.int32) + if not self.cp_world_size > 1: + return tokens, None, num_cp_pads, None, None + + num_decode_reqs = sum(self.input_batch.num_computed_tokens_cpu[ + :num_reqs] >= self.input_batch.num_prompt_tokens[:num_reqs]) + + num_padded_scheduled_tokens = np.ceil( + tokens / (2 * self.cp_world_size) + ) * (2 * self.cp_world_size) + # we align scheduled tokens of decode reqs to cp_world_size instead + # of 2*cp_world_size + num_padded_scheduled_tokens[:num_decode_reqs] = self.cp_world_size + num_cp_pads = torch.from_numpy(num_padded_scheduled_tokens - tokens) + cu_padded_tokens, cp_padded_arange = \ + self._get_cumsum_and_arange(num_padded_scheduled_tokens) + unpad_mask = torch.from_numpy(cp_padded_arange < + np.repeat( + tokens, + num_padded_scheduled_tokens + )) + + cp_tokens = num_padded_scheduled_tokens // self.cp_world_size + cp_chunk_sizes = (cp_tokens // 2).clip(min=1) + _, cp_arange = self._get_cumsum_and_arange(cp_tokens) + _, cp_chunk_arange = self._get_cumsum_and_arange(cp_chunk_sizes) + cp_head_chunk_mask = cp_arange < np.repeat(cp_chunk_sizes, + cp_tokens) + + + def get_current_rank_positions(cu_tokens, rank): + positions_start_loc = np.zeros_like(cu_tokens) + positions_start_loc[1:] = cu_tokens[:-1] + positions = np.zeros(len(cp_head_chunk_mask), dtype=np.int32) + head_start_loc = positions_start_loc + rank * cp_chunk_sizes + tail_start_loc = positions_start_loc + \ + (2 * self.cp_world_size - rank - 1) * cp_chunk_sizes + positions[cp_head_chunk_mask] = cp_chunk_arange + \ + np.repeat(head_start_loc, cp_chunk_sizes) + # Decode reqs do not have tail chunks. + positions[~cp_head_chunk_mask] = \ + cp_chunk_arange[num_decode_reqs:] + \ + np.repeat(tail_start_loc, cp_chunk_sizes)[num_decode_reqs:] + return positions + + positions = get_current_rank_positions(np.zeros(num_reqs, + dtype=np.int32), + self.cp_rank) + # Decode tokens are duplicate and their positions always be 0. + positions[:num_decode_reqs] = 0 + + all_positions = [get_current_rank_positions(cu_padded_tokens, + rank_i) + for rank_i in range(self.cp_world_size)] + all_positions = torch.from_numpy(np.concatenate(all_positions)) + cp_allgather_restore_idx = all_positions.float().argsort( + ).long().to(self.device) + return (cp_tokens, positions, num_cp_pads, + unpad_mask, cp_allgather_restore_idx) def _get_cumsum_and_arange( @@ -1023,10 +1019,13 @@ def _prepare_inputs( # Get the number of scheduled tokens for each request. req_ids = self.input_batch.req_ids tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - total_num_scheduled_tokens_for_slotmapping = total_num_scheduled_tokens + # NOTE(qcs): we need compute slotmapping for all kv + # instead of sliced sequences + total_num_scheduled_tokens4sltmap = total_num_scheduled_tokens original_num_scheduled_tokens = np.array(tokens, dtype=np.int32) - tokens = self._update_tokens_for_cp(tokens, scheduler_output) - num_scheduled_tokens = np.array(tokens, dtype=np.int32) + num_scheduled_tokens, positions_cp, num_cp_pads, unpad_mask, \ + self.cp_allgather_restore_idx = self._update_tokens_for_cp( + original_num_scheduled_tokens) # update total_num_scheduled_tokens total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) max_num_scheduled_tokens = max(tokens) @@ -1044,23 +1043,22 @@ def _prepare_inputs( # Get positions. positions_np = self.positions.np[:total_num_scheduled_tokens] if self.cp_world_size > 1: + assert positions_cp is not None req_indices_for_slotmapping = np.repeat(self.arange_np[:num_reqs], original_num_scheduled_tokens) - cu_num_tokens_for_logits_indices, original_arange = self._get_cumsum_and_arange( + _, original_arange = self._get_cumsum_and_arange( original_num_scheduled_tokens) - positions_np_for_slotmapping = self.positions.np[ - :total_num_scheduled_tokens_for_slotmapping].copy() + positions_np_for_slotmapping = \ np.add(self.input_batch.num_computed_tokens_cpu[req_indices_for_slotmapping], original_arange, - out=positions_np_for_slotmapping) + ) np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - self.position_cp[:total_num_scheduled_tokens], + positions_cp[:total_num_scheduled_tokens], out=positions_np) else: np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) - cu_num_tokens_for_logits_indices = cu_num_tokens req_indices_for_slotmapping = req_indices positions_np_for_slotmapping = positions_np @@ -1133,7 +1131,7 @@ def _prepare_inputs( self.input_batch.block_table.compute_slot_mapping( req_indices_for_slotmapping, positions_np_for_slotmapping) self.input_batch.block_table.commit_slot_mapping( - total_num_scheduled_tokens_for_slotmapping) + total_num_scheduled_tokens4sltmap) # Prepare the attention metadata. self.query_start_loc.np[0] = 0 @@ -1199,7 +1197,8 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = torch.from_numpy(cu_num_tokens_for_logits_indices) - 1 + logits_indices = torch.from_numpy(cu_num_tokens) * \ + self.cp_world_size - num_cp_pads - 1 num_draft_tokens = None spec_decode_metadata = None else: @@ -1241,9 +1240,6 @@ def _prepare_inputs( self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() - if self.cp_world_size > 1: - # Prepare the metadata for Context Parallel - total_num_scheduled_tokens_for_slotmapping = sum(original_num_scheduled_tokens[:num_reqs]) # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -1261,7 +1257,7 @@ def _prepare_inputs( device=self.device, ) slot_mapping = torch.zeros( - (total_num_scheduled_tokens_for_slotmapping, ), + (total_num_scheduled_tokens4sltmap, ), dtype=torch.int64, device=self.device, ) @@ -1270,16 +1266,24 @@ def _prepare_inputs( blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor(num_reqs) slot_mapping = blk_table.slot_mapping.gpu[: - total_num_scheduled_tokens_for_slotmapping] + total_num_scheduled_tokens4sltmap] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. - blk_table.slot_mapping.gpu[total_num_scheduled_tokens_for_slotmapping:].fill_( - -1) + blk_table.slot_mapping.gpu[total_num_scheduled_tokens4sltmap: + ].fill_(-1) num_common_prefix_blocks = ( scheduler_output. num_common_prefix_blocks[kv_cache_group_id]) + if self.cp_world_size > 1: + assert unpad_mask is not None + # After cp allgather and restore, there are padded tokens in + # kv, so we need pad slotmapping for alignment. + padded_slot_mapping = torch.full((unpad_mask.shape[0],), + 1).to(slot_mapping) + padded_slot_mapping[unpad_mask] = slot_mapping + slot_mapping = padded_slot_mapping common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, @@ -1297,8 +1301,7 @@ def _prepare_inputs( causal=True, encoder_seq_lens=encoder_seq_lens, query_positions=positions_np, - cp_kv_recover_idx=self.cp_kv_recover_idx, - num_cp_pads=self.num_cp_pads, + cp_allgather_restore_idx=self.cp_allgather_restore_idx, ) if self.speculative_config and \ @@ -2356,13 +2359,9 @@ def execute_model( aux_hidden_states = None if self.cp_world_size > 1: - if isinstance(attn_metadata, dict): - cp_kv_recover_idx = list(attn_metadata.values())[0].cp_kv_recover_idx - else: - cp_kv_recover_idx = attn_metadata.cp_kv_recover_idx hidden_states = get_cp_group().all_gather(hidden_states, 0) hidden_states = torch.index_select( - hidden_states, 0, cp_kv_recover_idx) + hidden_states, 0, self.cp_allgather_restore_idx) if not self.broadcast_pp_output: # Common case. if not get_pp_group().is_last_rank: From b40f23effc9b64a9488a35d48e905ecd0355ed88 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Tue, 14 Oct 2025 11:14:44 +0800 Subject: [PATCH 05/12] [bugfix] type convert error Signed-off-by: QiuChunshuo --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 91b14d61abcf..0c50258efcab 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -833,7 +833,7 @@ def _update_tokens_for_cp(self, tokens): num_padded_scheduled_tokens = np.ceil( tokens / (2 * self.cp_world_size) - ) * (2 * self.cp_world_size) + ).astype(np.int32) * (2 * self.cp_world_size) # we align scheduled tokens of decode reqs to cp_world_size instead # of 2*cp_world_size num_padded_scheduled_tokens[:num_decode_reqs] = self.cp_world_size From cdf24953c6a16280b161bdaaa20fa35005dd44ed Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Tue, 14 Oct 2025 13:06:04 +0800 Subject: [PATCH 06/12] [flashinfer] Vectorized custom_mask computation Signed-off-by: QiuChunshuo --- vllm/v1/attention/backends/flashinfer.py | 49 ++++++++---------------- 1 file changed, 17 insertions(+), 32 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 6712bde45c80..192da433ecf0 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -604,38 +604,23 @@ def build(self, if self.cp_world_size > 1: kv_indptr_cpu = qo_indptr_cpu * self.cp_world_size # init custom mask for head-tail query order - mask_arr = [] - q_pos = common_attn_metadata.query_positions[ - prefill_start:] - for i in range(num_prefills): - # |---------|--|--| - # |-------| - # cp_world_size = 2 - # Q = 2 - # C = 8 - # cur_q_pos = [0,3] // [1, 2] in another rank - # context_mask_i.shape = (2, 8) - # upper = [0,1,2,3] - # local_mask_i = [[True, False, False, False], - # [True, True, True, True]] - # mask_i.shape = (2, 12) - cur_q_pos = torch.from_numpy(q_pos[qo_indptr_cpu[i] - :qo_indptr_cpu[i+1]]) - Q = len(cur_q_pos) - C = prefill_num_computed_tokens_cpu[i] - if Q <= 0: - mask_arr.append(torch.zeros(0, - dtype=torch.bool)) - continue - context_mask_i = torch.ones((Q, C), - dtype=torch.bool) - upper = torch.arange(Q*self.cp_world_size) - local_mask_i = (upper.unsqueeze(0) - <= cur_q_pos.unsqueeze(1)) - mask_i = torch.cat([context_mask_i, local_mask_i], - dim=1) - mask_arr.append(mask_i.flatten()) - custom_mask = torch.cat(mask_arr, dim=0).to(self.device) + q_pos = torch.from_numpy( + common_attn_metadata.query_positions[ + prefill_start:]).long() + kv_lens = prefill_num_computed_tokens_cpu + \ + kv_indptr_cpu[1:] - kv_indptr_cpu[:-1] + max_q_lens = int(q_pos.max().item()) + 1 + max_kv_lens = int(kv_lens.max().item()) + mask = torch.ones(max_q_lens, max_kv_lens, + dtype=torch.bool).tril() + selected_rows = torch.index_select(mask, 0, q_pos) + col_indices = torch.arange(max_kv_lens).expand(q_pos.size(0), -1) + valid_mask = col_indices < torch.repeat_interleave( + kv_lens, + qo_indptr_cpu[1:] - \ + qo_indptr_cpu[:-1] + ).unsqueeze(1) + custom_mask = selected_rows[valid_mask].to(self.device) attn_metadata.prefill_wrapper.plan( qo_indptr_cpu.to(self.device), From 50ab73cd6bf9cc865aac0f3af6953ae0ef744fdf Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Tue, 14 Oct 2025 14:22:37 +0800 Subject: [PATCH 07/12] [bugfix] add cp_world_size&cp_rank to AttentionImpl Signed-off-by: QiuChunshuo --- vllm/attention/backends/abstract.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 0bd7e80f544c..5b8c9fe5c32e 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -264,6 +264,9 @@ class AttentionImpl(ABC, Generic[T]): dcp_world_size: int dcp_rank: int + cp_world_size: int + cp_rank: int + def __new__(cls, *args, **kwargs): # use __new__ so that all subclasses will call this self = super().__new__(cls) From 67b052cca6cf7f41754488b9bddef9352b3056a9 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Tue, 14 Oct 2025 15:18:00 +0800 Subject: [PATCH 08/12] [format] break too long lines and add assert for variables of uncertain type Signed-off-by: QiuChunshuo --- .../model_executor/layers/fused_moe/config.py | 6 ++++-- vllm/v1/attention/backends/flashinfer.py | 1 + vllm/v1/attention/backends/utils.py | 5 +++-- vllm/v1/executor/multiproc_executor.py | 6 ++++-- vllm/v1/worker/block_table.py | 21 ++++++++++++------- vllm/v1/worker/gpu_model_runner.py | 2 +- 6 files changed, 27 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 8fe8f3053e35..b5d8217952d2 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -7,7 +7,8 @@ import vllm.envs as envs from vllm.config import ParallelConfig -from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank, get_context_model_parallel_rank +from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank, \ + get_context_model_parallel_rank from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) @@ -604,7 +605,8 @@ def make(tp_size_: int, dp_size_: int, cp_size_: int, level's of parallelism to use in the fused moe layer. Args: - tp_size_ (int): `tp_size` pa use_ep = (dp_size_ * tp_size_ssed into the FusedMoE constructor. + tp_size_ (int): `tp_size` pa use_ep = (dp_size_ * tp_size_ssed into + the FusedMoE constructor. dp_size_ (int): `dp_size` passed into the FusedMoE constructor. vllm_parallel_config (ParallelConfig): vLLM's parallel config object which contains the `enable_expert_parallel` flag. diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 192da433ecf0..df69d300747e 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -602,6 +602,7 @@ def build(self, num_computed_tokens_cpu[prefill_start:] if not attn_metadata.prefill_use_trtllm: if self.cp_world_size > 1: + assert common_attn_metadata.query_positions is not None kv_indptr_cpu = qo_indptr_cpu * self.cp_world_size # init custom mask for head-tail query order q_pos = torch.from_numpy( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 2dd7dd89f0bb..df7e1ab9792d 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -143,8 +143,9 @@ def _make_metadata_with_slice( # cp_kv_recover_idx as following approach query_positions = attn_metadata.query_positions[token_slice] \ if attn_metadata.query_positions is not None else None - cp_allgather_restore_idx = attn_metadata.cp_allgather_restore_idx[token_slice] \ - if attn_metadata.cp_allgather_restore_idx is not None else None + cp_allgather_restore_idx = attn_metadata.cp_allgather_restore_idx[ + token_slice] if attn_metadata.cp_allgather_restore_idx is not None \ + else None return CommonAttentionMetadata( query_start_loc=query_start_loc, diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e728dbb96272..71937c94231d 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -64,7 +64,8 @@ def _init_executor(self) -> None: tensor_parallel_size = self.parallel_config.tensor_parallel_size pp_parallel_size = self.parallel_config.pipeline_parallel_size context_parallel_size = self.parallel_config.context_parallel_size - assert self.world_size == tensor_parallel_size * pp_parallel_size * context_parallel_size, ( + assert self.world_size == tensor_parallel_size * pp_parallel_size * \ + context_parallel_size, ( f"world_size ({self.world_size}) must be equal to the " f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" f"_parallel_size ({pp_parallel_size}) x context" @@ -345,7 +346,8 @@ def _get_output_rank(self) -> int: # 16-23, PP rank 2 # 24-31, PP rank 3 # so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3) - return self.world_size - self.parallel_config.tensor_parallel_size * self.parallel_config.context_parallel_size + return self.world_size - self.parallel_config.tensor_parallel_size * \ + self.parallel_config.context_parallel_size @dataclass diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index d25cb699d346..851b6012673e 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -92,18 +92,21 @@ def compute_slot_mapping(self, req_indices: np.ndarray, # Use a "virtual block" which equals to world_size * block_size # for block_table_indices calculation. - virtual_block_size = self.block_size * self.dcp_world_size * self.cp_world_size + virtual_block_size = self.block_size * self.dcp_world_size * \ + self.cp_world_size block_table_indices = (req_indices * self.max_num_blocks_per_req + positions // virtual_block_size) block_numbers = self.block_table.np.ravel()[block_table_indices] # Use virtual_block_size for mask calculation, which marks local # tokens. virtual_block_offsets = positions % virtual_block_size - self.current_rank = self.dcp_world_size * self.cp_rank + self.dcp_rank - mask = (virtual_block_offsets % - (self.dcp_world_size * self.cp_world_size) == self.current_rank) + self.current_rank = self.dcp_world_size * self.cp_rank + \ + self.dcp_rank + mask = (virtual_block_offsets % (self.dcp_world_size * \ + self.cp_world_size) == self.current_rank) # Calculate local block_offsets - block_offsets = virtual_block_offsets // (self.dcp_world_size * self.cp_world_size) + block_offsets = virtual_block_offsets // \ + (self.dcp_world_size * self.cp_world_size) # Calculate slot_mapping slot_mapping = block_numbers * self.block_size + block_offsets # Write final slots, use -1 for not-local @@ -147,8 +150,12 @@ def _make_buffer(self, *size: Union[int, torch.SymInt], device=self.device, pin_memory=self.pin_memory) - def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) -> list[list[list[int]]]: - "Splits computed token counts across dcp and sp dimensions for distributed allocation." + def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) \ + -> list[list[list[int]]]: + """ + Splits computed token counts across dcp and sp dimensions for + distributed allocation. + """ num_requests = len(num_computed_tokens) num_computed_tokens_of_dcp_sp = [[ [0] * self.dcp_world_size for _ in range(self.cp_world_size) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0c50258efcab..6e76d9a77b16 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -806,7 +806,7 @@ def _update_tokens_for_cp(self, tokens): `num_cp_pads` the number of padding tokens per request for alignment, `unpad_mask` the mask for non-padded tokens, `cp_allgather_restore_idx` indices to restore the original vector - order after CP allgather. + order after CP allgather. Example: >>> tokens = [1, 5, 8] >>> cp_world_size = 2 From b461a55e9422749e2f404ef0364e0af0d5b58617 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Tue, 14 Oct 2025 16:01:35 +0800 Subject: [PATCH 09/12] [format] Rename to distinguish different types of variables Signed-off-by: QiuChunshuo --- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6e76d9a77b16..c9318ab67394 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -875,10 +875,10 @@ def get_current_rank_positions(cu_tokens, rank): # Decode tokens are duplicate and their positions always be 0. positions[:num_decode_reqs] = 0 - all_positions = [get_current_rank_positions(cu_padded_tokens, + all_positions_lst = [get_current_rank_positions(cu_padded_tokens, rank_i) for rank_i in range(self.cp_world_size)] - all_positions = torch.from_numpy(np.concatenate(all_positions)) + all_positions = torch.from_numpy(np.concatenate(all_positions_lst)) cp_allgather_restore_idx = all_positions.float().argsort( ).long().to(self.device) return (cp_tokens, positions, num_cp_pads, From c9fd35a9c6621996a8b3f60dbe280f8948513945 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Tue, 14 Oct 2025 17:24:30 +0800 Subject: [PATCH 10/12] [format] add error for unreachable branch. Signed-off-by: QiuChunshuo --- vllm/v1/worker/gpu_model_runner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c9318ab67394..6fda2b875a8c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2010,6 +2010,9 @@ def _preprocess( num_pad, num_tokens_after_padding = self.get_dp_padding( num_input_tokens) num_input_tokens += num_pad + else: + raise RuntimeError(f"Unreachable branch, please check the value " + f"of ubatch_slikces({ubatch_slices}).") # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order From 3212997329d4cabf0b558ab9840f0afacd18b12d Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Tue, 14 Oct 2025 19:10:37 +0800 Subject: [PATCH 11/12] [Perf] Persist some variables that require round-by-round initialization. Signed-off-by: QiuChunshuo --- vllm/v1/worker/gpu_model_runner.py | 82 +++++++++++++++++++----------- 1 file changed, 51 insertions(+), 31 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6fda2b875a8c..793feb708c01 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -359,6 +359,24 @@ def __init__( self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, dtype=torch.int64) + # Persistent buffers for Context Parallism + self.cp_allgather_restore_idx = self._make_buffer(self.max_num_tokens, + dtype=torch.int64) + self.cp_padded_slot_mapping = torch.empty((self.max_num_tokens, ), + dtype=torch.int64, + device=self.device,) + self.num_cp_pads_cpu_tensor = torch.zeros((self.max_num_reqs, ), + device="cpu", + dtype=torch.int64, + pin_memory=True) + self.num_cp_pads_cpu = self.num_cp_pads_cpu_tensor.numpy() + self.cp_unpad_mask_cpu_tensor = torch.zeros((self.max_num_tokens, ), + device="cpu", + dtype=torch.bool, + pin_memory=True) + self.cp_unpad_mask_cpu = self.cp_unpad_mask_cpu_tensor.numpy() + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: # NOTE: `mrope_positions` is implemented with one additional dummy @@ -824,9 +842,9 @@ def _update_tokens_for_cp(self, tokens): 3, 4, 5, 6, 14, 15, 16, 17, 7, 8]) """ num_reqs = self.input_batch.num_reqs - num_cp_pads = torch.zeros(num_reqs, dtype=torch.int32) + self.num_cp_pads_cpu[:num_reqs] = 0 if not self.cp_world_size > 1: - return tokens, None, num_cp_pads, None, None + return tokens, None num_decode_reqs = sum(self.input_batch.num_computed_tokens_cpu[ :num_reqs] >= self.input_batch.num_prompt_tokens[:num_reqs]) @@ -837,14 +855,11 @@ def _update_tokens_for_cp(self, tokens): # we align scheduled tokens of decode reqs to cp_world_size instead # of 2*cp_world_size num_padded_scheduled_tokens[:num_decode_reqs] = self.cp_world_size - num_cp_pads = torch.from_numpy(num_padded_scheduled_tokens - tokens) + self.num_cp_pads_cpu[:num_reqs] = num_padded_scheduled_tokens - tokens cu_padded_tokens, cp_padded_arange = \ self._get_cumsum_and_arange(num_padded_scheduled_tokens) - unpad_mask = torch.from_numpy(cp_padded_arange < - np.repeat( - tokens, - num_padded_scheduled_tokens - )) + self.cp_unpad_mask_cpu[:cp_padded_arange.shape[0]] = \ + cp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens) cp_tokens = num_padded_scheduled_tokens // self.cp_world_size cp_chunk_sizes = (cp_tokens // 2).clip(min=1) @@ -854,9 +869,10 @@ def _update_tokens_for_cp(self, tokens): cp_tokens) - def get_current_rank_positions(cu_tokens, rank): - positions_start_loc = np.zeros_like(cu_tokens) - positions_start_loc[1:] = cu_tokens[:-1] + def get_current_rank_positions( + positions_start_loc: Union[int, np.ndarray], + rank: int + ): positions = np.zeros(len(cp_head_chunk_mask), dtype=np.int32) head_start_loc = positions_start_loc + rank * cp_chunk_sizes tail_start_loc = positions_start_loc + \ @@ -869,20 +885,20 @@ def get_current_rank_positions(cu_tokens, rank): np.repeat(tail_start_loc, cp_chunk_sizes)[num_decode_reqs:] return positions - positions = get_current_rank_positions(np.zeros(num_reqs, - dtype=np.int32), - self.cp_rank) + positions = get_current_rank_positions(0, self.cp_rank) # Decode tokens are duplicate and their positions always be 0. positions[:num_decode_reqs] = 0 - all_positions_lst = [get_current_rank_positions(cu_padded_tokens, + padded_pos_start_loc = np.roll(cu_padded_tokens, 1) + padded_pos_start_loc[0] = 0 + all_positions_lst = [get_current_rank_positions(padded_pos_start_loc, rank_i) for rank_i in range(self.cp_world_size)] - all_positions = torch.from_numpy(np.concatenate(all_positions_lst)) - cp_allgather_restore_idx = all_positions.float().argsort( - ).long().to(self.device) - return (cp_tokens, positions, num_cp_pads, - unpad_mask, cp_allgather_restore_idx) + all_positions = np.concatenate(all_positions_lst) + self.cp_allgather_restore_idx.np[:all_positions.shape[0]] = \ + all_positions.argsort() + self.cp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) + return cp_tokens, positions def _get_cumsum_and_arange( @@ -1023,9 +1039,8 @@ def _prepare_inputs( # instead of sliced sequences total_num_scheduled_tokens4sltmap = total_num_scheduled_tokens original_num_scheduled_tokens = np.array(tokens, dtype=np.int32) - num_scheduled_tokens, positions_cp, num_cp_pads, unpad_mask, \ - self.cp_allgather_restore_idx = self._update_tokens_for_cp( - original_num_scheduled_tokens) + num_scheduled_tokens, positions_cp = self._update_tokens_for_cp( + original_num_scheduled_tokens) # update total_num_scheduled_tokens total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) max_num_scheduled_tokens = max(tokens) @@ -1198,7 +1213,7 @@ def _prepare_inputs( # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = torch.from_numpy(cu_num_tokens) * \ - self.cp_world_size - num_cp_pads - 1 + self.cp_world_size - self.num_cp_pads_cpu_tensor[:num_reqs] - 1 num_draft_tokens = None spec_decode_metadata = None else: @@ -1277,13 +1292,15 @@ def _prepare_inputs( num_common_prefix_blocks[kv_cache_group_id]) if self.cp_world_size > 1: - assert unpad_mask is not None # After cp allgather and restore, there are padded tokens in # kv, so we need pad slotmapping for alignment. - padded_slot_mapping = torch.full((unpad_mask.shape[0],), - 1).to(slot_mapping) - padded_slot_mapping[unpad_mask] = slot_mapping - slot_mapping = padded_slot_mapping + cp_padded_slot_mapping = self.cp_padded_slot_mapping[ + :total_num_scheduled_tokens*self.cp_world_size] + cp_unpad_mask = self.cp_unpad_mask_cpu_tensor[ + :total_num_scheduled_tokens*self.cp_world_size] + cp_padded_slot_mapping.fill_(-1) + cp_padded_slot_mapping[cp_unpad_mask] = slot_mapping + slot_mapping = cp_padded_slot_mapping common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, @@ -1301,7 +1318,8 @@ def _prepare_inputs( causal=True, encoder_seq_lens=encoder_seq_lens, query_positions=positions_np, - cp_allgather_restore_idx=self.cp_allgather_restore_idx, + cp_allgather_restore_idx=self.cp_allgather_restore_idx.gpu[ + :total_num_scheduled_tokens*self.cp_world_size], ) if self.speculative_config and \ @@ -2364,7 +2382,9 @@ def execute_model( if self.cp_world_size > 1: hidden_states = get_cp_group().all_gather(hidden_states, 0) hidden_states = torch.index_select( - hidden_states, 0, self.cp_allgather_restore_idx) + hidden_states, 0, self.cp_allgather_restore_idx.gpu[ + :hidden_states.shape[0] + ]) if not self.broadcast_pp_output: # Common case. if not get_pp_group().is_last_rank: From 51cb1f2238104fc6219cb05cb33831ad48138c10 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Tue, 14 Oct 2025 19:55:59 +0800 Subject: [PATCH 12/12] [format] isort Signed-off-by: QiuChunshuo --- vllm/model_executor/layers/fused_moe/config.py | 4 ++-- vllm/model_executor/layers/fused_moe/layer.py | 4 ++-- vllm/v1/attention/backends/flashinfer.py | 2 +- vllm/v1/executor/multiproc_executor.py | 6 +++--- vllm/v1/worker/block_table.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 10 +++++----- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index b5d8217952d2..1a1b1fea1933 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -7,8 +7,8 @@ import vllm.envs as envs from vllm.config import ParallelConfig -from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank, \ - get_context_model_parallel_rank +from vllm.distributed import (get_context_model_parallel_rank, get_dp_group, + get_tensor_model_parallel_rank) from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 22aefc8e1617..217a2e2c2f8a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -13,9 +13,9 @@ import vllm.envs as envs from vllm.config import get_current_vllm_config from vllm.config.parallel import ExpertPlacementStrategy -from vllm.distributed import (get_dp_group, get_ep_group, +from vllm.distributed import (get_context_model_parallel_world_size, + get_dp_group, get_ep_group, get_tensor_model_parallel_world_size, - get_context_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.distributed.eplb.eplb_state import EplbState from vllm.forward_context import ForwardContext, get_forward_context diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index df69d300747e..f00829c69867 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -21,8 +21,8 @@ AttentionType) from vllm.attention.ops.common import cp_lse_ag_out_ar from vllm.config import CUDAGraphMode, VllmConfig -from vllm.logger import init_logger from vllm.distributed.parallel_state import get_cp_group +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, kNvfp4Quant) from vllm.platforms import current_platform diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 71937c94231d..c3c4b4b9be1b 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -28,9 +28,9 @@ destroy_model_parallel) from vllm.distributed.device_communicators.shm_broadcast import (Handle, MessageQueue) -from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, - get_pp_group, get_tp_group, - get_cp_group) +from vllm.distributed.parallel_state import (get_cp_group, get_dp_group, + get_ep_group, get_pp_group, + get_tp_group) from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import worker_receiver_cache_from_config diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 851b6012673e..ac722a332503 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -5,7 +5,7 @@ import numpy as np import torch -from vllm.distributed import get_dcp_group, get_cp_group +from vllm.distributed import get_cp_group, get_dcp_group from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.utils import CpuGpuBuffer diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 793feb708c01..887f87263aba 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,7 +8,7 @@ from collections.abc import Iterator from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast, List +from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Union, cast import numpy as np import torch @@ -31,8 +31,8 @@ has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( - get_pp_group, get_tp_group, get_dcp_group, get_cp_group, graph_capture, is_global_first_rank, - prepare_communication_buffer_for_model) + get_cp_group, get_pp_group, get_tp_group, graph_capture, + is_global_first_rank, prepare_communication_buffer_for_model) from vllm.forward_context import (BatchDescriptor, DPMetadata, set_forward_context) from vllm.logger import init_logger @@ -55,10 +55,10 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, check_use_alibi, get_dtype_size, + GiB_bytes, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, length_from_prompt_token_ids_or_embeds, round_up, - supports_dynamo, cdiv) + supports_dynamo) from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import (