diff --git a/.buildkite/vllm_lkg.version b/.buildkite/vllm_lkg.version index 54b2b36583..5800c242dc 100644 --- a/.buildkite/vllm_lkg.version +++ b/.buildkite/vllm_lkg.version @@ -1 +1 @@ -64a40a7ab4d0053830fae04c83763fa67f2183e6 +010ec0c30ef1fa481fc69b8a5d2c205052d93607 diff --git a/tpu_inference/runner/multimodal_manager.py b/tpu_inference/runner/multimodal_manager.py index fe616094e4..86aef70cb3 100644 --- a/tpu_inference/runner/multimodal_manager.py +++ b/tpu_inference/runner/multimodal_manager.py @@ -20,8 +20,6 @@ from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput -from vllm.v1.worker.utils import (gather_mm_placeholders, - scatter_mm_placeholders) from tpu_inference.models.jax.utils.multi_modal_utils import ( flatten_embeddings, sanity_check_mm_encoder_outputs) @@ -160,15 +158,12 @@ def execute_mm_encoder(self, scheduler_output: "VllmSchedulerOutput"): encoder_outputs.append(output) # Cache the encoder outputs. - for (mm_hash, pos_info), output in zip( + for (mm_hash, _), output in zip( mm_hashes_pos, encoder_outputs, ): - self.runner.encoder_cache[mm_hash] = scatter_mm_placeholders( - output, - is_embed=pos_info.is_embed, - ) + self.runner.encoder_cache[mm_hash] = output def gather_mm_embeddings(self, scheduler_output: "VllmSchedulerOutput", target_pad_len: int) -> list[jax.Array]: @@ -201,6 +196,11 @@ def gather_mm_embeddings(self, scheduler_output: "VllmSchedulerOutput", num_computed_tokens - start_pos + num_scheduled_tokens, num_encoder_tokens) assert start_idx < end_idx + curr_embeds_start, curr_embeds_end = ( + pos_info.get_embeds_indices_in_range(start_idx, end_idx)) + if curr_embeds_start == curr_embeds_end: + continue + mm_hash = mm_feature.identifier encoder_output = self.runner.encoder_cache.get(mm_hash, None) assert encoder_output is not None,\ @@ -209,11 +209,11 @@ def gather_mm_embeddings(self, scheduler_output: "VllmSchedulerOutput", if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] + mm_embeds_item = encoder_output[ + curr_embeds_start:curr_embeds_end] + else: + mm_embeds_item = encoder_output[start_idx:end_idx] - mm_embeds_item = gather_mm_placeholders( - encoder_output[start_idx:end_idx], - is_embed=is_embed, - ) mm_embeds.append(mm_embeds_item) if not mm_embeds: return None