Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions tests/runner/test_multimodal_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ def test_execute_mm_encoder_single_image(self):
# Mock request state
dummy_pixel_values = torch.randn(3, 224, 224, dtype=torch.bfloat16)
dummy_grid_thw = torch.tensor([[1, 1, 1]], dtype=torch.int64)
mm_item = MultiModalKwargsItem.from_elems([
MultiModalFieldElem("image", "pixel_values", dummy_pixel_values,
MultiModalBatchedField()),
MultiModalFieldElem("image", "image_grid_thw", dummy_grid_thw,
MultiModalBatchedField())
])
mm_item = MultiModalKwargsItem({
"pixel_values":
MultiModalFieldElem(dummy_pixel_values, MultiModalBatchedField()),
"image_grid_thw":
MultiModalFieldElem(dummy_grid_thw, MultiModalBatchedField())
})

req_state = CachedRequestState(
req_id="req-1",
Expand Down Expand Up @@ -183,12 +183,12 @@ def test_execute_mm_encoder_multiple_images(self):
px_1 = torch.randn(3, 224, 224, dtype=torch.bfloat16)
grid_1 = torch.tensor([[1, 1, 1]], dtype=torch.int64)

mm_item_1 = MultiModalKwargsItem.from_elems([
MultiModalFieldElem("image", "pixel_values", px_1,
MultiModalBatchedField()),
MultiModalFieldElem("image", "image_grid_thw", grid_1,
MultiModalBatchedField())
])
mm_item_1 = MultiModalKwargsItem({
"pixel_values":
MultiModalFieldElem(px_1, MultiModalBatchedField()),
"image_grid_thw":
MultiModalFieldElem(grid_1, MultiModalBatchedField())
})

req_state_1 = CachedRequestState(
req_id="req-1",
Expand All @@ -210,12 +210,12 @@ def test_execute_mm_encoder_multiple_images(self):

px_2 = torch.randn(3, 224, 224, dtype=torch.bfloat16)
grid_2 = torch.tensor([[1, 2, 2]], dtype=torch.int64)
mm_item_2 = MultiModalKwargsItem.from_elems([
MultiModalFieldElem("image", "pixel_values", px_2,
MultiModalBatchedField()),
MultiModalFieldElem("image", "image_grid_thw", grid_2,
MultiModalBatchedField())
])
mm_item_2 = MultiModalKwargsItem({
"pixel_values":
MultiModalFieldElem(px_2, MultiModalBatchedField()),
"image_grid_thw":
MultiModalFieldElem(grid_2, MultiModalBatchedField())
})

req_state_2 = CachedRequestState(
req_id="req-2",
Expand Down
6 changes: 2 additions & 4 deletions tpu_inference/runner/multimodal_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@ def execute_mm_encoder(self, scheduler_output: "VllmSchedulerOutput"):
return

# Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]()
mm_kwargs = list[tuple[str, MultiModalKwargsItem]]()
# List of tuple (mm_hash, pos_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.runner.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_feature = req_state.mm_features[mm_input_id]
mm_hash = mm_feature.identifier
mm_kwargs.append(mm_feature.data)
mm_kwargs.append((mm_feature.modality, mm_feature.data))
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))

# Batch mm inputs as much as we can: if a request in the batch has
Expand Down Expand Up @@ -164,8 +164,6 @@ def execute_mm_encoder(self, scheduler_output: "VllmSchedulerOutput"):
mm_hashes_pos,
encoder_outputs,
):
if req_id not in self.runner.encoder_cache:
Comment thread
mrjunwan-lang marked this conversation as resolved.
self.runner.encoder_cache[req_id] = {}

self.runner.encoder_cache[mm_hash] = scatter_mm_placeholders(
output,
Expand Down
Loading