diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 0bf563d5be24..6b5e86a0ebd6 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -17,7 +17,7 @@ source /etc/environment docker run --privileged --net host --shm-size=16G -it \ -e "HF_TOKEN=$HF_TOKEN" --name tpu-test \ vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \ - && python3 -m pip install pytest tpu-info \ + && python3 -m pip install pytest pytest-asyncio tpu-info \ && python3 -m pip install lm_eval[api]==0.4.4 \ && export VLLM_USE_V1=1 \ && export VLLM_XLA_CHECK_RECOMPILATION=1 \ @@ -42,6 +42,8 @@ docker run --privileged --net host --shm-size=16G -it \ && echo TEST_8 \ && pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \ && echo TEST_9 \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py \ + && echo TEST_10 \ && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \ diff --git a/tests/v1/tpu/test_multimodal.py b/tests/v1/tpu/test_multimodal.py new file mode 100644 index 000000000000..eb62e0e4b201 --- /dev/null +++ b/tests/v1/tpu/test_multimodal.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 + +import openai +import pytest + +from vllm import envs +from vllm.multimodal.utils import encode_image_base64, fetch_image +from vllm.platforms import current_platform + +from ...entrypoints.openai.test_vision import TEST_IMAGE_URLS +from ...utils import RemoteOpenAIServer + +if not envs.VLLM_USE_V1: + pytest.skip( + "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", + allow_module_level=True, + ) + + +@pytest.fixture(scope="session") +def base64_encoded_image() -> dict[str, str]: + return { + image_url: encode_image_base64(fetch_image(image_url)) + for image_url in TEST_IMAGE_URLS + } + + +@pytest.mark.asyncio +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This test needs a TPU") +@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"]) +async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, + str]): + + def whats_in_this_image_msg(b64): + return [{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{b64}" + }, + }, + ], + }] + + server_args = [ + "--max-model-len", + "1024", + "--max-num-seqs", + "16", + "--gpu-memory-utilization", + "0.95", + "--trust-remote-code", + "--max-num-batched-tokens", + "576", + # NOTE: max-num-batched-tokens>=mm_item_size + "--disable_chunked_mm_input", + "--chat-template", + "examples/template_llava.jinja" + ] + + # Server will pre-compile on first startup (takes a long time). + with RemoteOpenAIServer(model_name, server_args, + max_wait_seconds=600) as remote_server: + client: openai.AsyncOpenAI = remote_server.get_async_client() + + # Other requests now should be much faster + for image_url in TEST_IMAGE_URLS: + image_base64 = base64_encoded_image[image_url] + chat_completion_from_base64 = await client.chat.completions\ + .create( + model=model_name, + messages=whats_in_this_image_msg(image_base64), + max_completion_tokens=24, + temperature=0.0) + result = chat_completion_from_base64 + assert result + choice = result.choices[0] + assert choice.finish_reason == "length" + + message = choice.message + message = result.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f31454ab31f7..7eb464660e95 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import bisect +import gc import time from typing import TYPE_CHECKING, Optional, cast from unittest.mock import patch @@ -21,7 +22,8 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, + PlaceholderRange) from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available @@ -37,8 +39,7 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, - scatter_mm_placeholders) +from .utils import sanity_check_mm_encoder_outputs if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -198,7 +199,7 @@ def __init__( device="cpu") self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.block_table_cpu = torch.zeros( - (self.max_num_tokens, self.max_num_blocks_per_req), + (self.max_num_reqs, self.max_num_blocks_per_req), dtype=self.input_batch.block_table.get_cpu_tensor().dtype, device="cpu") @@ -220,6 +221,37 @@ def __init__( self.num_reqs_paddings = _get_req_paddings( min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs) + # Get maximum number of mm items per modality (batch size). + self.max_num_mm_items_by_modality = dict() + if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 + and self.encoder_cache_size > 0): + max_tokens_by_modality_dict = ( + MULTIMODAL_REGISTRY. + get_max_tokens_per_item_by_nonzero_modality(self.model_config)) + for modality, max_tokens in max_tokens_by_modality_dict.items(): + # Check how many items of this modality can be supported by + # the encoder budget. + encoder_budget = min(self.max_num_encoder_input_tokens, + self.encoder_cache_size) + + max_num_mm_items_encoder_budget = cdiv(encoder_budget, + max_tokens) + + # Check how many items of this modality can be supported by + # the decoder budget. + max_mm_items_per_req = self.mm_registry.\ + get_mm_limits_per_prompt(self.model_config)[modality] + + # NOTE: We do not consider max_num_batched_tokens on purpose + # because the multimodal embeddings can be generated in advance + # and chunked prefilled. + max_num_mm_items_decoder_budget = self.max_num_reqs * \ + max_mm_items_per_req + + max_num_mm_items = min(max_num_mm_items_encoder_budget, + max_num_mm_items_decoder_budget) + self.max_num_mm_items_by_modality[modality] = max_num_mm_items + def _update_num_xla_graphs(self, case_str): check_comp = self.check_recompilation and not self.enforce_eager if not check_comp: @@ -606,29 +638,36 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # 2. A list or tuple (length: num_items) of tensors, each of shape # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. + xm.mark_step() curr_group_outputs = self.model.get_multimodal_embeddings( **batched_mm_inputs) + xm.mark_step() sanity_check_mm_encoder_outputs( curr_group_outputs, expected_num_items=len(grouped_mm_inputs), ) - for output in curr_group_outputs: - encoder_outputs.append(output) + if isinstance(curr_group_outputs, torch.Tensor): + encoder_outputs.append(curr_group_outputs) + else: + assert isinstance(curr_group_outputs, (list, tuple)) + for output in curr_group_outputs: + encoder_outputs.append(output) # Cache the encoder outputs. + # NOTE (NickLucche) here we diverge from logic in other runners, as we + # assume to only have whole mm items to process. Hence we avoid the + # intrinsic dynamism that `scatter_mm_placeholders` introduces. for (req_id, input_id, pos_info), output in zip( req_ids_pos, encoder_outputs, ): if req_id not in self.encoder_cache: self.encoder_cache[req_id] = {} - - self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( - output, - is_embed=pos_info.is_embed, - ) + assert pos_info.is_embed is None, "Expected all positions to be"\ + " contiguous and embeddings." + self.encoder_cache[req_id][input_id] = output def _gather_mm_embeddings( self, @@ -641,6 +680,10 @@ def _gather_mm_embeddings( req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens mm_positions = req_state.mm_positions + # TODO unroll loop and assume/enforce --disable_chunked_mm_input + # NOTE (NickLucche) here we diverge from logic in other runners, as + # we assume to only have whole mm items to process. Hence we avoid + # the intrinsic dynamism that `gather_mm_placeholders` introduces. for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -657,25 +700,33 @@ def _gather_mm_embeddings( # in the decoder's KV cache. continue - start_idx = max(num_computed_tokens - start_pos, 0) - end_idx = min( - num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens) - assert start_idx < end_idx assert req_id in self.encoder_cache assert i in self.encoder_cache[req_id] + assert pos_info.is_embed is None, "Expected all positions to"\ + " be contiguous and embeddings." encoder_output = self.encoder_cache[req_id][i] - - if (is_embed := pos_info.is_embed) is not None: - is_embed = is_embed[start_idx:end_idx] - - mm_embeds_item = gather_mm_placeholders( - encoder_output[start_idx:end_idx], - is_embed=is_embed, - ) - mm_embeds.append(mm_embeds_item) + mm_embeds.append(encoder_output) return mm_embeds + def _get_model_inputs(self, input_ids: torch.Tensor, + mm_embeds: list[torch.Tensor]): + if self.is_multimodal_model: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings( + input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + return None, inputs_embeds + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + return input_ids, None + @torch.no_grad() def execute_model( self, @@ -694,27 +745,13 @@ def execute_model( mm_embeds = self._gather_mm_embeddings(scheduler_output) else: mm_embeds = [] - + xm.mark_step() # Prepare inputs attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs( scheduler_output) - if self.is_multimodal_model: - # NOTE(woosuk): To unify token ids and soft tokens (vision - # embeddings), we always use embeddings (rather than token ids) - # as input to the multimodal model, even when the input is text. - if mm_embeds: - inputs_embeds = self.model.get_input_embeddings( - self.input_ids, mm_embeds) - else: - inputs_embeds = self.model.get_input_embeddings(self.input_ids) - input_ids = None - else: - # For text-only models, we use token ids as input. - # While it is possible to use embeddings as input just like the - # multimodal models, it is not desirable for performance since - # then the embedding layer is not included in the CUDA graph. - input_ids = self.input_ids - inputs_embeds = None + input_ids, inputs_embeds = self._get_model_inputs( + self.input_ids, mm_embeds) + xm.mark_step() num_reqs = self.input_batch.num_reqs # Run the decoder with set_forward_context(attn_metadata, self.vllm_config): @@ -890,9 +927,70 @@ def _dummy_run(self, num_tokens: int) -> None: inputs_embeds=inputs_embeds) self._hidden_states_dtype = out.dtype + def _precompile_mm_encoder(self) -> None: + # Pre-compile MM encoder for all supported data modalities. + hf_config = self.vllm_config.model_config.hf_config + for mode, max_items_by_mode in \ + self.max_num_mm_items_by_modality.items(): + logger.info( + "Compiling Multimodal %s Encoder with different input" + " shapes.", mode) + start = time.perf_counter() + # No padding for MM encoder just yet. + for num_items in range(1, max_items_by_mode + 1): + logger.info(" -- mode: %s items: %d", mode, num_items) + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + mode, num_items) + # Run multimodal encoder. + xm.mark_step() + mm_embeds = self.model.\ + get_multimodal_embeddings(**batched_dummy_mm_inputs) + xm.mark_step() + num_patches = mm_embeds[0].shape[0] + items_size = num_patches * num_items + + # NOTE (NickLucche) pre-compile `get_input_embeddings` when mm + # embeddings are present. We assume `--disable-mm-chunked`, + # hence only whole items can be scheduled. This implies we just + # need to compile when `num_items` fit the (padded) `input_ids` + for num_tokens in self.num_tokens_paddings: + if num_tokens >= items_size: + # XLA Workaround: if torch.zeros(..device) is used, XLA + # compiles a scalar+expansion op, which won't match + # the graph generated at runtime. CPU->TPU must be used + placeholders_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device="cpu") + # Align placeholders and actual num mm_embeddings. + placeholders_ids[:items_size] = \ + hf_config.image_token_index + + placeholders_ids = placeholders_ids.to(self.device) + # Assign outputs or the graph will be cut short. + a, b = self._get_model_inputs(placeholders_ids, + [mm_embeds]) + assert a is None + xm.mark_step() + + # Pre-compile `get_input_embeddings` when mm_embeddings are not + # present. Chunk is only made of text, no mm_placeholders. + for num_tokens in self.num_tokens_paddings: + placeholders_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device="cpu") + placeholders_ids = placeholders_ids.to(self.device) + a, b = self._get_model_inputs(placeholders_ids, []) + assert a is None + xm.mark_step() + + xm.wait_device_ops() + end = time.perf_counter() + logger.info( + "Multimodal %s Encoder compilation finished in in %.2f " + "[secs].", mode, end - start) + def _precompile_backbone(self) -> None: logger.info("Compiling the model with different input shapes.") - start = time.perf_counter() for num_tokens in self.num_tokens_paddings: logger.info(" -- num_tokens: %d", num_tokens) @@ -962,11 +1060,70 @@ def capture_model(self) -> None: """ Precompile all the subgraphs with possible input shapes. """ - # TODO: precompile encoder + self._precompile_mm_encoder() self._precompile_backbone() self._precompile_select_hidden_states() self._precompile_sample_from_hidden() + def profile_run( + self, + num_tokens: int, + ) -> None: + # Profile with multimodal encoder & encoder cache. + # TODO: handle encoder-decoder models once we support them. + if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 + and self.encoder_cache_size > 0): + + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + dummy_data_modality, max_num_mm_items = max( + self.max_num_mm_items_by_modality.items(), key=lambda t: t[1]) + + encoder_budget = min(self.max_num_encoder_input_tokens, + self.encoder_cache_size) + + logger.info( + "Encoder cache will be initialized with a budget of %d tokens," + " and profiled with %s %s items of the maximum feature size.", + encoder_budget, max_num_mm_items, dummy_data_modality) + + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_data_modality, max_num_mm_items) + + # Run multimodal encoder. + # Isolate encoder graph from post-processing to minimize + # impact of recompilation until it's fixed. + start = time.perf_counter() + xm.mark_step() + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + xm.mark_step() + xm.wait_device_ops() + end = time.perf_counter() + logger.info( + "Multimodal Encoder profiling finished in in %.2f [secs].", + end - start) + + assert len(dummy_encoder_outputs) == max_num_mm_items, ( + "Expected dimension 0 of encoder outputs to match the number " + f"of multimodal data items: {max_num_mm_items}, got " + f"{len(dummy_encoder_outputs)=} instead. This is most likely " + "due to the 'get_multimodal_embeddings' method of the model " + "not implemented correctly.") + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + + # Trigger compilation for general shape. + self._dummy_run(num_tokens) + + xm.mark_step() + xm.wait_device_ops() + self.encoder_cache.clear() + gc.collect() + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -1045,6 +1202,36 @@ def get_multimodal_embeddings(self, *args, **kwargs): def get_input_embeddings(self, *args, **kwargs): return self.model.get_input_embeddings(*args, **kwargs) + def _get_mm_dummy_batch(self, modality: str, + batch_size: int) -> BatchedTensorInputs: + # Dummy data for pre-compiling multimodal models. + dummy_request_data = self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, + seq_len=self.max_num_tokens, + ) + dummy_mm_data = dummy_request_data.multi_modal_data + + # Dummy data definition in V0 may contain multiple multimodal items + # (e.g, multiple images) for a single request, therefore here we + # always replicate first item by max_num_mm_items times since in V1 + # they are scheduled to be processed separately. + assert isinstance(dummy_mm_data, MultiModalKwargs), ( + "Expected dummy multimodal data to be of type " + f"MultiModalKwargs, got {type(dummy_mm_data)=} instead. " + "This is most likely due to the model not having a merged " + "processor.") + + # When models have a merged processor, their dummy data is + # already batched `MultiModalKwargs`, therefore we take the first + # `MultiModalKwargsItem` from the desired modality to profile on. + dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) + dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) + + batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * + batch_size) + return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs, + device=self.device) + def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: logger.info("Preparing request paddings:") @@ -1088,7 +1275,6 @@ def _get_token_paddings(min_token_size: int, max_token_size: int, if num >= max_token_size: break num *= 2 - else: logger.info("Using incremental token paddings:") while num <= padding_gap: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 8f2b4acc32c3..2204f037a6d5 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -157,7 +157,7 @@ def determine_available_memory(self) -> int: runner_kv_caches) # `max_num_tokens >= max_num_batched_tokens` due to padding. - self.model_runner._dummy_run(self.model_runner.max_num_tokens) + self.model_runner.profile_run(self.model_runner.max_num_tokens) # Synchronize before measuring the memory usage. xm.wait_device_ops()