Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def _run_incremental_decode(tokenizer,
request = EngineCoreRequest("",
prompt_token_ids,
None,
None,
None,
params,
None,
None,
Expand Down
22 changes: 13 additions & 9 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch

from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_manager import KVCacheManager
Expand Down Expand Up @@ -37,17 +38,20 @@ def make_request(
mm_hashes: Optional[list[str]] = None,
cache_salt: Optional[str] = None,
):
if mm_positions is None:
mm_kwargs = None
else:
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_positions)
mm_features = []
if mm_positions is not None:
for j, position in enumerate(mm_positions):
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)

return Request(request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
mm_features=mm_features if mm_features else None,
sampling_params=SamplingParams(max_tokens=17),
pooling_params=None,
eos_token_id=100,
Expand Down
22 changes: 13 additions & 9 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import torch

from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.block_pool import BlockPool
Expand All @@ -32,17 +33,20 @@ def make_request(
prompt_logprobs: Optional[int] = None,
cache_salt: Optional[str] = None,
):
if mm_positions is None:
mm_kwargs = None
else:
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_positions)
mm_features = []
if mm_positions is not None:
for j, position in enumerate(mm_positions):
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)

return Request(request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
mm_features=mm_features if mm_features else None,
sampling_params=SamplingParams(
max_tokens=17, prompt_logprobs=prompt_logprobs),
pooling_params=None,
Expand Down
26 changes: 14 additions & 12 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
Expand Down Expand Up @@ -1307,21 +1308,24 @@ def create_requests_with_priority(
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
mm_features = []
if mm_positions is not None:
mm_position = mm_positions[i]
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_position)
else:
mm_position = None
mm_kwargs = None
for j, position in enumerate(mm_position):
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)

request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_times[i],
priority=priorities[i],
Expand Down Expand Up @@ -1800,9 +1804,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
multi_modal_kwargs=None,
multi_modal_hashes=None,
multi_modal_placeholders=None,
mm_features=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
Expand Down
30 changes: 15 additions & 15 deletions tests/v1/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
Expand Down Expand Up @@ -139,29 +140,28 @@ def create_requests(
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
mm_features = []
if mm_positions is not None:
mm_position = mm_positions[i]
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_position)
# Dummy hash for each mm item should be unique
# since encoder cache tracks entries by hash
mm_hashes = [
"hash" + str(i) + "_" + str(j) for j in range(len(mm_position))
]
else:
mm_position = None
mm_kwargs = None
mm_hashes = None
for j, position in enumerate(mm_position):
# Dummy hash for each mm item should be unique
# since encoder cache tracks entries by hash
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)

prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
num_tokens)
request = Request(
request_id=f"{i}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=mm_hashes,
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)
Expand Down
4 changes: 1 addition & 3 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def make_request() -> EngineCoreRequest:
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=PROMPT_TOKENS,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
sampling_params=SamplingParams(),
pooling_params=None,
eos_token_id=None,
Expand Down
4 changes: 1 addition & 3 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def make_request(
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=prompt_tokens_ids,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
Expand Down
18 changes: 8 additions & 10 deletions tests/v1/engine/test_fast_incdec_prefix_err.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,14 @@ def test_fast_inc_detok_invalid_utf8_err_case():
prompt_token_ids = [107, 4606, 236787, 107]
params = SamplingParams(skip_special_tokens=True)
request = EngineCoreRequest(
"test",
prompt_token_ids,
None,
None,
None,
params,
None,
None,
0.0,
None,
request_id="test",
prompt_token_ids=prompt_token_ids,
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)
Expand Down
30 changes: 10 additions & 20 deletions tests/v1/engine/test_output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,9 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
requests = [
EngineCoreRequest(request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
Expand Down Expand Up @@ -401,11 +399,9 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
requests = [
EngineCoreRequest(request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
Expand Down Expand Up @@ -566,11 +562,9 @@ def test_stop_token(include_stop_str_in_output: bool,
request = EngineCoreRequest(
request_id=request_id,
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
eos_token_id=eos_token_id,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
Expand Down Expand Up @@ -665,11 +659,9 @@ def test_stop_string(include_stop_str_in_output: bool,
EngineCoreRequest(
request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
Expand Down Expand Up @@ -781,11 +773,9 @@ def test_iteration_stats(dummy_test_vectors):
EngineCoreRequest(
request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
Expand Down
4 changes: 1 addition & 3 deletions tests/v1/kv_connector/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ def create_request(request_id: int,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_kwargs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
mm_features=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=get_request_block_hasher(block_size, hash_fn),
)
Expand Down
16 changes: 13 additions & 3 deletions vllm/multimodal/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from vllm.utils import GiB_bytes, LRUCache
from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves

from .inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem, MultiModalKwargsItems,
NestedTensors)
from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem,
MultiModalKwargs, MultiModalKwargsItem,
MultiModalKwargsItems, NestedTensors)

if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
Expand Down Expand Up @@ -418,6 +418,16 @@ class BaseMultiModalReceiverCache(
MultiModalKwargsItem]):
"""The required interface for caches on P1."""

def get_and_update_features(
self,
mm_features: list["MultiModalFeatureSpec"],
) -> list["MultiModalFeatureSpec"]:
"""Update multimodal features with cached encoder outputs."""
for feature in mm_features:
feature.data = self.get_and_update_item(feature.data,
feature.identifier)
return mm_features


class MultiModalReceiverCache(BaseMultiModalReceiverCache):
"""
Expand Down
23 changes: 23 additions & 0 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,29 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
"""


@dataclass
class MultiModalFeatureSpec:
"""
Represents a single multimodal input with its processed data and metadata.

Used by the V1 engine to track multimodal data through processing and
caching. A request containing multiple multimodal items will have one
MultiModalFeatureSpec per item.
"""

data: Optional["MultiModalKwargsItem"]
"""Multimodal data for this feature"""

modality: str
"""Based on the input, e.g., "image", "audio", "video"."""

identifier: str
"""mm_hash or uuid for caching encoder outputs."""

mm_position: PlaceholderRange
"""e.g., PlaceholderRange(offset=2, length=336)"""


@dataclass
class MultiModalFieldElem:
"""
Expand Down
Loading