Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
dbb2ab9
Add KVTransferParams, modify all generate interfaces.
ShangmingCai Feb 7, 2025
d031995
Pass kv_transfer_params to model_input in model_runner.
ShangmingCai Feb 7, 2025
3c306b4
Add test for KVTransferParams.
ShangmingCai Feb 8, 2025
ff4b89e
Merge branch 'main' into add_kv_transfer_params
ShangmingCai Feb 8, 2025
56d80c9
fix ruff
ShangmingCai Feb 8, 2025
4d9cc92
fix mypy
ShangmingCai Feb 8, 2025
bcab623
minor
ShangmingCai Feb 8, 2025
10c20c2
fix msgspec type error
ShangmingCai Feb 10, 2025
a45736e
fix v1 error
ShangmingCai Feb 11, 2025
88f81d8
add args
ShangmingCai Feb 11, 2025
1f194db
verified that #12959 can fix distributed-tests-4-gpus ci
ShangmingCai Feb 12, 2025
0395307
retrigger ci
ShangmingCai Feb 12, 2025
0d8ef2c
retrigger ci
ShangmingCai Feb 12, 2025
d33b0dc
Merge branch 'main' into add_kv_transfer_params
ShangmingCai Feb 12, 2025
c373b43
refactor
ShangmingCai Feb 17, 2025
3404057
remove unused import
ShangmingCai Feb 17, 2025
dbece05
fix ruff
ShangmingCai Feb 17, 2025
14b213f
typo
ShangmingCai Feb 17, 2025
943460e
fix isort
ShangmingCai Feb 17, 2025
f261591
minor
ShangmingCai Feb 17, 2025
14da3b9
minor
ShangmingCai Feb 17, 2025
be1a184
add tests
ShangmingCai Feb 18, 2025
5dcad80
fix isort
ShangmingCai Feb 18, 2025
c6f54e0
Merge remote-tracking branch 'upstream/main' into add_kv_transfer_params
ShangmingCai Feb 18, 2025
e428ea4
retrigger ci
ShangmingCai Feb 18, 2025
77701bb
add more tests
ShangmingCai Feb 19, 2025
a1c3f80
Merge remote-tracking branch 'upstream/main' into add_kv_transfer_params
ShangmingCai Feb 19, 2025
35f0961
Merge branch 'main' into add_kv_transfer_params
ShangmingCai Mar 3, 2025
45764f7
retrigger ci after addressing the typing conflicts with 13971
ShangmingCai Mar 3, 2025
f71be5d
try entrypoints test again
ShangmingCai Mar 3, 2025
c463f48
Revert openapi changes
ShangmingCai Mar 14, 2025
9c55445
Merge remote-tracking branch 'upstream/main' into add_kv_transfer_params
ShangmingCai Mar 14, 2025
05bce94
Refactor mooncake xpyd
ShangmingCai Mar 14, 2025
6b98e2b
Fix doc
ShangmingCai Mar 17, 2025
c56b6d3
fix interface
ShangmingCai Mar 17, 2025
e5736b5
fix proxy
ShangmingCai Mar 17, 2025
2aaef25
fix typo
ShangmingCai Mar 17, 2025
a798d31
address comments
ShangmingCai Mar 25, 2025
b366931
add docstring
ShangmingCai Mar 25, 2025
abd1b3f
fix magic numbers
ShangmingCai Mar 27, 2025
8f96f29
Merge remote-tracking branch 'upstream/main' into add_kv_transfer_params
ShangmingCai Mar 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions tests/kv_transfer/test_kv_transfer_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for the KVTransferParams class.
"""
from vllm import KVTransferParams


def test_all_none():
# None should be allowed
KVTransferParams(prefix_prompt_ids=None,
kvcache_load_keys=None,
kvcache_store_keys=None)
KVTransferParams(prefix_prompt_ids=[None],
kvcache_load_keys=[None],
kvcache_store_keys=[None])

# Note(shangming): KVCache transfer may have different granularity,
# such as block-level, so the length of kvcache_load_keys and
# kvcache_store_keys has no strong correspondence with the length of
# prefix_prompt_ids.

# prefill node cases
KVTransferParams(prefix_prompt_ids=[1, 2, 3],
kvcache_load_keys=None,
kvcache_store_keys=["key1", "key2", "key3"])

KVTransferParams(prefix_prompt_ids=[None, [1, 2, 3]],
kvcache_load_keys=[None, None],
kvcache_store_keys=[None, ["key1"]])

# decode node cases
KVTransferParams(prefix_prompt_ids=[1, 2, 3],
kvcache_load_keys=["key1", "key2", "key3"],
kvcache_store_keys=None)
KVTransferParams(prefix_prompt_ids=[None, [1, 2, 3]],
kvcache_load_keys=[None, ["key1"]],
kvcache_store_keys=[None, None])

# prefix cache sharing cases
KVTransferParams(prefix_prompt_ids=[[1, 2, 3], [1, 2]],
kvcache_load_keys=[["key1", "key2", "key3"],
["key1", "key2"]],
kvcache_store_keys=[None, None])
KVTransferParams(prefix_prompt_ids=[[1, 2, 3], [4, 5, 6]],
kvcache_load_keys=[["key1", "key2", "key3"], None],
kvcache_store_keys=[None, ["key4", "key5", "key6"]])


if __name__ == "__main__":
import pytest
pytest.main([__file__])
2 changes: 2 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.kv_transfer_params import KVTransferParams
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
CompletionOutput, EmbeddingOutput,
Expand Down Expand Up @@ -59,4 +60,5 @@
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
"KVTransferParams",
]
1 change: 1 addition & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,7 @@ def schedule(
computed_block_nums=common_computed_block_nums,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
kv_transfer_params=seq_group.kv_transfer_params,
state=seq_group.state,
token_type_ids=seq_group.token_type_ids,
# `multi_modal_data` will only be present for the 1st comm
Expand Down
13 changes: 13 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.kv_transfer_params import KVTransferParams
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
Expand Down Expand Up @@ -437,6 +438,7 @@ async def add_request_async(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> None:
...
Expand All @@ -451,6 +453,7 @@ async def add_request_async(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> None:
...
Expand All @@ -468,6 +471,7 @@ async def add_request_async(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
Expand Down Expand Up @@ -519,6 +523,7 @@ async def add_request_async(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
kv_transfer_params=kv_transfer_params,
priority=priority,
)

Expand Down Expand Up @@ -856,6 +861,7 @@ def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, PoolingRequestOutput], None]]:
Expand All @@ -871,6 +877,7 @@ def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, PoolingRequestOutput], None]]:
Expand All @@ -889,6 +896,7 @@ async def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
Expand Down Expand Up @@ -921,6 +929,7 @@ async def add_request(
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
kv_transfer_params=kv_transfer_params,
priority=priority,
)

Expand All @@ -934,6 +943,7 @@ async def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
Expand All @@ -951,6 +961,8 @@ async def generate(
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
kv_transfer_params: The KVCache transfer parameters to use
for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.

Expand Down Expand Up @@ -1010,6 +1022,7 @@ async def generate(
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
kv_transfer_params=kv_transfer_params,
priority=priority,
):
yield LLMEngine.validate_output(output, RequestOutput)
Expand Down
11 changes: 11 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.kv_transfer_params import KVTransferParams
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -551,6 +552,7 @@ def _add_processed_request(
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> Optional[SequenceGroup]:
"""Add a processed request to the engine's request pool.
Expand All @@ -566,6 +568,7 @@ def _add_processed_request(
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
kv_transfer_params=kv_transfer_params,
priority=priority,
)
return None
Expand Down Expand Up @@ -601,6 +604,7 @@ def _add_processed_request(
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
kv_transfer_params=kv_transfer_params,
priority=priority)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
Expand Down Expand Up @@ -639,6 +643,7 @@ def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> None:
...
Expand All @@ -655,6 +660,7 @@ def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> None:
...
Expand All @@ -672,6 +678,7 @@ def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
Expand All @@ -694,6 +701,7 @@ def add_request(
lora_request: The LoRA request to add.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: The prompt adapter request to add.
kv_transfer_params: The KVCache transfer parameters to add.
priority: The priority of the request.
Only applicable with priority scheduling.

Expand Down Expand Up @@ -764,6 +772,7 @@ def add_request(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
kv_transfer_params=kv_transfer_params,
priority=priority,
)

Expand Down Expand Up @@ -798,6 +807,7 @@ def _create_sequence_group_with_sampling(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
Expand Down Expand Up @@ -829,6 +839,7 @@ def _create_sequence_group_with_sampling(
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
kv_transfer_params=kv_transfer_params,
priority=priority)

return seq_group
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from vllm import PoolingParams
from vllm.inputs import PromptType
from vllm.kv_transfer_params import KVTransferParams
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand All @@ -35,6 +36,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
kv_transfer_params: Optional[KVTransferParams] = None
priority: int = 0

@overload
Expand All @@ -46,6 +48,7 @@ def __init__(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> None:
...
Expand All @@ -61,6 +64,7 @@ def __init__(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> None:
...
Expand All @@ -77,6 +81,7 @@ def __init__(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
Expand All @@ -94,6 +99,7 @@ def __init__(
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.kv_transfer_params = kv_transfer_params
self.priority = priority


Expand Down
11 changes: 10 additions & 1 deletion vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.kv_transfer_params import KVTransferParams
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -441,6 +442,7 @@ def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
...
Expand All @@ -456,6 +458,7 @@ def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
...
Expand All @@ -472,6 +475,7 @@ def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None # DEPRECATED
Expand All @@ -491,6 +495,8 @@ def generate(
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
kv_transfer_params: The KVCache transfer parameters to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
Expand All @@ -502,7 +508,8 @@ def generate(

return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request, priority)
prompt_adapter_request,
kv_transfer_params, priority)

@overload
def encode(
Expand Down Expand Up @@ -585,6 +592,7 @@ async def _process_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
Expand Down Expand Up @@ -638,6 +646,7 @@ async def _process_request(
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
kv_transfer_params=kv_transfer_params,
priority=priority,
))

Expand Down
1 change: 1 addition & 0 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def _handle_process_request(self, request: RPCProcessRequest):
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
kv_transfer_params=request.kv_transfer_params,
priority=request.priority)

if self.log_requests:
Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.kv_transfer_params import KVTransferParams
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -55,6 +56,7 @@ def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
Expand Down
Loading