Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/chat_completion/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
# vLLM-specific fields that are not in OpenAI spec
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
prompt_token_ids: list[int] | None = None
kv_transfer_params: dict[str, Any] | None = Field(
kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field(
Comment thread
chaunceyjiang marked this conversation as resolved.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we could also comment here when we expect a list

default=None, description="KVTransfer parameters."
)

Expand Down Expand Up @@ -332,7 +332,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
),
)

kv_transfer_params: dict[str, Any] | None = Field(
kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/completion/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class CompletionRequest(OpenAIBaseModel):
),
)

kv_transfer_params: dict[str, Any] | None = Field(
kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
Expand Down Expand Up @@ -481,7 +481,7 @@ class CompletionResponse(OpenAIBaseModel):
usage: UsageInfo

# vLLM-specific fields that are not in OpenAI spec
kv_transfer_params: dict[str, Any] | None = Field(
kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field(
default=None, description="KVTransfer parameters."
)

Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/openai/responses/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def append_output(self, output) -> None:
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
if output.kv_transfer_params is not None:
assert isinstance(output.kv_transfer_params, dict)
self.kv_transfer_params = output.kv_transfer_params

# Accumulate text, token_ids, and logprobs for streaming mode
Expand Down Expand Up @@ -318,6 +319,7 @@ def append_output(self, output: RequestOutput) -> None:
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
if output.kv_transfer_params is not None:
assert isinstance(output.kv_transfer_params, dict)
self.kv_transfer_params = output.kv_transfer_params
self.parser.process(output.outputs[0])
output_token_ids = output.outputs[0].token_ids or []
Expand Down Expand Up @@ -565,6 +567,7 @@ def append_output(self, output: RequestOutput) -> None:
self._update_prefill_token_usage(output)
self._update_decode_token_usage(output)
if output.kv_transfer_params is not None:
assert isinstance(output.kv_transfer_params, dict)
self.kv_transfer_params = output.kv_transfer_params
# Append current turn to all turn list for next turn's calculations
self.all_turn_metrics.append(self.current_turn_metrics.copy())
Expand Down Expand Up @@ -878,6 +881,7 @@ def append_output(self, output: RequestOutput) -> None:
self.last_content_delta = last_delta_text
self._update_decode_token_usage(output)
if output.kv_transfer_params is not None:
assert isinstance(output.kv_transfer_params, dict)
self.kv_transfer_params = output.kv_transfer_params

# For streaming, update previous turn when message is complete
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/serve/disagg/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def validate_token_ids(cls, v: list[int]) -> list[int]:
"if the served model does not use priority scheduling."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
Expand Down Expand Up @@ -135,7 +135,7 @@ class GenerateResponse(BaseModel):

prompt_logprobs: list[dict[int, Logprob] | None] | None = None

kv_transfer_params: dict[str, Any] | None = Field(
kv_transfer_params: dict[str, Any] | list[dict[str, Any]] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
33 changes: 25 additions & 8 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class CompletionOutput:
to stop, None if the completion finished for some other reason
including encountering the EOS token.
lora_request: The LoRA request that was used to generate the output.
kv_transfer_params: The params for remote K/V transfer.
"""

index: int
Expand All @@ -46,6 +47,7 @@ class CompletionOutput:
finish_reason: str | None = None
stop_reason: int | str | None = None
lora_request: LoRARequest | None = None
kv_transfer_params: dict[str, Any] | None = None

def finished(self) -> bool:
return self.finish_reason is not None
Expand All @@ -59,7 +61,8 @@ def __repr__(self) -> str:
f"cumulative_logprob={self.cumulative_logprob}, "
f"logprobs={self.logprobs}, "
f"finish_reason={self.finish_reason}, "
f"stop_reason={self.stop_reason})"
f"stop_reason={self.stop_reason}, "
f"kv_transfer_params={self.kv_transfer_params})"
)


Expand Down Expand Up @@ -103,7 +106,6 @@ class RequestOutput:
encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit.
kv_transfer_params: The params for remote K/V transfer.
"""

def __init__(
Expand All @@ -119,8 +121,6 @@ def __init__(
encoder_prompt: str | None = None,
encoder_prompt_token_ids: list[int] | None = None,
num_cached_tokens: int | None = None,
*,
kv_transfer_params: dict[str, Any] | None = None,
# Forward compatibility, code that uses args added in new release can
# still run with older versions of vLLM without breaking.
**kwargs: Any,
Expand All @@ -140,14 +140,27 @@ def __init__(
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.kv_transfer_params = kv_transfer_params

@property
def kv_transfer_params(self) -> dict[str, Any] | list[dict[str, Any]] | None:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will break backwards compatibility of the python API

@njhill I’ve also taken this case into consideration. I’ve added some backward-compatible handling here, and you can see that the constructor supports **kwargs: Any.

So whether it’s initialization or accessing the kv_transfer_params attribute, it should remain compatible.

params_list = [
o.kv_transfer_params for o in self.outputs if o.kv_transfer_params
]
if len(params_list) == 1:
# keep backward compatibility for the common case where there is
# only one output and its kv_transfer_params is a dict
return params_list[0]
if len(params_list) > 1:
# for the case where there are multiple outputs, we return a list of
# kv_transfer_params dicts. This is for parallel sampling (n > 1)
# where each child request may have different kv_transfer_params.
return params_list
return None

def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
"""Merge subsequent RequestOutput into this one"""

self.finished |= next_output.finished
self.kv_transfer_params = next_output.kv_transfer_params

for next_completion in next_output.outputs:
for i, completion in enumerate(self.outputs):
if completion.index == next_completion.index:
Expand All @@ -165,6 +178,9 @@ def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
)
completion.finish_reason = next_completion.finish_reason
completion.stop_reason = next_completion.stop_reason
completion.kv_transfer_params = (
next_completion.kv_transfer_params
)
else:
# Replace the output with the new one
self.outputs[i] = next_completion
Expand All @@ -184,7 +200,8 @@ def __repr__(self) -> str:
f"finished={self.finished}, "
f"metrics={self.metrics}, "
f"lora_request={self.lora_request}, "
f"num_cached_tokens={self.num_cached_tokens})"
f"num_cached_tokens={self.num_cached_tokens}, "
f"kv_transfer_params={self.kv_transfer_params})"
)


Expand Down
14 changes: 8 additions & 6 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,11 @@ def make_request_output(
)

output = self._new_completion_output(
new_token_ids, finish_reason, stop_reason, routed_experts
new_token_ids,
finish_reason,
stop_reason,
routed_experts,
kv_transfer_params,
)

if self.parent_req is None:
Expand All @@ -326,16 +330,13 @@ def make_request_output(
return None
external_req_id = self.parent_req.external_req_id

return self._new_request_output(
external_req_id, outputs, finished, kv_transfer_params
)
return self._new_request_output(external_req_id, outputs, finished)

def _new_request_output(
self,
external_req_id: str,
outputs: list[CompletionOutput] | list[PoolingOutput],
finished: bool,
kv_transfer_params: dict[str, Any] | None = None,
) -> RequestOutput | PoolingRequestOutput:
# If prompt embeds were used, put placeholder prompt token ids
prompt_token_ids = self.prompt_token_ids
Expand Down Expand Up @@ -368,7 +369,6 @@ def _new_request_output(
prompt_logprobs=prompt_logprobs,
outputs=cast(list[CompletionOutput], outputs),
finished=finished,
kv_transfer_params=kv_transfer_params,
num_cached_tokens=self.num_cached_tokens,
metrics=self.stats,
)
Expand All @@ -379,6 +379,7 @@ def _new_completion_output(
finish_reason: FinishReason | None,
stop_reason: int | str | None,
routed_experts: np.ndarray | None = None,
kv_transfer_params: dict[str, Any] | None = None,
) -> CompletionOutput:
assert self.detokenizer is not None
assert self.logprobs_processor is not None
Expand All @@ -404,6 +405,7 @@ def _new_completion_output(
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
finish_reason=str(finish_reason) if finished else None,
stop_reason=stop_reason if finished else None,
kv_transfer_params=kv_transfer_params,
)

def _new_pooling_output(self, pooling_output: torch.Tensor) -> PoolingOutput:
Expand Down
13 changes: 11 additions & 2 deletions vllm/v1/engine/parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,24 @@ def _get_child_sampling_params(
Child `sampling_params` instance.
"""
seed = self.sampling_params.seed

if self.cached_child_sampling_params:
# Reuse child sampling_params data structure
return self.cached_child_sampling_params
# Build child sampling_params
child_sampling_params = copy(self.sampling_params)
child_sampling_params.n = 1
extra_args = child_sampling_params.extra_args or {}
kv_transfer = extra_args.get("kv_transfer_params")
if kv_transfer and isinstance(kv_transfer, list):
child_sampling_params.extra_args = {
**extra_args,
"kv_transfer_params": kv_transfer[index],
}
if seed is None:
# Cache child sampling_params for later reuse
self.cached_child_sampling_params = child_sampling_params
if not (kv_transfer and isinstance(kv_transfer, list)):
# Cache child sampling_params for later reuse
self.cached_child_sampling_params = child_sampling_params
else:
# Each child gets a clone with a unique seed
child_sampling_params.seed = seed + index
Expand Down
Loading