Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/vllm_lkg.version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
64a40a7ab4d0053830fae04c83763fa67f2183e6
010ec0c30ef1fa481fc69b8a5d2c205052d93607
22 changes: 11 additions & 11 deletions tpu_inference/runner/multimodal_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
from vllm.v1.worker.utils import (gather_mm_placeholders,
scatter_mm_placeholders)

from tpu_inference.models.jax.utils.multi_modal_utils import (
flatten_embeddings, sanity_check_mm_encoder_outputs)
Expand Down Expand Up @@ -160,15 +158,12 @@ def execute_mm_encoder(self, scheduler_output: "VllmSchedulerOutput"):
encoder_outputs.append(output)

# Cache the encoder outputs.
for (mm_hash, pos_info), output in zip(
for (mm_hash, _), output in zip(
mm_hashes_pos,
encoder_outputs,
):

self.runner.encoder_cache[mm_hash] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
self.runner.encoder_cache[mm_hash] = output

def gather_mm_embeddings(self, scheduler_output: "VllmSchedulerOutput",
target_pad_len: int) -> list[jax.Array]:
Expand Down Expand Up @@ -201,6 +196,11 @@ def gather_mm_embeddings(self, scheduler_output: "VllmSchedulerOutput",
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
assert start_idx < end_idx
curr_embeds_start, curr_embeds_end = (
pos_info.get_embeds_indices_in_range(start_idx, end_idx))
if curr_embeds_start == curr_embeds_end:
continue

mm_hash = mm_feature.identifier
encoder_output = self.runner.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
Expand All @@ -209,11 +209,11 @@ def gather_mm_embeddings(self, scheduler_output: "VllmSchedulerOutput",

if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = encoder_output[
curr_embeds_start:curr_embeds_end]
else:
mm_embeds_item = encoder_output[start_idx:end_idx]

mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
if not mm_embeds:
return None
Expand Down
Loading