diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index 533959df6094..70aca1e11fd3 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -107,7 +107,7 @@ class ChatCompletionResponse(OpenAIBaseModel): # vLLM-specific fields that are not in OpenAI spec prompt_logprobs: list[dict[int, Logprob] | None] | None = None prompt_token_ids: list[int] | None = None - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters." ) @@ -332,7 +332,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ), ) - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) diff --git a/vllm/entrypoints/openai/completion/protocol.py b/vllm/entrypoints/openai/completion/protocol.py index c785d254084d..a0512fe87adf 100644 --- a/vllm/entrypoints/openai/completion/protocol.py +++ b/vllm/entrypoints/openai/completion/protocol.py @@ -154,7 +154,7 @@ class CompletionRequest(OpenAIBaseModel): ), ) - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) @@ -481,7 +481,7 @@ class CompletionResponse(OpenAIBaseModel): usage: UsageInfo # vLLM-specific fields that are not in OpenAI spec - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters." ) diff --git a/vllm/entrypoints/openai/responses/context.py b/vllm/entrypoints/openai/responses/context.py index 48360173cf48..6cfcaa8b783a 100644 --- a/vllm/entrypoints/openai/responses/context.py +++ b/vllm/entrypoints/openai/responses/context.py @@ -192,6 +192,7 @@ def append_output(self, output) -> None: self.num_cached_tokens = output.num_cached_tokens or 0 self.num_output_tokens += len(output.outputs[0].token_ids or []) if output.kv_transfer_params is not None: + assert isinstance(output.kv_transfer_params, dict) self.kv_transfer_params = output.kv_transfer_params # Accumulate text, token_ids, and logprobs for streaming mode @@ -318,6 +319,7 @@ def append_output(self, output: RequestOutput) -> None: self.num_cached_tokens = output.num_cached_tokens or 0 self.num_output_tokens += len(output.outputs[0].token_ids or []) if output.kv_transfer_params is not None: + assert isinstance(output.kv_transfer_params, dict) self.kv_transfer_params = output.kv_transfer_params self.parser.process(output.outputs[0]) output_token_ids = output.outputs[0].token_ids or [] @@ -565,6 +567,7 @@ def append_output(self, output: RequestOutput) -> None: self._update_prefill_token_usage(output) self._update_decode_token_usage(output) if output.kv_transfer_params is not None: + assert isinstance(output.kv_transfer_params, dict) self.kv_transfer_params = output.kv_transfer_params # Append current turn to all turn list for next turn's calculations self.all_turn_metrics.append(self.current_turn_metrics.copy()) @@ -878,6 +881,7 @@ def append_output(self, output: RequestOutput) -> None: self.last_content_delta = last_delta_text self._update_decode_token_usage(output) if output.kv_transfer_params is not None: + assert isinstance(output.kv_transfer_params, dict) self.kv_transfer_params = output.kv_transfer_params # For streaming, update previous turn when message is complete diff --git a/vllm/entrypoints/serve/disagg/protocol.py b/vllm/entrypoints/serve/disagg/protocol.py index af4e8c20c14c..15d980540484 100644 --- a/vllm/entrypoints/serve/disagg/protocol.py +++ b/vllm/entrypoints/serve/disagg/protocol.py @@ -102,7 +102,7 @@ def validate_token_ids(cls, v: list[int]) -> list[int]: "if the served model does not use priority scheduling." ), ) - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) @@ -135,7 +135,7 @@ class GenerateResponse(BaseModel): prompt_logprobs: list[dict[int, Logprob] | None] | None = None - kv_transfer_params: dict[str, Any] | None = Field( + kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) diff --git a/vllm/outputs.py b/vllm/outputs.py index 2c71d2afb1b5..e2d609973d47 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -35,6 +35,7 @@ class CompletionOutput: to stop, None if the completion finished for some other reason including encountering the EOS token. lora_request: The LoRA request that was used to generate the output. + kv_transfer_params: The params for remote K/V transfer. """ index: int @@ -46,6 +47,7 @@ class CompletionOutput: finish_reason: str | None = None stop_reason: int | str | None = None lora_request: LoRARequest | None = None + kv_transfer_params: dict[str, Any] | None = None def finished(self) -> bool: return self.finish_reason is not None @@ -59,7 +61,8 @@ def __repr__(self) -> str: f"cumulative_logprob={self.cumulative_logprob}, " f"logprobs={self.logprobs}, " f"finish_reason={self.finish_reason}, " - f"stop_reason={self.stop_reason})" + f"stop_reason={self.stop_reason}, " + f"kv_transfer_params={self.kv_transfer_params})" ) @@ -103,7 +106,6 @@ class RequestOutput: encoder_prompt_token_ids: The token IDs of the encoder prompt. None if decoder-only. num_cached_tokens: The number of tokens with prefix cache hit. - kv_transfer_params: The params for remote K/V transfer. """ def __init__( @@ -119,8 +121,6 @@ def __init__( encoder_prompt: str | None = None, encoder_prompt_token_ids: list[int] | None = None, num_cached_tokens: int | None = None, - *, - kv_transfer_params: dict[str, Any] | None = None, # Forward compatibility, code that uses args added in new release can # still run with older versions of vLLM without breaking. **kwargs: Any, @@ -140,14 +140,27 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens - self.kv_transfer_params = kv_transfer_params + + @property + def kv_transfer_params(self) -> dict[str, Any] | list[dict[str, Any]] | None: + params_list = [ + o.kv_transfer_params for o in self.outputs if o.kv_transfer_params + ] + if len(params_list) == 1: + # keep backward compatibility for the common case where there is + # only one output and its kv_transfer_params is a dict + return params_list[0] + if len(params_list) > 1: + # for the case where there are multiple outputs, we return a list of + # kv_transfer_params dicts. This is for parallel sampling (n > 1) + # where each child request may have different kv_transfer_params. + return params_list + return None def add(self, next_output: "RequestOutput", aggregate: bool) -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished - self.kv_transfer_params = next_output.kv_transfer_params - for next_completion in next_output.outputs: for i, completion in enumerate(self.outputs): if completion.index == next_completion.index: @@ -165,6 +178,9 @@ def add(self, next_output: "RequestOutput", aggregate: bool) -> None: ) completion.finish_reason = next_completion.finish_reason completion.stop_reason = next_completion.stop_reason + completion.kv_transfer_params = ( + next_completion.kv_transfer_params + ) else: # Replace the output with the new one self.outputs[i] = next_completion @@ -184,7 +200,8 @@ def __repr__(self) -> str: f"finished={self.finished}, " f"metrics={self.metrics}, " f"lora_request={self.lora_request}, " - f"num_cached_tokens={self.num_cached_tokens})" + f"num_cached_tokens={self.num_cached_tokens}, " + f"kv_transfer_params={self.kv_transfer_params})" ) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index f9e965092288..7ba880435812 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -315,7 +315,11 @@ def make_request_output( ) output = self._new_completion_output( - new_token_ids, finish_reason, stop_reason, routed_experts + new_token_ids, + finish_reason, + stop_reason, + routed_experts, + kv_transfer_params, ) if self.parent_req is None: @@ -326,16 +330,13 @@ def make_request_output( return None external_req_id = self.parent_req.external_req_id - return self._new_request_output( - external_req_id, outputs, finished, kv_transfer_params - ) + return self._new_request_output(external_req_id, outputs, finished) def _new_request_output( self, external_req_id: str, outputs: list[CompletionOutput] | list[PoolingOutput], finished: bool, - kv_transfer_params: dict[str, Any] | None = None, ) -> RequestOutput | PoolingRequestOutput: # If prompt embeds were used, put placeholder prompt token ids prompt_token_ids = self.prompt_token_ids @@ -368,7 +369,6 @@ def _new_request_output( prompt_logprobs=prompt_logprobs, outputs=cast(list[CompletionOutput], outputs), finished=finished, - kv_transfer_params=kv_transfer_params, num_cached_tokens=self.num_cached_tokens, metrics=self.stats, ) @@ -379,6 +379,7 @@ def _new_completion_output( finish_reason: FinishReason | None, stop_reason: int | str | None, routed_experts: np.ndarray | None = None, + kv_transfer_params: dict[str, Any] | None = None, ) -> CompletionOutput: assert self.detokenizer is not None assert self.logprobs_processor is not None @@ -404,6 +405,7 @@ def _new_completion_output( cumulative_logprob=self.logprobs_processor.cumulative_logprob, finish_reason=str(finish_reason) if finished else None, stop_reason=stop_reason if finished else None, + kv_transfer_params=kv_transfer_params, ) def _new_pooling_output(self, pooling_output: torch.Tensor) -> PoolingOutput: diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 8eb6fa057d37..bed1efc1679c 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -66,15 +66,24 @@ def _get_child_sampling_params( Child `sampling_params` instance. """ seed = self.sampling_params.seed + if self.cached_child_sampling_params: # Reuse child sampling_params data structure return self.cached_child_sampling_params # Build child sampling_params child_sampling_params = copy(self.sampling_params) child_sampling_params.n = 1 + extra_args = child_sampling_params.extra_args or {} + kv_transfer = extra_args.get("kv_transfer_params") + if kv_transfer and isinstance(kv_transfer, list): + child_sampling_params.extra_args = { + **extra_args, + "kv_transfer_params": kv_transfer[index], + } if seed is None: - # Cache child sampling_params for later reuse - self.cached_child_sampling_params = child_sampling_params + if not (kv_transfer and isinstance(kv_transfer, list)): + # Cache child sampling_params for later reuse + self.cached_child_sampling_params = child_sampling_params else: # Each child gets a clone with a unique seed child_sampling_params.seed = seed + index