diff --git a/tests/runner/test_multimodal_manager.py b/tests/runner/test_multimodal_manager.py index 3441df0e8c..5cc8c779d0 100644 --- a/tests/runner/test_multimodal_manager.py +++ b/tests/runner/test_multimodal_manager.py @@ -98,12 +98,12 @@ def test_execute_mm_encoder_single_image(self): # Mock request state dummy_pixel_values = torch.randn(3, 224, 224, dtype=torch.bfloat16) dummy_grid_thw = torch.tensor([[1, 1, 1]], dtype=torch.int64) - mm_item = MultiModalKwargsItem.from_elems([ - MultiModalFieldElem("image", "pixel_values", dummy_pixel_values, - MultiModalBatchedField()), - MultiModalFieldElem("image", "image_grid_thw", dummy_grid_thw, - MultiModalBatchedField()) - ]) + mm_item = MultiModalKwargsItem({ + "pixel_values": + MultiModalFieldElem(dummy_pixel_values, MultiModalBatchedField()), + "image_grid_thw": + MultiModalFieldElem(dummy_grid_thw, MultiModalBatchedField()) + }) req_state = CachedRequestState( req_id="req-1", @@ -183,12 +183,12 @@ def test_execute_mm_encoder_multiple_images(self): px_1 = torch.randn(3, 224, 224, dtype=torch.bfloat16) grid_1 = torch.tensor([[1, 1, 1]], dtype=torch.int64) - mm_item_1 = MultiModalKwargsItem.from_elems([ - MultiModalFieldElem("image", "pixel_values", px_1, - MultiModalBatchedField()), - MultiModalFieldElem("image", "image_grid_thw", grid_1, - MultiModalBatchedField()) - ]) + mm_item_1 = MultiModalKwargsItem({ + "pixel_values": + MultiModalFieldElem(px_1, MultiModalBatchedField()), + "image_grid_thw": + MultiModalFieldElem(grid_1, MultiModalBatchedField()) + }) req_state_1 = CachedRequestState( req_id="req-1", @@ -210,12 +210,12 @@ def test_execute_mm_encoder_multiple_images(self): px_2 = torch.randn(3, 224, 224, dtype=torch.bfloat16) grid_2 = torch.tensor([[1, 2, 2]], dtype=torch.int64) - mm_item_2 = MultiModalKwargsItem.from_elems([ - MultiModalFieldElem("image", "pixel_values", px_2, - MultiModalBatchedField()), - MultiModalFieldElem("image", "image_grid_thw", grid_2, - MultiModalBatchedField()) - ]) + mm_item_2 = MultiModalKwargsItem({ + "pixel_values": + MultiModalFieldElem(px_2, MultiModalBatchedField()), + "image_grid_thw": + MultiModalFieldElem(grid_2, MultiModalBatchedField()) + }) req_state_2 = CachedRequestState( req_id="req-2", diff --git a/tpu_inference/runner/multimodal_manager.py b/tpu_inference/runner/multimodal_manager.py index 4c76dd895f..fe616094e4 100644 --- a/tpu_inference/runner/multimodal_manager.py +++ b/tpu_inference/runner/multimodal_manager.py @@ -92,7 +92,7 @@ def execute_mm_encoder(self, scheduler_output: "VllmSchedulerOutput"): return # Batch the multi-modal inputs. - mm_kwargs = list[MultiModalKwargsItem]() + mm_kwargs = list[tuple[str, MultiModalKwargsItem]]() # List of tuple (mm_hash, pos_info) mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): @@ -100,7 +100,7 @@ def execute_mm_encoder(self, scheduler_output: "VllmSchedulerOutput"): for mm_input_id in encoder_input_ids: mm_feature = req_state.mm_features[mm_input_id] mm_hash = mm_feature.identifier - mm_kwargs.append(mm_feature.data) + mm_kwargs.append((mm_feature.modality, mm_feature.data)) mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) # Batch mm inputs as much as we can: if a request in the batch has @@ -164,8 +164,6 @@ def execute_mm_encoder(self, scheduler_output: "VllmSchedulerOutput"): mm_hashes_pos, encoder_outputs, ): - if req_id not in self.runner.encoder_cache: - self.runner.encoder_cache[req_id] = {} self.runner.encoder_cache[mm_hash] = scatter_mm_placeholders( output,