Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/chat_completion/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
# vLLM-specific fields that are not in OpenAI spec
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
prompt_token_ids: list[int] | None = None
kv_transfer_params: dict[str, Any] | None = Field(
kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field(
Comment thread
chaunceyjiang marked this conversation as resolved.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we could also comment here when we expect a list

default=None, description="KVTransfer parameters."
)

Expand Down Expand Up @@ -332,7 +332,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
),
)

kv_transfer_params: dict[str, Any] | None = Field(
kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
Expand Down
19 changes: 16 additions & 3 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(
encoder_prompt_token_ids: list[int] | None = None,
num_cached_tokens: int | None = None,
*,
kv_transfer_params: dict[str, Any] | None = None,
kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None,
# Forward compatibility, code that uses args added in new release can
# still run with older versions of vLLM without breaking.
**kwargs: Any,
Expand All @@ -140,13 +140,25 @@ def __init__(
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.kv_transfer_params = kv_transfer_params
self.kv_transfer_params_list = []
if kv_transfer_params:
if isinstance(kv_transfer_params, list):
self.kv_transfer_params_list = kv_transfer_params
else:
self.kv_transfer_params_list = [kv_transfer_params]

@property
def kv_transfer_params(self) -> dict[str, Any] | list[dict[str, Any]] | None:

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will break backwards compatibility of the python API

@njhill I’ve also taken this case into consideration. I’ve added some backward-compatible handling here, and you can see that the constructor supports **kwargs: Any.

So whether it’s initialization or accessing the kv_transfer_params attribute, it should remain compatible.

if len(self.kv_transfer_params_list) == 1:
return self.kv_transfer_params_list[0]
if len(self.kv_transfer_params_list):
return self.kv_transfer_params_list
return None

def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
"""Merge subsequent RequestOutput into this one"""

self.finished |= next_output.finished
self.kv_transfer_params = next_output.kv_transfer_params

for next_completion in next_output.outputs:
for i, completion in enumerate(self.outputs):
Expand All @@ -171,6 +183,7 @@ def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
break
else:
self.outputs.append(next_completion)
self.kv_transfer_params_list.extend(next_output.kv_transfer_params_list)
Comment thread
chaunceyjiang marked this conversation as resolved.
Outdated

def __repr__(self) -> str:
return (
Expand Down
10 changes: 8 additions & 2 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,13 @@
if self.parent_req is None:
outputs = [output]
else:
outputs, finished = self.parent_req.get_outputs(self.request_id, output)
if kv_transfer_params is None:
outputs, finished = self.parent_req.get_outputs(self.request_id, output)
else:
output_with_kv_transfer = self.parent_req.aggre_kv_transfer_params(
self.request_id, output, kv_transfer_params
)
outputs, finished, kv_transfer_params = output_with_kv_transfer

Check failure on line 330 in vllm/v1/engine/output_processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "list[dict[str, Any]]", variable has type "dict[str, Any] | None") [assignment]

Check failure on line 330 in vllm/v1/engine/output_processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "list[dict[str, Any]]", variable has type "dict[str, Any] | None") [assignment]
if not outputs:
return None
external_req_id = self.parent_req.external_req_id
Expand All @@ -335,7 +341,7 @@
external_req_id: str,
outputs: list[CompletionOutput] | list[PoolingOutput],
finished: bool,
kv_transfer_params: dict[str, Any] | None = None,
kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = None,
) -> RequestOutput | PoolingRequestOutput:
# If prompt embeds were used, put placeholder prompt token ids
prompt_token_ids = self.prompt_token_ids
Expand Down
29 changes: 24 additions & 5 deletions vllm/v1/engine/parallel_sampling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from copy import copy
from typing import cast
from copy import deepcopy
Comment thread
chaunceyjiang marked this conversation as resolved.
Outdated
from typing import Any, cast

from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
Expand All @@ -26,6 +26,8 @@

# To aggregate child completions when not streaming
output_aggregator: list[CompletionOutput]
# To store kv_transfer_params for child request
output_kv_transfer_params_list: list[dict[str, Any]]

# To find the max number of generated tokens across all children
max_num_generation_tokens: int
Expand All @@ -48,6 +50,7 @@
)
self.max_num_generation_tokens = 0
self.cached_child_sampling_params = None
self.output_kv_transfer_params_list = []
Comment thread
chaunceyjiang marked this conversation as resolved.
Outdated

def _get_child_sampling_params(
self,
Expand All @@ -66,15 +69,21 @@
Child `sampling_params` instance.
"""
seed = self.sampling_params.seed
no_caching = seed is None and self.sampling_params.n > 1
Comment thread
chaunceyjiang marked this conversation as resolved.
Outdated
if self.cached_child_sampling_params:
# Reuse child sampling_params data structure
return self.cached_child_sampling_params
# Build child sampling_params
child_sampling_params = copy(self.sampling_params)
child_sampling_params = deepcopy(self.sampling_params)
child_sampling_params.n = 1
kv_transfer = child_sampling_params.extra_args.get("kv_transfer_params")

Check failure on line 79 in vllm/v1/engine/parallel_sampling.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "dict[str, Any] | None" has no attribute "get" [union-attr]

Check failure on line 79 in vllm/v1/engine/parallel_sampling.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "dict[str, Any] | None" has no attribute "get" [union-attr]
if kv_transfer is not None and isinstance(kv_transfer, list):
child_sampling_params.extra_args["kv_transfer_params"] = kv_transfer[index]

Check failure on line 81 in vllm/v1/engine/parallel_sampling.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unsupported target for indexed assignment ("dict[str, Any] | None") [index]

Check failure on line 81 in vllm/v1/engine/parallel_sampling.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unsupported target for indexed assignment ("dict[str, Any] | None") [index]
Comment thread
chaunceyjiang marked this conversation as resolved.
Outdated

if seed is None:
# Cache child sampling_params for later reuse
self.cached_child_sampling_params = child_sampling_params
if not no_caching:
# Cache child sampling_params for later reuse
self.cached_child_sampling_params = child_sampling_params
else:
# Each child gets a clone with a unique seed
child_sampling_params.seed = seed + index
Expand Down Expand Up @@ -125,6 +134,16 @@
finished = not self.child_requests
return outputs, finished

def aggre_kv_transfer_params(
self,
child_request_id: str,
completion_output: CompletionOutput,
kv_transfer_params: dict[str, Any],
) -> tuple[list[CompletionOutput], bool, list[dict[str, Any]]]:
outputs, finished = self.get_outputs(child_request_id, completion_output)
self.output_kv_transfer_params_list.append(kv_transfer_params)
Comment thread
chaunceyjiang marked this conversation as resolved.
Outdated
return outputs, finished, self.output_kv_transfer_params_list

def observe_num_generation_tokens(self, num_generation_tokens: int):
self.max_num_generation_tokens = max(
num_generation_tokens, self.max_num_generation_tokens
Expand Down
Loading