From d4c20efc4657e8eef0747db20130b117f7efe2cf Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Fri, 3 Apr 2026 17:19:44 +0800 Subject: [PATCH 01/13] [Feature] Support kv_transfer_params for parallel sampling (n>1) in PD Signed-off-by: chaunceyjiang --- .../openai/chat_completion/protocol.py | 4 +-- vllm/outputs.py | 19 ++++++++++-- vllm/v1/engine/output_processor.py | 10 +++++-- vllm/v1/engine/parallel_sampling.py | 29 +++++++++++++++---- 4 files changed, 50 insertions(+), 12 deletions(-) diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index 533959df6094..70aca1e11fd3 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -107,7 +107,7 @@ class ChatCompletionResponse(OpenAIBaseModel): # vLLM-specific fields that are not in OpenAI spec prompt_logprobs: list[dict[int, Logprob] | None] | None = None prompt_token_ids: list[int] | None = None - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters." ) @@ -332,7 +332,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ), ) - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) diff --git a/vllm/outputs.py b/vllm/outputs.py index 2c71d2afb1b5..e315ba676616 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -120,7 +120,7 @@ def __init__( encoder_prompt_token_ids: list[int] | None = None, num_cached_tokens: int | None = None, *, - kv_transfer_params: dict[str, Any] | None = None, + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None, # Forward compatibility, code that uses args added in new release can # still run with older versions of vLLM without breaking. **kwargs: Any, @@ -140,13 +140,25 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens - self.kv_transfer_params = kv_transfer_params + self.kv_transfer_params_list = [] + if kv_transfer_params: + if isinstance(kv_transfer_params, list): + self.kv_transfer_params_list = kv_transfer_params + else: + self.kv_transfer_params_list = [kv_transfer_params] + + @property + def kv_transfer_params(self) -> dict[str, Any] | list[dict[str, Any]] | None: + if len(self.kv_transfer_params_list) == 1: + return self.kv_transfer_params_list[0] + if len(self.kv_transfer_params_list): + return self.kv_transfer_params_list + return None def add(self, next_output: "RequestOutput", aggregate: bool) -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished - self.kv_transfer_params = next_output.kv_transfer_params for next_completion in next_output.outputs: for i, completion in enumerate(self.outputs): @@ -171,6 +183,7 @@ def add(self, next_output: "RequestOutput", aggregate: bool) -> None: break else: self.outputs.append(next_completion) + self.kv_transfer_params_list.extend(next_output.kv_transfer_params_list) def __repr__(self) -> str: return ( diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index f9e965092288..6a479dbe8b26 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -321,7 +321,13 @@ def make_request_output( if self.parent_req is None: outputs = [output] else: - outputs, finished = self.parent_req.get_outputs(self.request_id, output) + if kv_transfer_params is None: + outputs, finished = self.parent_req.get_outputs(self.request_id, output) + else: + output_with_kv_transfer = self.parent_req.aggre_kv_transfer_params( + self.request_id, output, kv_transfer_params + ) + outputs, finished, kv_transfer_params = output_with_kv_transfer if not outputs: return None external_req_id = self.parent_req.external_req_id @@ -335,7 +341,7 @@ def _new_request_output( external_req_id: str, outputs: list[CompletionOutput] | list[PoolingOutput], finished: bool, - kv_transfer_params: dict[str, Any] | None = None, + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None, ) -> RequestOutput | PoolingRequestOutput: # If prompt embeds were used, put placeholder prompt token ids prompt_token_ids = self.prompt_token_ids diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 8eb6fa057d37..346e504da6ad 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from copy import copy -from typing import cast +from copy import deepcopy +from typing import Any, cast from vllm.outputs import CompletionOutput from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -26,6 +26,8 @@ class ParentRequest: # To aggregate child completions when not streaming output_aggregator: list[CompletionOutput] + # To store kv_transfer_params for child request + output_kv_transfer_params_list: list[dict[str, Any]] # To find the max number of generated tokens across all children max_num_generation_tokens: int @@ -48,6 +50,7 @@ def __init__(self, request: EngineCoreRequest) -> None: ) self.max_num_generation_tokens = 0 self.cached_child_sampling_params = None + self.output_kv_transfer_params_list = [] def _get_child_sampling_params( self, @@ -66,15 +69,21 @@ def _get_child_sampling_params( Child `sampling_params` instance. """ seed = self.sampling_params.seed + no_caching = seed is None and self.sampling_params.n > 1 if self.cached_child_sampling_params: # Reuse child sampling_params data structure return self.cached_child_sampling_params # Build child sampling_params - child_sampling_params = copy(self.sampling_params) + child_sampling_params = deepcopy(self.sampling_params) child_sampling_params.n = 1 + kv_transfer = child_sampling_params.extra_args.get("kv_transfer_params") + if kv_transfer is not None and isinstance(kv_transfer, list): + child_sampling_params.extra_args["kv_transfer_params"] = kv_transfer[index] + if seed is None: - # Cache child sampling_params for later reuse - self.cached_child_sampling_params = child_sampling_params + if not no_caching: + # Cache child sampling_params for later reuse + self.cached_child_sampling_params = child_sampling_params else: # Each child gets a clone with a unique seed child_sampling_params.seed = seed + index @@ -125,6 +134,16 @@ def get_outputs( finished = not self.child_requests return outputs, finished + def aggre_kv_transfer_params( + self, + child_request_id: str, + completion_output: CompletionOutput, + kv_transfer_params: dict[str, Any], + ) -> tuple[list[CompletionOutput], bool, list[dict[str, Any]]]: + outputs, finished = self.get_outputs(child_request_id, completion_output) + self.output_kv_transfer_params_list.append(kv_transfer_params) + return outputs, finished, self.output_kv_transfer_params_list + def observe_num_generation_tokens(self, num_generation_tokens: int): self.max_num_generation_tokens = max( num_generation_tokens, self.max_num_generation_tokens From 0836848eff90005e4557194971427a57ee673fab Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Fri, 3 Apr 2026 17:48:29 +0800 Subject: [PATCH 02/13] [Feature] Support kv_transfer_params for parallel sampling (n>1) in PD Signed-off-by: chaunceyjiang --- vllm/v1/engine/output_processor.py | 4 +++- vllm/v1/engine/parallel_sampling.py | 11 ++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 6a479dbe8b26..d91788541643 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -272,7 +272,7 @@ def make_request_output( pooling_output: torch.Tensor | None, finish_reason: FinishReason | None, stop_reason: int | str | None, - kv_transfer_params: dict[str, Any] | None = None, + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None, routed_experts: np.ndarray | None = None, ) -> RequestOutput | PoolingRequestOutput | None: finished = finish_reason is not None @@ -327,6 +327,8 @@ def make_request_output( output_with_kv_transfer = self.parent_req.aggre_kv_transfer_params( self.request_id, output, kv_transfer_params ) + # overwrite kv_transfer_params with the aggregated one + # from children requests in case of parallel sampling outputs, finished, kv_transfer_params = output_with_kv_transfer if not outputs: return None diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 346e504da6ad..ab9ce5b6897b 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from copy import deepcopy +from copy import copy from typing import Any, cast from vllm.outputs import CompletionOutput @@ -74,12 +74,13 @@ def _get_child_sampling_params( # Reuse child sampling_params data structure return self.cached_child_sampling_params # Build child sampling_params - child_sampling_params = deepcopy(self.sampling_params) + child_sampling_params = copy(self.sampling_params) child_sampling_params.n = 1 - kv_transfer = child_sampling_params.extra_args.get("kv_transfer_params") - if kv_transfer is not None and isinstance(kv_transfer, list): + extra_args = child_sampling_params.extra_args or {} + kv_transfer = extra_args.get("kv_transfer_params") + if kv_transfer and isinstance(kv_transfer, list): + child_sampling_params.extra_args = copy(extra_args) child_sampling_params.extra_args["kv_transfer_params"] = kv_transfer[index] - if seed is None: if not no_caching: # Cache child sampling_params for later reuse From b7482d8755be38c64f6fcad9576e37e5611f085e Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Fri, 3 Apr 2026 17:52:57 +0800 Subject: [PATCH 03/13] [Feature] Support kv_transfer_params for parallel sampling (n>1) in PD Signed-off-by: chaunceyjiang --- vllm/v1/engine/output_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index d91788541643..0371987bdc0f 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -327,8 +327,8 @@ def make_request_output( output_with_kv_transfer = self.parent_req.aggre_kv_transfer_params( self.request_id, output, kv_transfer_params ) - # overwrite kv_transfer_params with the aggregated one - # from children requests in case of parallel sampling + # Overwrite kv_transfer_params using the aggregated values from + # child requests in the case of parallel sampling. outputs, finished, kv_transfer_params = output_with_kv_transfer if not outputs: return None From cc9a763e98e8ed5148035a33d2499caff90f718a Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Fri, 3 Apr 2026 18:12:56 +0800 Subject: [PATCH 04/13] [Feature] Support kv_transfer_params for parallel sampling (n>1) in PD Signed-off-by: chaunceyjiang --- vllm/entrypoints/anthropic/protocol.py | 4 ++-- vllm/entrypoints/openai/completion/protocol.py | 4 ++-- vllm/entrypoints/openai/responses/context.py | 6 +++--- vllm/entrypoints/openai/responses/protocol.py | 2 +- vllm/entrypoints/serve/disagg/protocol.py | 4 ++-- vllm/v1/engine/output_processor.py | 1 + 6 files changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/anthropic/protocol.py b/vllm/entrypoints/anthropic/protocol.py index 3445f709109f..4882e77acc02 100644 --- a/vllm/entrypoints/anthropic/protocol.py +++ b/vllm/entrypoints/anthropic/protocol.py @@ -113,7 +113,7 @@ class AnthropicMessagesRequest(BaseModel): top_p: float | None = None # vLLM-specific fields that are not in Anthropic spec - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) @@ -188,7 +188,7 @@ class AnthropicMessagesResponse(BaseModel): usage: AnthropicUsage | None = None # vLLM-specific fields that are not in Anthropic spec - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters." ) diff --git a/vllm/entrypoints/openai/completion/protocol.py b/vllm/entrypoints/openai/completion/protocol.py index c785d254084d..a0512fe87adf 100644 --- a/vllm/entrypoints/openai/completion/protocol.py +++ b/vllm/entrypoints/openai/completion/protocol.py @@ -154,7 +154,7 @@ class CompletionRequest(OpenAIBaseModel): ), ) - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) @@ -481,7 +481,7 @@ class CompletionResponse(OpenAIBaseModel): usage: UsageInfo # vLLM-specific fields that are not in OpenAI spec - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters." ) diff --git a/vllm/entrypoints/openai/responses/context.py b/vllm/entrypoints/openai/responses/context.py index 48360173cf48..731d1c493e74 100644 --- a/vllm/entrypoints/openai/responses/context.py +++ b/vllm/entrypoints/openai/responses/context.py @@ -182,7 +182,7 @@ def __init__(self): self.all_turn_metrics = [] self.input_messages: list[ResponseRawMessageAndToken] = [] - self.kv_transfer_params: dict[str, Any] | None = None + self.kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None def append_output(self, output) -> None: self.last_output = output @@ -311,7 +311,7 @@ def __init__( self.input_messages: list[ResponseRawMessageAndToken] = [] self.output_messages: list[ResponseRawMessageAndToken] = [] self._accumulated_token_ids: list[int] = [] - self.kv_transfer_params: dict[str, Any] | None = None + self.kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None def append_output(self, output: RequestOutput) -> None: self.num_prompt_tokens = len(output.prompt_token_ids or []) @@ -544,7 +544,7 @@ def __init__( self.all_turn_metrics: list[TurnMetrics] = [] self.is_first_turn = True self.first_tok_of_message = True # For streaming support - self.kv_transfer_params: dict[str, Any] | None = None + self.kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None def _update_num_reasoning_tokens(self): channel = self.parser.current_channel diff --git a/vllm/entrypoints/openai/responses/protocol.py b/vllm/entrypoints/openai/responses/protocol.py index d34ba2d75bba..07e6fe3a752b 100644 --- a/vllm/entrypoints/openai/responses/protocol.py +++ b/vllm/entrypoints/openai/responses/protocol.py @@ -589,7 +589,7 @@ def from_request( usage: ResponseUsage | None = None, input_messages: ResponseInputOutputMessage | None = None, output_messages: ResponseInputOutputMessage | None = None, - kv_transfer_params: dict[str, Any] | None = None, + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None, ) -> "ResponsesResponse": incomplete_details: IncompleteDetails | None = None if status == "incomplete": diff --git a/vllm/entrypoints/serve/disagg/protocol.py b/vllm/entrypoints/serve/disagg/protocol.py index af4e8c20c14c..15d980540484 100644 --- a/vllm/entrypoints/serve/disagg/protocol.py +++ b/vllm/entrypoints/serve/disagg/protocol.py @@ -102,7 +102,7 @@ def validate_token_ids(cls, v: list[int]) -> list[int]: "if the served model does not use priority scheduling." ), ) - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) @@ -135,7 +135,7 @@ class GenerateResponse(BaseModel): prompt_logprobs: list[dict[int, Logprob] | None] | None = None - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 0371987bdc0f..94caa7942214 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -324,6 +324,7 @@ def make_request_output( if kv_transfer_params is None: outputs, finished = self.parent_req.get_outputs(self.request_id, output) else: + assert isinstance(kv_transfer_params, dict) output_with_kv_transfer = self.parent_req.aggre_kv_transfer_params( self.request_id, output, kv_transfer_params ) From f4a096f78dd2b2f0d27c5dc0bfd1d3a5dbaad38a Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Fri, 3 Apr 2026 18:17:18 +0800 Subject: [PATCH 05/13] [Feature] Support kv_transfer_params for parallel sampling (n>1) in PD Signed-off-by: chaunceyjiang --- vllm/entrypoints/anthropic/protocol.py | 4 ++-- vllm/entrypoints/openai/responses/protocol.py | 2 +- vllm/entrypoints/serve/disagg/protocol.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/anthropic/protocol.py b/vllm/entrypoints/anthropic/protocol.py index 4882e77acc02..3445f709109f 100644 --- a/vllm/entrypoints/anthropic/protocol.py +++ b/vllm/entrypoints/anthropic/protocol.py @@ -113,7 +113,7 @@ class AnthropicMessagesRequest(BaseModel): top_p: float | None = None # vLLM-specific fields that are not in Anthropic spec - kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( + kv_transfer_params: dict[str, Any] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) @@ -188,7 +188,7 @@ class AnthropicMessagesResponse(BaseModel): usage: AnthropicUsage | None = None # vLLM-specific fields that are not in Anthropic spec - kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( + kv_transfer_params: dict[str, Any] | None = Field( default=None, description="KVTransfer parameters." ) diff --git a/vllm/entrypoints/openai/responses/protocol.py b/vllm/entrypoints/openai/responses/protocol.py index 07e6fe3a752b..d34ba2d75bba 100644 --- a/vllm/entrypoints/openai/responses/protocol.py +++ b/vllm/entrypoints/openai/responses/protocol.py @@ -589,7 +589,7 @@ def from_request( usage: ResponseUsage | None = None, input_messages: ResponseInputOutputMessage | None = None, output_messages: ResponseInputOutputMessage | None = None, - kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None, + kv_transfer_params: dict[str, Any] | None = None, ) -> "ResponsesResponse": incomplete_details: IncompleteDetails | None = None if status == "incomplete": diff --git a/vllm/entrypoints/serve/disagg/protocol.py b/vllm/entrypoints/serve/disagg/protocol.py index 15d980540484..af4e8c20c14c 100644 --- a/vllm/entrypoints/serve/disagg/protocol.py +++ b/vllm/entrypoints/serve/disagg/protocol.py @@ -102,7 +102,7 @@ def validate_token_ids(cls, v: list[int]) -> list[int]: "if the served model does not use priority scheduling." ), ) - kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( + kv_transfer_params: dict[str, Any] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) @@ -135,7 +135,7 @@ class GenerateResponse(BaseModel): prompt_logprobs: list[dict[int, Logprob] | None] | None = None - kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( + kv_transfer_params: dict[str, Any] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) From 2eddc958406fd298de69128ec28954bd08a880ea Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Fri, 3 Apr 2026 18:29:49 +0800 Subject: [PATCH 06/13] [Feature] Support kv_transfer_params for parallel sampling (n>1) in PD Signed-off-by: chaunceyjiang --- vllm/entrypoints/openai/completion/protocol.py | 4 ++-- vllm/entrypoints/openai/responses/context.py | 10 +++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/completion/protocol.py b/vllm/entrypoints/openai/completion/protocol.py index a0512fe87adf..c785d254084d 100644 --- a/vllm/entrypoints/openai/completion/protocol.py +++ b/vllm/entrypoints/openai/completion/protocol.py @@ -154,7 +154,7 @@ class CompletionRequest(OpenAIBaseModel): ), ) - kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( + kv_transfer_params: dict[str, Any] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) @@ -481,7 +481,7 @@ class CompletionResponse(OpenAIBaseModel): usage: UsageInfo # vLLM-specific fields that are not in OpenAI spec - kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( + kv_transfer_params: dict[str, Any] | None = Field( default=None, description="KVTransfer parameters." ) diff --git a/vllm/entrypoints/openai/responses/context.py b/vllm/entrypoints/openai/responses/context.py index 731d1c493e74..6cfcaa8b783a 100644 --- a/vllm/entrypoints/openai/responses/context.py +++ b/vllm/entrypoints/openai/responses/context.py @@ -182,7 +182,7 @@ def __init__(self): self.all_turn_metrics = [] self.input_messages: list[ResponseRawMessageAndToken] = [] - self.kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None + self.kv_transfer_params: dict[str, Any] | None = None def append_output(self, output) -> None: self.last_output = output @@ -192,6 +192,7 @@ def append_output(self, output) -> None: self.num_cached_tokens = output.num_cached_tokens or 0 self.num_output_tokens += len(output.outputs[0].token_ids or []) if output.kv_transfer_params is not None: + assert isinstance(output.kv_transfer_params, dict) self.kv_transfer_params = output.kv_transfer_params # Accumulate text, token_ids, and logprobs for streaming mode @@ -311,13 +312,14 @@ def __init__( self.input_messages: list[ResponseRawMessageAndToken] = [] self.output_messages: list[ResponseRawMessageAndToken] = [] self._accumulated_token_ids: list[int] = [] - self.kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None + self.kv_transfer_params: dict[str, Any] | None = None def append_output(self, output: RequestOutput) -> None: self.num_prompt_tokens = len(output.prompt_token_ids or []) self.num_cached_tokens = output.num_cached_tokens or 0 self.num_output_tokens += len(output.outputs[0].token_ids or []) if output.kv_transfer_params is not None: + assert isinstance(output.kv_transfer_params, dict) self.kv_transfer_params = output.kv_transfer_params self.parser.process(output.outputs[0]) output_token_ids = output.outputs[0].token_ids or [] @@ -544,7 +546,7 @@ def __init__( self.all_turn_metrics: list[TurnMetrics] = [] self.is_first_turn = True self.first_tok_of_message = True # For streaming support - self.kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None + self.kv_transfer_params: dict[str, Any] | None = None def _update_num_reasoning_tokens(self): channel = self.parser.current_channel @@ -565,6 +567,7 @@ def append_output(self, output: RequestOutput) -> None: self._update_prefill_token_usage(output) self._update_decode_token_usage(output) if output.kv_transfer_params is not None: + assert isinstance(output.kv_transfer_params, dict) self.kv_transfer_params = output.kv_transfer_params # Append current turn to all turn list for next turn's calculations self.all_turn_metrics.append(self.current_turn_metrics.copy()) @@ -878,6 +881,7 @@ def append_output(self, output: RequestOutput) -> None: self.last_content_delta = last_delta_text self._update_decode_token_usage(output) if output.kv_transfer_params is not None: + assert isinstance(output.kv_transfer_params, dict) self.kv_transfer_params = output.kv_transfer_params # For streaming, update previous turn when message is complete From 6e53cde0285af8f6e3ae1b7a975865f15df47a91 Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Fri, 3 Apr 2026 18:40:05 +0800 Subject: [PATCH 07/13] [Feature] Support kv_transfer_params for parallel sampling (n>1) in PD Signed-off-by: chaunceyjiang --- vllm/v1/engine/output_processor.py | 13 ++++++++++--- vllm/v1/engine/parallel_sampling.py | 6 ++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 94caa7942214..ac83fe20f203 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -321,6 +321,7 @@ def make_request_output( if self.parent_req is None: outputs = [output] else: + child_kv_transfer_params = None if kv_transfer_params is None: outputs, finished = self.parent_req.get_outputs(self.request_id, output) else: @@ -328,11 +329,17 @@ def make_request_output( output_with_kv_transfer = self.parent_req.aggre_kv_transfer_params( self.request_id, output, kv_transfer_params ) - # Overwrite kv_transfer_params using the aggregated values from - # child requests in the case of parallel sampling. - outputs, finished, kv_transfer_params = output_with_kv_transfer + outputs, finished, child_kv_transfer_params = output_with_kv_transfer if not outputs: return None + # In the case of parallel sampling, the final output's kv_transfer_params + # is aggregated from all child requests, so we use the aggregated + # child_kv_transfer_params if available. + kv_transfer_params = ( + child_kv_transfer_params + if child_kv_transfer_params + else kv_transfer_params + ) external_req_id = self.parent_req.external_req_id return self._new_request_output( diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index ab9ce5b6897b..ff96c1dccba6 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -79,8 +79,10 @@ def _get_child_sampling_params( extra_args = child_sampling_params.extra_args or {} kv_transfer = extra_args.get("kv_transfer_params") if kv_transfer and isinstance(kv_transfer, list): - child_sampling_params.extra_args = copy(extra_args) - child_sampling_params.extra_args["kv_transfer_params"] = kv_transfer[index] + child_sampling_params.extra_args = { + **extra_args, + "kv_transfer_params": kv_transfer[index], + } if seed is None: if not no_caching: # Cache child sampling_params for later reuse From eae60763001a79d52b26f7be03ec84d51dbf4bcd Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Tue, 7 Apr 2026 10:50:18 +0800 Subject: [PATCH 08/13] [Feature] Support kv_transfer_params for parallel sampling (n>1) in PD Signed-off-by: chaunceyjiang --- vllm/outputs.py | 27 +++++++++++++----------- vllm/v1/engine/output_processor.py | 34 +++++++++--------------------- 2 files changed, 25 insertions(+), 36 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index e315ba676616..8d2e727bd113 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -35,6 +35,7 @@ class CompletionOutput: to stop, None if the completion finished for some other reason including encountering the EOS token. lora_request: The LoRA request that was used to generate the output. + kv_transfer_params: The params for remote K/V transfer. """ index: int @@ -46,6 +47,7 @@ class CompletionOutput: finish_reason: str | None = None stop_reason: int | str | None = None lora_request: LoRARequest | None = None + kv_transfer_params: dict[str, Any] | None = None def finished(self) -> bool: return self.finish_reason is not None @@ -59,7 +61,8 @@ def __repr__(self) -> str: f"cumulative_logprob={self.cumulative_logprob}, " f"logprobs={self.logprobs}, " f"finish_reason={self.finish_reason}, " - f"stop_reason={self.stop_reason})" + f"stop_reason={self.stop_reason}, " + f"kv_transfer_params={self.kv_transfer_params})" ) @@ -103,7 +106,6 @@ class RequestOutput: encoder_prompt_token_ids: The token IDs of the encoder prompt. None if decoder-only. num_cached_tokens: The number of tokens with prefix cache hit. - kv_transfer_params: The params for remote K/V transfer. """ def __init__( @@ -119,8 +121,6 @@ def __init__( encoder_prompt: str | None = None, encoder_prompt_token_ids: list[int] | None = None, num_cached_tokens: int | None = None, - *, - kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None, # Forward compatibility, code that uses args added in new release can # still run with older versions of vLLM without breaking. **kwargs: Any, @@ -141,17 +141,20 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens self.kv_transfer_params_list = [] - if kv_transfer_params: - if isinstance(kv_transfer_params, list): - self.kv_transfer_params_list = kv_transfer_params - else: - self.kv_transfer_params_list = [kv_transfer_params] + for output in outputs: + if output.kv_transfer_params: + self.kv_transfer_params_list.append(output.kv_transfer_params) @property def kv_transfer_params(self) -> dict[str, Any] | list[dict[str, Any]] | None: if len(self.kv_transfer_params_list) == 1: + # keep backward compatibility for the common case where there is + # only one output and its kv_transfer_params is a dict return self.kv_transfer_params_list[0] if len(self.kv_transfer_params_list): + # for the case where there are multiple outputs, we return a list of + # kv_transfer_params dicts. This is for parallel sampling (n > 1) + # where each child request may have different kv_transfer_params. return self.kv_transfer_params_list return None @@ -159,7 +162,6 @@ def add(self, next_output: "RequestOutput", aggregate: bool) -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished - for next_completion in next_output.outputs: for i, completion in enumerate(self.outputs): if completion.index == next_completion.index: @@ -183,7 +185,7 @@ def add(self, next_output: "RequestOutput", aggregate: bool) -> None: break else: self.outputs.append(next_completion) - self.kv_transfer_params_list.extend(next_output.kv_transfer_params_list) + self.kv_transfer_params_list.extend(next_completion.kv_transfer_params) def __repr__(self) -> str: return ( @@ -197,7 +199,8 @@ def __repr__(self) -> str: f"finished={self.finished}, " f"metrics={self.metrics}, " f"lora_request={self.lora_request}, " - f"num_cached_tokens={self.num_cached_tokens})" + f"num_cached_tokens={self.num_cached_tokens}, " + f"kv_transfer_params={self.kv_transfer_params})" ) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index ac83fe20f203..7ba880435812 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -272,7 +272,7 @@ def make_request_output( pooling_output: torch.Tensor | None, finish_reason: FinishReason | None, stop_reason: int | str | None, - kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None, + kv_transfer_params: dict[str, Any] | None = None, routed_experts: np.ndarray | None = None, ) -> RequestOutput | PoolingRequestOutput | None: finished = finish_reason is not None @@ -315,43 +315,28 @@ def make_request_output( ) output = self._new_completion_output( - new_token_ids, finish_reason, stop_reason, routed_experts + new_token_ids, + finish_reason, + stop_reason, + routed_experts, + kv_transfer_params, ) if self.parent_req is None: outputs = [output] else: - child_kv_transfer_params = None - if kv_transfer_params is None: - outputs, finished = self.parent_req.get_outputs(self.request_id, output) - else: - assert isinstance(kv_transfer_params, dict) - output_with_kv_transfer = self.parent_req.aggre_kv_transfer_params( - self.request_id, output, kv_transfer_params - ) - outputs, finished, child_kv_transfer_params = output_with_kv_transfer + outputs, finished = self.parent_req.get_outputs(self.request_id, output) if not outputs: return None - # In the case of parallel sampling, the final output's kv_transfer_params - # is aggregated from all child requests, so we use the aggregated - # child_kv_transfer_params if available. - kv_transfer_params = ( - child_kv_transfer_params - if child_kv_transfer_params - else kv_transfer_params - ) external_req_id = self.parent_req.external_req_id - return self._new_request_output( - external_req_id, outputs, finished, kv_transfer_params - ) + return self._new_request_output(external_req_id, outputs, finished) def _new_request_output( self, external_req_id: str, outputs: list[CompletionOutput] | list[PoolingOutput], finished: bool, - kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None, ) -> RequestOutput | PoolingRequestOutput: # If prompt embeds were used, put placeholder prompt token ids prompt_token_ids = self.prompt_token_ids @@ -384,7 +369,6 @@ def _new_request_output( prompt_logprobs=prompt_logprobs, outputs=cast(list[CompletionOutput], outputs), finished=finished, - kv_transfer_params=kv_transfer_params, num_cached_tokens=self.num_cached_tokens, metrics=self.stats, ) @@ -395,6 +379,7 @@ def _new_completion_output( finish_reason: FinishReason | None, stop_reason: int | str | None, routed_experts: np.ndarray | None = None, + kv_transfer_params: dict[str, Any] | None = None, ) -> CompletionOutput: assert self.detokenizer is not None assert self.logprobs_processor is not None @@ -420,6 +405,7 @@ def _new_completion_output( cumulative_logprob=self.logprobs_processor.cumulative_logprob, finish_reason=str(finish_reason) if finished else None, stop_reason=stop_reason if finished else None, + kv_transfer_params=kv_transfer_params, ) def _new_pooling_output(self, pooling_output: torch.Tensor) -> PoolingOutput: From 16d9bde695d2f8425ee71573acaaa877204f279a Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Tue, 7 Apr 2026 10:52:16 +0800 Subject: [PATCH 09/13] [Feature] Support kv_transfer_params for parallel sampling (n>1) in PD Signed-off-by: chaunceyjiang --- vllm/v1/engine/parallel_sampling.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index ff96c1dccba6..7d99780ab21f 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -26,8 +26,6 @@ class ParentRequest: # To aggregate child completions when not streaming output_aggregator: list[CompletionOutput] - # To store kv_transfer_params for child request - output_kv_transfer_params_list: list[dict[str, Any]] # To find the max number of generated tokens across all children max_num_generation_tokens: int @@ -137,16 +135,6 @@ def get_outputs( finished = not self.child_requests return outputs, finished - def aggre_kv_transfer_params( - self, - child_request_id: str, - completion_output: CompletionOutput, - kv_transfer_params: dict[str, Any], - ) -> tuple[list[CompletionOutput], bool, list[dict[str, Any]]]: - outputs, finished = self.get_outputs(child_request_id, completion_output) - self.output_kv_transfer_params_list.append(kv_transfer_params) - return outputs, finished, self.output_kv_transfer_params_list - def observe_num_generation_tokens(self, num_generation_tokens: int): self.max_num_generation_tokens = max( num_generation_tokens, self.max_num_generation_tokens From 003d3ee22261b7937258b7531aee00e10393f1d6 Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Tue, 7 Apr 2026 10:59:37 +0800 Subject: [PATCH 10/13] [Feature] Support kv_transfer_params for parallel sampling (n>1) in PD Signed-off-by: chaunceyjiang --- vllm/outputs.py | 12 +++++++++++- vllm/v1/engine/parallel_sampling.py | 3 +-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 8d2e727bd113..18c96579d6d0 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -179,13 +179,23 @@ def add(self, next_output: "RequestOutput", aggregate: bool) -> None: ) completion.finish_reason = next_completion.finish_reason completion.stop_reason = next_completion.stop_reason + completion.kv_transfer_params = ( + next_completion.kv_transfer_params + ) else: # Replace the output with the new one self.outputs[i] = next_completion + if next_completion.kv_transfer_params: + self.kv_transfer_params_list[i] = ( + next_completion.kv_transfer_params + ) break else: self.outputs.append(next_completion) - self.kv_transfer_params_list.extend(next_completion.kv_transfer_params) + if next_completion.kv_transfer_params: + self.kv_transfer_params_list.extend( + next_completion.kv_transfer_params + ) def __repr__(self) -> str: return ( diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 7d99780ab21f..ce24c1a4011a 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import copy -from typing import Any, cast +from typing import cast from vllm.outputs import CompletionOutput from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -48,7 +48,6 @@ def __init__(self, request: EngineCoreRequest) -> None: ) self.max_num_generation_tokens = 0 self.cached_child_sampling_params = None - self.output_kv_transfer_params_list = [] def _get_child_sampling_params( self, From 2327dfe0df584731a950294b253a1a5a703f25f5 Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Tue, 7 Apr 2026 11:02:47 +0800 Subject: [PATCH 11/13] [Feature] Support kv_transfer_params for parallel sampling (n>1) in PD Signed-off-by: chaunceyjiang --- vllm/outputs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 18c96579d6d0..4dbe6ff55022 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -140,7 +140,7 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens - self.kv_transfer_params_list = [] + self.kv_transfer_params_list: list[dict[str, Any]] = [] for output in outputs: if output.kv_transfer_params: self.kv_transfer_params_list.append(output.kv_transfer_params) @@ -193,7 +193,7 @@ def add(self, next_output: "RequestOutput", aggregate: bool) -> None: else: self.outputs.append(next_completion) if next_completion.kv_transfer_params: - self.kv_transfer_params_list.extend( + self.kv_transfer_params_list.append( next_completion.kv_transfer_params ) From dd58080c6955e6a7999c5e1a3bf8292ab5c8a7fd Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Wed, 8 Apr 2026 11:37:26 +0800 Subject: [PATCH 12/13] [Feature] Support kv_transfer_params for parallel sampling (n>1) in PD Signed-off-by: chaunceyjiang --- .../entrypoints/openai/completion/protocol.py | 4 +-- vllm/entrypoints/serve/disagg/protocol.py | 4 +-- vllm/outputs.py | 25 +++++++------------ vllm/v1/engine/parallel_sampling.py | 4 +-- 4 files changed, 15 insertions(+), 22 deletions(-) diff --git a/vllm/entrypoints/openai/completion/protocol.py b/vllm/entrypoints/openai/completion/protocol.py index c785d254084d..a0512fe87adf 100644 --- a/vllm/entrypoints/openai/completion/protocol.py +++ b/vllm/entrypoints/openai/completion/protocol.py @@ -154,7 +154,7 @@ class CompletionRequest(OpenAIBaseModel): ), ) - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) @@ -481,7 +481,7 @@ class CompletionResponse(OpenAIBaseModel): usage: UsageInfo # vLLM-specific fields that are not in OpenAI spec - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters." ) diff --git a/vllm/entrypoints/serve/disagg/protocol.py b/vllm/entrypoints/serve/disagg/protocol.py index af4e8c20c14c..15d980540484 100644 --- a/vllm/entrypoints/serve/disagg/protocol.py +++ b/vllm/entrypoints/serve/disagg/protocol.py @@ -102,7 +102,7 @@ def validate_token_ids(cls, v: list[int]) -> list[int]: "if the served model does not use priority scheduling." ), ) - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) @@ -135,7 +135,7 @@ class GenerateResponse(BaseModel): prompt_logprobs: list[dict[int, Logprob] | None] | None = None - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) diff --git a/vllm/outputs.py b/vllm/outputs.py index 4dbe6ff55022..13475f8a5c1d 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -140,22 +140,23 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens - self.kv_transfer_params_list: list[dict[str, Any]] = [] - for output in outputs: - if output.kv_transfer_params: - self.kv_transfer_params_list.append(output.kv_transfer_params) @property def kv_transfer_params(self) -> dict[str, Any] | list[dict[str, Any]] | None: - if len(self.kv_transfer_params_list) == 1: + params_list = [ + o.kv_transfer_params + for o in self.outputs + if o.kv_transfer_params + ] + if len(params_list) == 1: # keep backward compatibility for the common case where there is # only one output and its kv_transfer_params is a dict - return self.kv_transfer_params_list[0] - if len(self.kv_transfer_params_list): + return params_list[0] + if len(params_list) > 1: # for the case where there are multiple outputs, we return a list of # kv_transfer_params dicts. This is for parallel sampling (n > 1) # where each child request may have different kv_transfer_params. - return self.kv_transfer_params_list + return params_list return None def add(self, next_output: "RequestOutput", aggregate: bool) -> None: @@ -185,17 +186,9 @@ def add(self, next_output: "RequestOutput", aggregate: bool) -> None: else: # Replace the output with the new one self.outputs[i] = next_completion - if next_completion.kv_transfer_params: - self.kv_transfer_params_list[i] = ( - next_completion.kv_transfer_params - ) break else: self.outputs.append(next_completion) - if next_completion.kv_transfer_params: - self.kv_transfer_params_list.append( - next_completion.kv_transfer_params - ) def __repr__(self) -> str: return ( diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index ce24c1a4011a..bed1efc1679c 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -66,7 +66,7 @@ def _get_child_sampling_params( Child `sampling_params` instance. """ seed = self.sampling_params.seed - no_caching = seed is None and self.sampling_params.n > 1 + if self.cached_child_sampling_params: # Reuse child sampling_params data structure return self.cached_child_sampling_params @@ -81,7 +81,7 @@ def _get_child_sampling_params( "kv_transfer_params": kv_transfer[index], } if seed is None: - if not no_caching: + if not (kv_transfer and isinstance(kv_transfer, list)): # Cache child sampling_params for later reuse self.cached_child_sampling_params = child_sampling_params else: From a221751261282299345962c5191a2bdc027e9f68 Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Wed, 8 Apr 2026 12:04:14 +0800 Subject: [PATCH 13/13] [Feature] Support kv_transfer_params for parallel sampling (n>1) in PD Signed-off-by: chaunceyjiang --- vllm/outputs.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 13475f8a5c1d..e2d609973d47 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -144,9 +144,7 @@ def __init__( @property def kv_transfer_params(self) -> dict[str, Any] | list[dict[str, Any]] | None: params_list = [ - o.kv_transfer_params - for o in self.outputs - if o.kv_transfer_params + o.kv_transfer_params for o in self.outputs if o.kv_transfer_params ] if len(params_list) == 1: # keep backward compatibility for the common case where there is