diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 56329a6edcc5..989511cfcaf5 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -227,6 +227,44 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 This does not impact [multi-modal processor caching](#processor-caching). +### GPU Multi-Modal Processing + +You can speed up multi-modal input processing by running Hugging Face processors on the GPU. +To support this, the processor must accept a `device` argument in its call signature. +As of this writing, the following processors are known to support GPU acceleration: + +- Descendants of `BaseImageProcessorFast` (requires `use_fast=True`) +- Descendants of `BaseVideoProcessor` +- `WhisperFeatureExtractor` + +To run Hugging Face processors on the GPU, you can pass the `device` argument +(and `use_fast` if needed) via `mm_processor_kwargs`: + +```python +# Fast image processor requires use_fast=True +llm = LLM( + model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_processor_kwargs={"use_fast": True, "device": "cuda"}, +) + +# Whisper feature extractor does not require use_fast +llm = LLM( + model="Qwen/Qwen2-Audio-7B-Instruct", + mm_processor_kwargs={"device": "cuda"}, +) +``` + +!!! note + vLLM will try to allocate visible GPUs that are not used by the core engine + for multi-modal processing. If this is not possible, then the same GPU + will be used for multi-modal processing and model forward pass, resulting + in resource contention (both I/O and memory capacity). + +!!! important + The performance improvement from GPU processing varies from model to model. + In some cases, GPU processing may even become detrimental because of resource contention. + Make sure to perform benchmarking before enabling this! + ## Multi-Modal Caching Multi-modal caching avoids repeated transfer or processing of the same multi-modal data, diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 7d8a09852799..86395850e74f 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -4,9 +4,17 @@ import pytest -from tests.entrypoints.openai.chat_completion.test_vision import TEST_IMAGE_ASSETS +from tests.entrypoints.openai.chat_completion.test_audio import ( + TEST_AUDIO_URLS, + dummy_messages_from_audio_url, +) +from tests.entrypoints.openai.chat_completion.test_vision import ( + TEST_IMAGE_ASSETS, + dummy_messages_from_image_url, +) from vllm import LLM from vllm.distributed import cleanup_dist_env_and_memory +from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams @@ -206,3 +214,76 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test): outputs_2 = llm.chat(batch_2, sampling_params=sampling_params) assert len(outputs_2) == len(batch_2) assert llm.llm_engine.get_num_unfinished_requests() == 0 + + +@pytest.mark.parametrize( + ("model_id", "modality", "mm_init_kwargs"), + [ + ("Qwen/Qwen2.5-VL-3B-Instruct", "image", {"use_fast": True}), + ("Qwen/Qwen2-Audio-7B-Instruct", "audio", {}), + ], +) +@pytest.mark.parametrize( + "image_urls", [[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]], indirect=True +) +def test_mm_processing_gpu(model_id, modality, mm_init_kwargs, image_urls: list[str]): + device = current_platform.device_name + + num_items = 2 + if modality == "image": + messages = dummy_messages_from_image_url(image_urls[:num_items]) + elif modality == "audio": + messages = dummy_messages_from_audio_url(TEST_AUDIO_URLS[:num_items]) + else: + raise NotImplementedError(modality) + + llm = LLM( + model=model_id, + max_model_len=6144, + max_num_seqs=2, + enforce_eager=True, + seed=0, + limit_mm_per_prompt={modality: num_items}, + mm_processor_kwargs=mm_init_kwargs | {"device": device}, + ) + + outputs = llm.chat(messages) + assert len(outputs) == 1 + + +@pytest.mark.parametrize( + ("model_id", "modality", "mm_init_kwargs"), + [ + ("Qwen/Qwen2.5-VL-3B-Instruct", "image", {"use_fast": True}), + ("Qwen/Qwen2-Audio-7B-Instruct", "audio", {}), + ], +) +@pytest.mark.parametrize("image_urls", [[TEST_IMAGE_ASSETS[0]]], indirect=True) +def test_mm_processing_gpu_bad_device( + model_id, modality, mm_init_kwargs, image_urls: list[str] +): + device = current_platform.device_name + if device == "cpu": + pytest.skip("Not applicable to CPU") + + num_items = 1 + if modality == "image": + messages = dummy_messages_from_image_url(image_urls[:num_items]) + elif modality == "audio": + messages = dummy_messages_from_audio_url(TEST_AUDIO_URLS[:num_items]) + else: + raise NotImplementedError(modality) + + llm = LLM( + model=model_id, + max_model_len=6144, + max_num_seqs=2, + enforce_eager=True, + seed=0, + limit_mm_per_prompt={modality: num_items}, + mm_processor_kwargs=mm_init_kwargs, + ) + + match = "cannot override the device for multi-modal preprocessing" + with pytest.raises(ValueError, match=match): + llm.chat(messages, mm_processor_kwargs={"device": device}) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 4e765ab1b8b3..7dfc41b5efbd 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -10,7 +10,122 @@ MultiModalSharedField, PlaceholderRange, ) -from vllm.multimodal.utils import argsort_mm_positions, group_and_batch_mm_items +from vllm.multimodal.utils import ( + allocate_gpu_mm_processors, + argsort_mm_positions, + group_and_batch_mm_items, +) + + +@pytest.mark.parametrize( + "case", + [ + # Basic + dict( + mm_processor_device="cuda", + mm_processor_count=0, + available_device_count=1, + engine_device_count=1, + expected_gpu_allocation=[], + ), + dict( + mm_processor_device="cuda", + mm_processor_count=1, + available_device_count=1, + engine_device_count=1, + expected_gpu_allocation=["cuda:0"], + ), + # Use Engine GPUs + dict( + mm_processor_device="cuda", + mm_processor_count=2, + available_device_count=1, + engine_device_count=1, + expected_gpu_allocation=["cuda:0", "cuda:0"], + ), + dict( + mm_processor_device="cuda", + mm_processor_count=2, + available_device_count=1, + engine_device_count=2, + expected_gpu_allocation=["cuda:0", "cuda:0"], + ), + dict( + mm_processor_device="cuda", + mm_processor_count=2, + available_device_count=2, + engine_device_count=2, + expected_gpu_allocation=["cuda:0", "cuda:1"], + ), + dict( + mm_processor_device="cuda", + mm_processor_count=3, + available_device_count=2, + engine_device_count=2, + expected_gpu_allocation=["cuda:0", "cuda:1", "cuda:0"], + ), + # Use excess GPUs + dict( + mm_processor_device="cuda", + mm_processor_count=2, + available_device_count=3, + engine_device_count=2, + expected_gpu_allocation=["cuda:2", "cuda:2"], + ), + dict( + mm_processor_device="cuda", + mm_processor_count=2, + available_device_count=4, + engine_device_count=2, + expected_gpu_allocation=["cuda:2", "cuda:3"], + ), + dict( + mm_processor_device="cuda", + mm_processor_count=3, + available_device_count=4, + engine_device_count=2, + expected_gpu_allocation=["cuda:2", "cuda:3", "cuda:2"], + ), + # Specific device + dict( + mm_processor_device="cuda:0", + mm_processor_count=2, + available_device_count=4, + engine_device_count=2, + expected_gpu_allocation=["cuda:0", "cuda:0"], + ), + dict( + mm_processor_device="cuda:2", + mm_processor_count=2, + available_device_count=4, + engine_device_count=2, + expected_gpu_allocation=["cuda:2", "cuda:2"], + ), + # Out-of-bounds device + dict( + mm_processor_device="cuda:4", + mm_processor_count=2, + available_device_count=4, + engine_device_count=2, + expected_gpu_allocation=["cuda:4", "cuda:4"], + ), + ], +) +def test_allocate_gpu_mm_processors(case): + mm_processor_device = case["mm_processor_device"] + mm_processor_count = case["mm_processor_count"] + available_device_count = case["available_device_count"] + engine_device_count = case["engine_device_count"] + expected_gpu_allocation = case["expected_gpu_allocation"] + + gpu_allocation = allocate_gpu_mm_processors( + mm_processor_device, + mm_processor_count, + available_device_count=available_device_count, + engine_device_count=engine_device_count, + ) + + assert gpu_allocation == expected_gpu_allocation @pytest.mark.parametrize( diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index e66511c92ab2..fa4cf0ed96df 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -235,6 +235,15 @@ def _validate_multimodal_config(self): ) return self + @property + def mm_processing_device(self) -> str: + kwargs = self.mm_processor_kwargs or {} + return str(kwargs.get("device", "cpu")) + + @mm_processing_device.setter + def mm_processing_device(self, device: str) -> None: + self.update_mm_processor_kwargs({"device": device}) + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -272,6 +281,12 @@ def get_limit_per_prompt(self, modality: str) -> int: return limit_data.count + def update_mm_processor_kwargs(self, value: dict[str, Any]) -> None: + if self.mm_processor_kwargs is None: + self.mm_processor_kwargs = {} + + self.mm_processor_kwargs.update(value) + def merge_mm_processor_kwargs( self, inference_kwargs: Mapping[str, object], @@ -281,6 +296,16 @@ def merge_mm_processor_kwargs( according to the extra arguments passed during inference. """ kwargs = self.mm_processor_kwargs or {} + + # This is to avoid breaking assumptions in memory profiling + init_device = kwargs.get("device", "cpu") + inference_device = inference_kwargs.get("device", init_device) + if init_device != inference_device: + raise ValueError( + "You cannot override the device for multi-modal preprocessing " + f"at runtime! Found: {init_device=} vs. {inference_device=}" + ) + return kwargs | dict(inference_kwargs) def is_multimodal_pruning_enabled(self): diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 107dcfa273eb..73a9e3a8349f 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -339,6 +339,15 @@ class is dynamically inherited by the worker class. This is used to inject should only be set by API server scale-out. """ + _renderer_gpu_allocation: list[str] | None = None + """ + The GPU allocated to the renderer of each API process. + + Note: + This is an internal config that is only valid for and + should only be set internally. + """ + @field_validator("disable_nccl_for_dp_synchronization", mode="wrap") @classmethod def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index fad3e0ed240f..410910a1a7e5 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1202,6 +1202,33 @@ def has_blocked_weights(): self.model_config.disable_cascade_attn = True logger.warning_once("Disabling cascade attention when DBO is enabled.") + mm_config = self.model_config.multimodal_config if self.model_config else None + if mm_config and mm_config.mm_processing_device != "cpu": + api_process_count = self.parallel_config._api_process_count + api_process_rank = self.parallel_config._api_process_rank + local_gpu_count = ( + self.parallel_config.data_parallel_size_local + * self.parallel_config.world_size + ) + + if api_process_rank != -1: + from vllm.multimodal.utils import allocate_gpu_mm_processors + + device_count = current_platform.device_count() # type: ignore + + gpu_allocation = allocate_gpu_mm_processors( + mm_config.mm_processing_device, + api_process_count, + available_device_count=device_count, + engine_device_count=local_gpu_count, + ) + device = gpu_allocation[api_process_rank] + + logger.info("Multi-modal processor will be run on device %s", device) + + self.parallel_config._renderer_gpu_allocation = gpu_allocation + mm_config.mm_processing_device = device + if not self.instance_id: self.instance_id = random_uuid()[:5] diff --git a/vllm/multimodal/processing/context.py b/vllm/multimodal/processing/context.py index ef9710374d81..4d25286dc5fc 100644 --- a/vllm/multimodal/processing/context.py +++ b/vllm/multimodal/processing/context.py @@ -226,15 +226,36 @@ def _postprocess_output( self, output: JSONTree, ) -> JSONTree: + mm_config = self.model_config.get_multimodal_config() + is_mm_processing_gpu = mm_config.mm_processing_device != "cpu" + def _postprocess_one(x: object): - if isinstance(x, torch.Tensor): # noqa: SIM102 + if isinstance(x, torch.Tensor): # This mimics the behavior of transformers.BatchFeature if x.is_floating_point(): x = x.to(dtype=self.model_config.dtype) + # This is required because we need to transfer the data + # to engine core, and the serialization process expects + # CPU tensors. + # The dtype of model config is usually lower precision + # so we call this last to transfer less data to CPU + if is_mm_processing_gpu: + x = x.to(device="cpu", non_blocking=True) + return x - return json_map_leaves(_postprocess_one, output) + output = json_map_leaves(_postprocess_one, output) + + # Async GPU -> CPU requires explicit synchronization + if is_mm_processing_gpu: + from vllm.platforms import current_platform + + synchronize = current_platform.synchronize + if synchronize is not None: + synchronize() + + return output def get_merged_mm_kwargs(self, kwargs: Mapping[str, object]): mm_config = self.model_config.get_multimodal_config() diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 2d321cb67b4e..1ee63486e417 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -109,6 +109,43 @@ def encode_video_url( return f"data:{mimetype};base64,{video_b64}" +def allocate_gpu_mm_processors( + mm_processor_device: str, + mm_processor_count: int, + *, + available_device_count: int, + engine_device_count: int, +) -> list[str]: + """ + Allocate each processor to a GPU that is not being used by EngineCore, + if possible. + + Returns: + The device to allocate for each multi-modal processor. + """ + device_type, *rest = mm_processor_device.rsplit(":", 1) + if len(rest) == 0: + # Try to run each processor on a different GPU, preferably those + # that are not used by vLLM engine + if available_device_count > engine_device_count: + remaining_count = available_device_count - engine_device_count + processor_gpu_idxs = [ + engine_device_count + server_idx % remaining_count + for server_idx in range(mm_processor_count) + ] + else: + processor_gpu_idxs = [ + server_idx % available_device_count + for server_idx in range(mm_processor_count) + ] + else: + # Already targeted a specific GPU + (device_idx,) = map(int, rest) + processor_gpu_idxs = [device_idx] * mm_processor_count + + return [f"{device_type}:{gpu_idx}" for gpu_idx in processor_gpu_idxs] + + def argsort_mm_positions( mm_positions: MultiModalPlaceholders, ) -> list[tuple[str, int]]: diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index b59d02a46327..959f4b106a96 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -28,7 +28,10 @@ from vllm.tokenizers import TokenizerLike from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid from vllm.utils.jsontree import json_iter_leaves +from vllm.utils.mem_constants import GiB_bytes +from vllm.utils.mem_utils import MemorySnapshot, memory_profiling from vllm.v1.engine import EngineCoreRequest +from vllm.v1.worker.utils import request_memory logger = init_logger(__name__) @@ -45,6 +48,7 @@ def __init__( self.model_config = model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config + self.parallel_config = vllm_config.parallel_config self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config self.structured_outputs_config = vllm_config.structured_outputs_config @@ -56,10 +60,12 @@ def __init__( self.supports_mm_inputs = mm_registry.supports_multimodal_inputs(model_config) self.mm_encoder_cache_size = 0 + self.mm_max_items_per_prompt: Mapping[str, int] = {} self.skip_prompt_length_check = False if self.supports_mm_inputs: mm_budget = MultiModalBudget(vllm_config, mm_registry) self.mm_encoder_cache_size = mm_budget.encoder_cache_size + self.mm_max_items_per_prompt = mm_budget.mm_max_items_per_prompt self.skip_prompt_length_check = ( mm_budget.processor.info.skip_prompt_length_check ) @@ -71,6 +77,8 @@ def __init__( mm_registry=mm_registry, ) + self.profile_run() + @property def tokenizer(self) -> TokenizerLike | None: return self.renderer.tokenizer @@ -442,3 +450,72 @@ def _validate_model_inputs( self._validate_model_input(encoder_input, prompt_type="encoder") self._validate_model_input(decoder_input, prompt_type="decoder") + + def profile_run(self) -> None: + model_config = self.model_config + mm_config = model_config.multimodal_config + if not mm_config: + return + + parallel_config = self.parallel_config + gpu_allocation = parallel_config._renderer_gpu_allocation + if not gpu_allocation: + return + + device = mm_config.mm_processing_device + if device != "cpu": + # Peak memory usage (required for this profiling) + # is only tracked for CUDA + if not current_platform.is_cuda_alike(): + return + + # Only run profiling on the first Processor for each device, + # then multiply the usage by the number of processors for that + # device. + # Compared to running profiling on every Processor in parallel, + # this avoids non-deterministic peak memory usage calculation. + api_process_rank = parallel_config._api_process_rank + if api_process_rank != gpu_allocation.index(device): + return + + baseline_snapshot = MemorySnapshot(device=device) + device_ = baseline_snapshot.device_ + + # Only check init memory if we are sure that the EngineCore is not + # loading weights or running profiling on the same GPU + new_device_index = device_.index + local_gpu_count = ( + parallel_config.data_parallel_size_local * parallel_config.world_size + ) + if new_device_index < local_gpu_count: + logger.warning( + "the same GPU (%s). This may result in inaccurate memory " + "profiling, and resource contention during inference.", + device_, + ) + else: + request_memory(baseline_snapshot, self.cache_config) + + with memory_profiling(baseline_snapshot) as diff: + for ( + modality, + max_items_per_prompt, + ) in self.mm_max_items_per_prompt.items(): + self.input_preprocessor.mm_registry.get_dummy_mm_inputs( + model_config=model_config, + mm_counts={modality: max_items_per_prompt}, + ) + + usage_mult = gpu_allocation.count(device) + memory_usage = diff.torch_peak_increase * usage_mult + logger.info( + "Multi-modal processing took %.4f GiB and %.6f seconds on %s", + memory_usage / GiB_bytes, + diff.profile_time, + device_, + ) + if memory_usage > diff.before_profile.free_memory: + raise ValueError( + f"Not enough memory in {device_} for multi-modal processor. " + f"Try reducing `api_server_count` or revert to CPU processing." + )