From e7c24723759519231b74ef3ef5ccfe030be3ebb0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 29 Jan 2026 10:38:29 +0000 Subject: [PATCH 1/5] [Multimodal] Simplify MM input definitions Signed-off-by: DarkLight1337 --- tests/distributed/test_shm_storage.py | 12 +- .../processing/test_tensor_schema.py | 6 +- tests/multimodal/test_cache.py | 36 ++--- tests/v1/core/test_kv_cache_utils.py | 2 +- tests/v1/core/test_prefix_caching.py | 2 +- .../v1/core/test_priority_scheduler_random.py | 2 +- tests/v1/core/test_scheduler.py | 2 +- tests/v1/core/utils.py | 2 +- .../test_gpu_model_runner_streaming.py | 4 +- .../test_scheduler_streaming.py | 4 +- tests/v1/test_serial_utils.py | 23 ++- vllm/multimodal/cache.py | 31 ++-- vllm/multimodal/inputs.py | 137 +++++++++--------- vllm/multimodal/utils.py | 8 +- vllm/v1/serial_utils.py | 12 +- vllm/v1/worker/gpu/mm/encoder_runner.py | 9 +- vllm/v1/worker/gpu_model_runner.py | 10 +- 17 files changed, 141 insertions(+), 161 deletions(-) diff --git a/tests/distributed/test_shm_storage.py b/tests/distributed/test_shm_storage.py index ea63f4a293af..fb7d5528c0da 100644 --- a/tests/distributed/test_shm_storage.py +++ b/tests/distributed/test_shm_storage.py @@ -23,18 +23,16 @@ ) -def _dummy_elem(modality: str, key: str, size: int): +def _dummy_elem(size: int): return MultiModalFieldElem( - modality=modality, - key=key, data=torch.empty((size,), dtype=torch.int8), field=MultiModalSharedField(batch_size=1), ) -def _dummy_item(modality: str, size_by_key: dict[str, int]): - return MultiModalKwargsItem.from_elems( - [_dummy_elem(modality, key, size) for key, size in size_by_key.items()] +def _dummy_item(size_by_key: dict[str, int]): + return MultiModalKwargsItem( + {key: _dummy_elem(size) for key, size in size_by_key.items()} ) @@ -61,7 +59,7 @@ def tearDown(self): def test_minimal_put_get_cycle(self): """Test basic put and get operations.""" key = "test_key" - value = _dummy_item("text", {"field1": 10, "field2": 20}) + value = _dummy_item({"field1": 10, "field2": 20}) # Put operation address, monotonic_id = self.storage.put(key, value) diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index b55dad266e03..5ce5217de2f6 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -119,7 +119,11 @@ def create_batched_mm_kwargs( )["mm_kwargs"].require_data() return group_mm_kwargs_by_modality( - [item for modality in supported_mm_limits for item in mm_kwargs[modality]] + [ + (item, modality) + for modality in supported_mm_limits + for item in mm_kwargs[modality] + ] ) diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index 36220e8f3ad9..d01b94ac9af2 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -36,8 +36,6 @@ def _dummy_elem( - modality: str, - key: str, size: int, *, rng: np.random.RandomState | None = None, @@ -48,21 +46,18 @@ def _dummy_elem( data = torch.from_numpy(rng.randint(4, size=(size,), dtype=np.int8)) return MultiModalFieldElem( - modality=modality, - key=key, data=data, field=MultiModalSharedField(batch_size=1), ) def _dummy_item( - modality: str, size_by_key: dict[str, int], *, rng: np.random.RandomState | None = None, ): - return MultiModalKwargsItem.from_elems( - [_dummy_elem(modality, key, size, rng=rng) for key, size in size_by_key.items()] + return MultiModalKwargsItem( + {key: _dummy_elem(size, rng=rng) for key, size in size_by_key.items()} ) @@ -71,19 +66,19 @@ def _dummy_items( *, rng: np.random.RandomState | None = None, ): - return MultiModalKwargsItems.from_seq( - [ - _dummy_item(modality, size_by_key, rng=rng) + return MultiModalKwargsItems( + { + modality: [_dummy_item(size_by_key, rng=rng)] for modality, size_by_key in size_by_key_modality.items() - ] + } ) @pytest.mark.parametrize( ("item", "expected_size"), [ - (_dummy_item("a", {"a1": 100}), 100), - (_dummy_item("a", {"a1": 100, "a2": 110}), 210), + (_dummy_item({"a1": 100}), 100), + (_dummy_item({"a1": 100, "a2": 110}), 210), (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 ], ) @@ -143,7 +138,7 @@ def _compare_caches( rng = np.random.RandomState(seed) all_items = [ - _dummy_item("item", {"key": item_size_gb}, rng=rng) + _dummy_item({"key": item_size_gb}, rng=rng) for _ in range(int(item_capacity / hit_rate)) ] all_hashes = [ @@ -245,13 +240,13 @@ def _run_test_cache_eviction_lru( "image_C", ] request1_items = { - h: MultiModalKwargsItem.dummy(h, nbytes=2 * base_item_size) + h: MultiModalKwargsItem.dummy(nbytes=2 * base_item_size) for h in request1_hashes } request2_hashes = ["image_D", "image_E", "image_A", "image_C"] request2_items = { - h: MultiModalKwargsItem.dummy(h, nbytes=1 * base_item_size) + h: MultiModalKwargsItem.dummy(nbytes=1 * base_item_size) for h in request2_hashes } @@ -356,15 +351,14 @@ def _run_test_cache_eviction_shm( ): request1_hashes = ["image_A", "image_B", "image_C"] request1_items = { - h: MultiModalKwargsItem.dummy(h, nbytes=5 * base_item_size) - for h in request1_hashes + h: MultiModalKwargsItem.dummy(5 * base_item_size) for h in request1_hashes } request1_items_p0_result = [] request2_hashes = ["image_G", "image_A"] request2_items = { h: MultiModalKwargsItem.dummy( - h, nbytes=(5 if h in request1_hashes else 2) * base_item_size + (5 if h in request1_hashes else 2) * base_item_size ) for h in request2_hashes } @@ -373,7 +367,7 @@ def _run_test_cache_eviction_shm( request3_hashes = ["image_G", "image_H", "image_I", "image_B"] request3_items = { h: MultiModalKwargsItem.dummy( - h, nbytes=(5 if h in request1_hashes else 2) * base_item_size + (5 if h in request1_hashes else 2) * base_item_size ) for h in request3_hashes } @@ -532,7 +526,7 @@ def test_processor_cache_shared_across_loras(): lora_a_identifier = f"12345:{base_mm_hash}" lora_b_identifier = f"67890:{base_mm_hash}" - item_data = MultiModalKwargsItem.dummy("test_image", nbytes=1024) + item_data = MultiModalKwargsItem.dummy(1024) feature_lora_a = MultiModalFeatureSpec( data=item_data, diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 65d3650c028d..d97362e06c64 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -77,7 +77,7 @@ def make_request( for j, position in enumerate(mm_positions): identifier = mm_hashes[j] if mm_hashes else f"hash_{j}" mm_feature = MultiModalFeatureSpec( - data=MultiModalKwargsItem.dummy("dummy_m"), + data=MultiModalKwargsItem.dummy(), mm_position=position, identifier=identifier, modality="image", diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 486e5f9cd4c8..52c793f48bc0 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -68,7 +68,7 @@ def make_request( for j, position in enumerate(mm_positions): identifier = mm_hashes[j] if mm_hashes else f"hash_{j}" mm_feature = MultiModalFeatureSpec( - data=MultiModalKwargsItem.dummy("dummy_m"), + data=MultiModalKwargsItem.dummy(), mm_position=position, identifier=identifier, modality="image", diff --git a/tests/v1/core/test_priority_scheduler_random.py b/tests/v1/core/test_priority_scheduler_random.py index 429b179b61dc..cb4dfc04618f 100644 --- a/tests/v1/core/test_priority_scheduler_random.py +++ b/tests/v1/core/test_priority_scheduler_random.py @@ -56,7 +56,7 @@ def _create_random_request( for j, position in enumerate(mm_positions): identifier = f"{request_id}_hash_{j}" mm_feature = MultiModalFeatureSpec( - data=MultiModalKwargsItem.dummy("dummy_m"), + data=MultiModalKwargsItem.dummy(), mm_position=position, identifier=identifier, modality="image", diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index d8e9e2e3c09b..537a02464d0b 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1707,7 +1707,7 @@ def create_requests_with_priority( # Unique dummy hash for each mm item identifier = f"hash{i}_{j}" mm_feature = MultiModalFeatureSpec( - data=MultiModalKwargsItem.dummy("dummy_m"), + data=MultiModalKwargsItem.dummy(), mm_position=position, identifier=identifier, modality="image", diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 751a29795634..00eb61285ab5 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -236,7 +236,7 @@ def create_requests( # Unique dummy hash for each mm item identifier = f"hash{i}_{j}" mm_feature = MultiModalFeatureSpec( - data=MultiModalKwargsItem.dummy("dummy_m"), + data=MultiModalKwargsItem.dummy(), mm_position=position, identifier=identifier, modality="image", diff --git a/tests/v1/streaming_input/test_gpu_model_runner_streaming.py b/tests/v1/streaming_input/test_gpu_model_runner_streaming.py index c9a641632ffa..0ed7b6cb3efc 100644 --- a/tests/v1/streaming_input/test_gpu_model_runner_streaming.py +++ b/tests/v1/streaming_input/test_gpu_model_runner_streaming.py @@ -131,7 +131,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat # Step 1: Create initial request state with one multimodal feature mm_feature_1 = MultiModalFeatureSpec( - data=MultiModalKwargsItem.dummy("audio"), + data=MultiModalKwargsItem.dummy(), modality="audio", identifier="audio_1", mm_position=PlaceholderRange(offset=2, length=10), @@ -158,7 +158,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat # The scheduler has already set prompt_token_ids to the full sequence # (original prompt + intermediate outputs + new prompt with new multimodal feature) mm_feature_2 = MultiModalFeatureSpec( - data=MultiModalKwargsItem.dummy("audio"), + data=MultiModalKwargsItem.dummy(), modality="audio", identifier="audio_2", mm_position=PlaceholderRange(offset=15, length=5), diff --git a/tests/v1/streaming_input/test_scheduler_streaming.py b/tests/v1/streaming_input/test_scheduler_streaming.py index 0387d31c98e9..f8d8c3cb3fdc 100644 --- a/tests/v1/streaming_input/test_scheduler_streaming.py +++ b/tests/v1/streaming_input/test_scheduler_streaming.py @@ -174,7 +174,7 @@ def test_update_request_as_session_with_multimodal(self): scheduler = create_scheduler() mm_feature = MultiModalFeatureSpec( - data=MultiModalKwargsItem.dummy("audio"), + data=MultiModalKwargsItem.dummy(), modality="audio", identifier="", mm_position=PlaceholderRange(offset=1, length=1), @@ -187,7 +187,7 @@ def test_update_request_as_session_with_multimodal(self): session.num_computed_tokens = len(session.prompt_token_ids) mm_feature = MultiModalFeatureSpec( - data=MultiModalKwargsItem.dummy("audio"), + data=MultiModalKwargsItem.dummy(), modality="audio", identifier="", mm_position=PlaceholderRange(offset=2, length=1), diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index dbbbfce97d28..1f12ebc1fd20 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -104,14 +104,10 @@ class MyRequest(msgspec.Struct): def test_multimodal_kwargs(): e1 = MultiModalFieldElem( - "audio", - "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField(), ) e2 = MultiModalFieldElem( - "video", - "v0", [torch.zeros(1000, dtype=torch.int8) for _ in range(4)], MultiModalFlatField( slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], @@ -119,21 +115,20 @@ def test_multimodal_kwargs(): ), ) e3 = MultiModalFieldElem( - "image", - "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(batch_size=4), ) e4 = MultiModalFieldElem( - "image", - "i1", torch.zeros(1000, dtype=torch.int32), MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2), ) - audio = MultiModalKwargsItem.from_elems([e1]) - video = MultiModalKwargsItem.from_elems([e2]) - image = MultiModalKwargsItem.from_elems([e3, e4]) - mm = MultiModalKwargsItems.from_seq([audio, video, image]) + mm = MultiModalKwargsItems( + { + "audio": [MultiModalKwargsItem({"a0": e1})], + "video": [MultiModalKwargsItem({"v0": e2})], + "image": [MultiModalKwargsItem({"i0": e3, "i1": e4})], + } + ) # pack mm kwargs into a mock request so that it can be decoded properly req = MyRequest([mm]) @@ -147,8 +142,8 @@ def test_multimodal_kwargs(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) - # expected total encoding length, should be 14395, +-20 for minor changes - assert 14375 <= total_len <= 14425 + # expected total encoding length, should be 14319, +-20 for minor changes + assert 14300 <= total_len <= 14440 decoded = decoder.decode(encoded).mm[0] assert isinstance(decoded, MultiModalKwargsItems) diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 2a0f59099c4f..c0df19d4f483 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -463,8 +463,8 @@ def __init__(self, vllm_config: "VllmConfig") -> None: ring_buffer=ring_buffer, serde_class=MsgpackSerde, ) - # cache (prompt_updates, modality) for P0 only - self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {} + # cache prompt_updates for P0 only + self._p0_cache: dict[str, Sequence[ResolvedPromptUpdate]] = {} self._hits = 0 self._total = 0 @@ -495,23 +495,22 @@ def get_and_update_item( self._total += 1 address, monotonic_id = self._shm_cache.get_cached(mm_hash) - prompt_updates, modality = self._p0_cache[mm_hash] - return self.address_as_item(address, monotonic_id, modality), prompt_updates + prompt_updates = self._p0_cache[mm_hash] + return self.address_as_item(address, monotonic_id), prompt_updates assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + item, prompt_updates = mm_item self._total += 1 try: - address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0]) + address, monotonic_id = self._shm_cache.put(mm_hash, item) # Try to remove dangling items if p0 cache is too large. if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index): self.remove_dangling_items() - self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality - address_item = self.address_as_item( - address, monotonic_id, mm_item[0].modality - ) - return address_item, mm_item[1] + + self._p0_cache[mm_hash] = prompt_updates + return self.address_as_item(address, monotonic_id), prompt_updates except (ValueError, MemoryError) as e: # put may fail if the object is too large or # the cache is full. @@ -550,22 +549,20 @@ def remove_dangling_items(self) -> None: del self._p0_cache[mm_hash] def address_as_item( - self, address: int, monotonic_id: int, modality: str + self, + address: int, + monotonic_id: int, ) -> MultiModalKwargsItem: addr_elem = MultiModalFieldElem( - modality=modality, - key="address", data=address, field=MultiModalBatchedField(), ) id_elem = MultiModalFieldElem( - modality=modality, - key="monotonic_id", data=monotonic_id, field=MultiModalBatchedField(), ) - mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem]) - return mm_item + + return MultiModalKwargsItem({"address": addr_elem, "monotonic_id": id_elem}) class BaseMultiModalReceiverCache( diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 8ce1e3587b11..46a2b777b867 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -23,7 +23,7 @@ from PIL.Image import Image from typing_extensions import NotRequired, TypeVar -from vllm.utils.collection_utils import full_groupby, is_list_of +from vllm.utils.collection_utils import is_list_of from vllm.utils.import_utils import LazyLoader from vllm.utils.jsontree import json_map_leaves @@ -336,25 +336,33 @@ class MultiModalFeatureSpec: """ Represents a single multimodal input with its processed data and metadata. - Used by the V1 engine to track multimodal data through processing and - caching. A request containing multiple multimodal items will have one - MultiModalFeatureSpec per item. + Used to track multimodal data through processing and caching. + A request containing multiple multimodal items will have one + `MultiModalFeatureSpec` per item. """ data: Optional["MultiModalKwargsItem"] - """Multimodal data for this feature""" + """ + Represents multimodal data for this feature. + + Can be `None` if the item is cached, to skip IPC between API server + and engine core processes. + """ modality: str - """Based on the input, e.g., "image", "audio", "video".""" + """The input modality, e.g., `"image"`, `"audio"`, `"video"`.""" identifier: str - """mm_hash or uuid for caching encoder outputs.""" + """The hash for caching encoder outputs (with LoRA prefix if applicable).""" mm_position: PlaceholderRange - """e.g., PlaceholderRange(offset=2, length=336)""" + """ + The location of the `modality` tokens corresponding to this item + in the prompt, e.g., `PlaceholderRange(offset=2, length=336)`. + """ mm_hash: str | None = None - """Base mm_hash for processor cache (without LoRA prefix).""" + """The hash for caching processor outputs (without LoRA prefix).""" @staticmethod def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]): @@ -373,23 +381,10 @@ def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]): @dataclass class MultiModalFieldElem: """ - Represents a keyword argument inside a + Represents a processed keyword argument to pass to a model for a [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]. """ - modality: str - """ - The modality of the corresponding multi-modal item. - Each multi-modal item can consist of multiple keyword arguments. - """ - - key: str - """ - The key of this field in - [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem], - i.e. the name of the keyword argument to be passed to the model. - """ - data: NestedTensors """ The tensor data of this field in @@ -417,11 +412,7 @@ def __eq__(self, other: object) -> bool: else: data_equal = nested_tensors_equal(self.data, other.data) - return ( - (self.modality, self.key) == (other.modality, other.key) - and data_equal - and type(self.field) is type(other.field) - ) # noqa: E721 + return data_equal and type(self.field) is type(other.field) # noqa: E721 @dataclass(frozen=True, kw_only=True) @@ -438,13 +429,8 @@ class BaseMultiModalField(ABC): when `MultiModalKwargsItems.get_data()` is called to batch the data. """ - def _field_factory(self, *, modality: str, key: str): - f = partial( - MultiModalFieldElem, - modality=modality, - key=key, - field=self, - ) + def _field_factory(self): + f = partial(MultiModalFieldElem, field=self) # Allow passing data as positional argument def factory(data: NestedTensors) -> MultiModalFieldElem: @@ -519,7 +505,7 @@ def build_elems( key: str, data: NestedTensors, ) -> Sequence[MultiModalFieldElem]: - field_factory = self._field_factory(modality=modality, key=key) + field_factory = self._field_factory() return [field_factory(item) for item in data] def _reduce_data( @@ -565,7 +551,7 @@ def build_elems( key: str, data: NestedTensors, ) -> Sequence[MultiModalFieldElem]: - field_factory = self._field_factory(modality=modality, key=key) + field_factory = self._field_factory() if not is_list_of(self.slices, slice, check="all"): assert isinstance(data, torch.Tensor), ( "torch.Tensor is required for multiple slices" @@ -623,7 +609,7 @@ def build_elems( key: str, data: NestedTensors, ) -> Sequence[MultiModalFieldElem]: - field_factory = self._field_factory(modality=modality, key=key) + field_factory = self._field_factory() return [field_factory(data)] * self.batch_size def _reduce_data( @@ -858,37 +844,19 @@ def build_elems( class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): """ - A collection of - [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem] - corresponding to a data item in + A dictionary of processed keyword arguments to pass to the model, + corresponding to a single item in [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]. """ @staticmethod - def dummy(modality: str, nbytes: int = 1): + def dummy(nbytes: int = 1): """Convenience class for testing.""" mm_elem = MultiModalFieldElem( - modality=modality, - key="dummy", data=torch.empty(nbytes, dtype=torch.uint8), field=MultiModalSharedField(batch_size=1), ) - return MultiModalKwargsItem.from_elems([mm_elem]) - - @staticmethod - def from_elems(elems: Sequence[MultiModalFieldElem]): - return MultiModalKwargsItem({elem.key: elem for elem in elems}) - - def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None: - super().__init__(data) - - modalities = {elem.modality for elem in self.values()} - assert len(modalities) == 1, f"Found different modalities={modalities}" - self._modality = next(iter(modalities)) - - @property - def modality(self) -> str: - return self._modality + return MultiModalKwargsItem({"dummy": mm_elem}) def get_data(self) -> dict[str, NestedTensors]: return {key: elem.data for key, elem in self.items()} @@ -904,9 +872,38 @@ def get_data(self) -> dict[str, NestedTensors]: class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): """ - A dictionary of - [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s - by modality. + A dictionary of processed multi-modal inputs by modality. + + For example, given a processor that processes + images into `pixel_values` and `image_grid_thw`, + and audios into `input_audio_features`, + a prompt with 2 images and 1 audio will be processed + into a `MultiModalKwargsItems` with the following structure: + + ```python + MultiModalKwargsItems( + { + "image": [ + # For the first image + MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}), + # For the second imgae + MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}), + ], + "audio": [ + # For the first audio + MultiModalKwargsItem({"input_audio_features": ...}), + ], + } + ) + ``` + + Unlike HF processing which returns all items + in a single dictionary with batched keyword arguments, + we split up the items because some of them may already be cached. + Also, items from multiple requests may be batched together to improve throughput, + using the logic defined by the + [`BaseMultiModalField`][vllm.multimodal.inputs.BaseMultiModalField] + for each keyword argument. """ @staticmethod @@ -926,7 +923,7 @@ def from_hf_inputs( elems_by_key[key] = elems keys_by_modality[config.modality].add(key) - items = list[MultiModalKwargsItem]() + items_by_modality = dict[str, list[MultiModalKwargsItem]]() for modality, keys in keys_by_modality.items(): elems_in_modality = {k: elems_by_key[k] for k in keys} batch_sizes = {k: len(v) for k, v in elems_in_modality.items()} @@ -938,15 +935,11 @@ def from_hf_inputs( ) batch_size = next(iter(batch_sizes.values())) - for item_idx in range(batch_size): - elems = [v[item_idx] for v in elems_in_modality.values()] - items.append(MultiModalKwargsItem.from_elems(elems)) - - return MultiModalKwargsItems.from_seq(items) + items_by_modality[modality] = [ + MultiModalKwargsItem({k: v[i] for k, v in elems_in_modality.items()}) + for i in range(batch_size) + ] - @staticmethod - def from_seq(items: Sequence[MultiModalKwargsItem]): - items_by_modality = full_groupby(items, key=lambda x: x.modality) return MultiModalKwargsItems(items_by_modality) def __getitem__(self, modality: str) -> Sequence[_I]: diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 85a699acea71..0a38a6e2ae9e 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -467,7 +467,7 @@ def argsort_mm_positions( def group_mm_kwargs_by_modality( - mm_kwargs: list[MultiModalKwargsItem], + mm_kwargs: list[tuple[str, MultiModalKwargsItem]], *, device: torch.types.Device = None, pin_memory: bool = False, @@ -485,9 +485,9 @@ def group_mm_kwargs_by_modality( """ from vllm.multimodal.inputs import MultiModalKwargsItems - for modality, items in groupby(mm_kwargs, key=lambda item: item.modality): - items_lst = list(items) - mm_kwargs_items = MultiModalKwargsItems.from_seq(items_lst) + for modality, group in groupby(mm_kwargs, key=lambda x: x[0]): + items_lst = [item for _, item in group] + mm_kwargs_items = MultiModalKwargsItems({modality: items_lst}) mm_kwargs_data = mm_kwargs_items.get_data( device=device, pin_memory=pin_memory, diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index a3c30e368b82..0c03de71c20a 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -242,13 +242,11 @@ def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]: for modality, itemlist in items.items() } - def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]: - return [self._encode_mm_field_elem(elem) for elem in item.values()] + def _encode_mm_item(self, item: MultiModalKwargsItem) -> dict[str, Any]: + return {key: self._encode_mm_field_elem(elem) for key, elem in item.items()} def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]: return { - "modality": elem.modality, - "key": elem.key, "data": ( None if elem.data is None else self._encode_nested_tensors(elem.data) ), @@ -383,9 +381,9 @@ def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems: } ) - def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem: - return MultiModalKwargsItem.from_elems( - [self._decode_mm_field_elem(v) for v in obj] + def _decode_mm_item(self, obj: dict[str, Any]) -> MultiModalKwargsItem: + return MultiModalKwargsItem( + {key: self._decode_mm_field_elem(elem) for key, elem in obj.items()} ) def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem: diff --git a/vllm/v1/worker/gpu/mm/encoder_runner.py b/vllm/v1/worker/gpu/mm/encoder_runner.py index d5e5bc04e9b5..bfe0bf1a36e7 100644 --- a/vllm/v1/worker/gpu/mm/encoder_runner.py +++ b/vllm/v1/worker/gpu/mm/encoder_runner.py @@ -43,9 +43,9 @@ def remove_request(self, req_id: str) -> None: def prepare_mm_inputs( self, scheduled_encoder_inputs: dict[str, list[int]], - ) -> tuple[list[str], list[MultiModalKwargsItem]]: + ) -> tuple[list[str], list[tuple[str, MultiModalKwargsItem]]]: mm_hashes: list[str] = [] - mm_kwargs: list[MultiModalKwargsItem] = [] + mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = [] for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): mm_features = self.req_id_to_mm_features[req_id] for mm_input_id in encoder_input_ids: @@ -53,7 +53,8 @@ def prepare_mm_inputs( if mm_feature.data is None: continue mm_hashes.append(mm_feature.identifier) - mm_kwargs.append(mm_feature.data) + mm_kwargs.append((mm_feature.modality, mm_feature.data)) + return mm_hashes, mm_kwargs @torch.inference_mode() @@ -61,7 +62,7 @@ def execute_mm_encoder( self, model: SupportsMultiModal, mm_hashes: list[str], - mm_kwargs: list[MultiModalKwargsItem], + mm_kwargs: list[tuple[str, MultiModalKwargsItem]], ) -> list[torch.Tensor]: if not mm_hashes: return [] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 96dab077d23d..692e5a450267 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1217,11 +1217,11 @@ def _extract_mm_kwargs( if not scheduler_output or not self.is_multimodal_raw_input_only_model: return {} - mm_kwargs = list[MultiModalKwargsItem]() + mm_kwargs = list[tuple[str, MultiModalKwargsItem]]() for req in scheduler_output.scheduled_new_reqs: for feature in req.mm_features: if feature.data is not None: - mm_kwargs.append(feature.data) + mm_kwargs.append((feature.modality, feature.data)) # Input all modalities at once mm_kwargs_combined: BatchedTensorInputs = {} @@ -2219,7 +2219,7 @@ def _batch_mm_inputs_from_scheduler( scheduler_output: "SchedulerOutput", ) -> tuple[ list[str], - list[MultiModalKwargsItem], + list[tuple[str, MultiModalKwargsItem]], list[tuple[str, PlaceholderRange]], ]: """Batch multimodal inputs from scheduled encoder inputs. @@ -2239,7 +2239,7 @@ def _batch_mm_inputs_from_scheduler( return [], [], [] mm_hashes = list[str]() - mm_kwargs = list[MultiModalKwargsItem]() + mm_kwargs = list[tuple[str, MultiModalKwargsItem]]() # Multimodal LoRA reference info to map each multimodal item # back to its request & position mm_lora_refs = list[tuple[str, PlaceholderRange]]() @@ -2252,7 +2252,7 @@ def _batch_mm_inputs_from_scheduler( continue mm_hashes.append(mm_feature.identifier) - mm_kwargs.append(mm_feature.data) + mm_kwargs.append((mm_feature.modality, mm_feature.data)) mm_lora_refs.append((req_id, mm_feature.mm_position)) return mm_hashes, mm_kwargs, mm_lora_refs From e63269318dd8444753c0da78bed11f268a8936df Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 29 Jan 2026 10:42:10 +0000 Subject: [PATCH 2/5] Fix Signed-off-by: DarkLight1337 --- tests/models/multimodal/processing/test_tensor_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 5ce5217de2f6..3229a4a59348 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -120,7 +120,7 @@ def create_batched_mm_kwargs( return group_mm_kwargs_by_modality( [ - (item, modality) + (modality, item) for modality in supported_mm_limits for item in mm_kwargs[modality] ] From f892638fc5f1fa76779d723b5aae4ad335772125 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 29 Jan 2026 10:43:08 +0000 Subject: [PATCH 3/5] Typo Signed-off-by: DarkLight1337 --- tests/v1/test_serial_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 1f12ebc1fd20..a5dc1773d477 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -143,7 +143,7 @@ def test_multimodal_kwargs(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) # expected total encoding length, should be 14319, +-20 for minor changes - assert 14300 <= total_len <= 14440 + assert 14300 <= total_len <= 14340 decoded = decoder.decode(encoded).mm[0] assert isinstance(decoded, MultiModalKwargsItems) From 7a8912f3e624f8298911b5578607a74b26b0d003 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 29 Jan 2026 11:23:51 +0000 Subject: [PATCH 4/5] Fix Signed-off-by: DarkLight1337 --- vllm/v1/worker/gpu_model_runner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 692e5a450267..1d66754605c0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4475,12 +4475,10 @@ def _get_mm_dummy_batch( # but not read from the cache assert dummy_mm_item is not None, "Item should not already be cached" - dummy_mm_items = [dummy_mm_item] * max_items_per_batch - return next( mm_kwargs_group for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - dummy_mm_items, + [(modality, dummy_mm_item) for _ in range(max_items_per_batch)], device=self.device, pin_memory=self.pin_memory, ) From 5143b343f3a7b30fab1a31f362c145231897f618 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 29 Jan 2026 11:24:34 +0000 Subject: [PATCH 5/5] Simplify Signed-off-by: DarkLight1337 --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1d66754605c0..8e21dea6900a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4478,7 +4478,7 @@ def _get_mm_dummy_batch( return next( mm_kwargs_group for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - [(modality, dummy_mm_item) for _ in range(max_items_per_batch)], + [(modality, dummy_mm_item)] * max_items_per_batch, device=self.device, pin_memory=self.pin_memory, )