Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
177 commits
Select commit Hold shift + click to select a range
d919aa8
[Core] Enable HF processing on GPU
DarkLight1337 Aug 1, 2025
9af73f3
Remove unused function
DarkLight1337 Aug 1, 2025
3a560a0
Format
DarkLight1337 Aug 1, 2025
3a3e8c2
Rename
DarkLight1337 Aug 1, 2025
ffff508
Address comment
DarkLight1337 Aug 1, 2025
87819d2
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 1, 2025
9302a3c
Make the test more useful
DarkLight1337 Aug 1, 2025
91a33ad
Update the test
DarkLight1337 Aug 1, 2025
de7549c
Separate preprocessor and model batch size
DarkLight1337 Aug 1, 2025
a08240b
Comments
DarkLight1337 Aug 1, 2025
1f2b4c4
Rename
DarkLight1337 Aug 1, 2025
dbd6159
Update docs
DarkLight1337 Aug 1, 2025
c7c6806
Fix tests
DarkLight1337 Aug 1, 2025
ad59773
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 1, 2025
1d1b419
Use async d2h
DarkLight1337 Aug 1, 2025
6147bec
Reword
DarkLight1337 Aug 1, 2025
a18af4b
Reword
DarkLight1337 Aug 1, 2025
55b90aa
Reword
DarkLight1337 Aug 1, 2025
ef1ec38
Reorganize
DarkLight1337 Aug 1, 2025
6937d5e
Fix incorrect batch size causing hang
DarkLight1337 Aug 1, 2025
4357736
Fix
DarkLight1337 Aug 1, 2025
6aef4b4
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 2, 2025
b1d9367
Consolidate budget calculation
DarkLight1337 Aug 2, 2025
aa0b648
Remove whitespace
DarkLight1337 Aug 2, 2025
ac1e1c1
Eager init
DarkLight1337 Aug 2, 2025
487e312
Fix
DarkLight1337 Aug 2, 2025
a574715
Don't use cache for dummy data
DarkLight1337 Aug 2, 2025
58f9123
Split processor and model data
DarkLight1337 Aug 2, 2025
6b7b970
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 2, 2025
0baf55a
Optimize
DarkLight1337 Aug 2, 2025
423a8aa
Fix
DarkLight1337 Aug 2, 2025
4eb0529
Handle disabled chunked prefill
DarkLight1337 Aug 2, 2025
b452419
Fix naming
DarkLight1337 Aug 2, 2025
b8303db
Rename
DarkLight1337 Aug 2, 2025
61b2d4a
Add guard
DarkLight1337 Aug 2, 2025
a935dd5
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 3, 2025
3132a41
Remove unnecessary register
DarkLight1337 Aug 3, 2025
e743c47
Fix
DarkLight1337 Aug 3, 2025
b0be133
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 5, 2025
8a5ba9e
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 5, 2025
93af843
Update
DarkLight1337 Aug 5, 2025
63ff13e
Don't disable caching
DarkLight1337 Aug 5, 2025
7f23907
Simplify
DarkLight1337 Aug 5, 2025
d921114
Update
DarkLight1337 Aug 5, 2025
b3b662b
Reduce diffs
DarkLight1337 Aug 5, 2025
b2e5843
Reduce diffs
DarkLight1337 Aug 5, 2025
f48ce4f
Update doc
DarkLight1337 Aug 5, 2025
e3f70cb
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 7, 2025
a318923
Clean
DarkLight1337 Aug 7, 2025
a07838a
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 8, 2025
4e8a8a1
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 9, 2025
b943eae
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 13, 2025
1fc7ac8
Revert profiling changes
DarkLight1337 Aug 13, 2025
aa8bdb9
Try profiling processing
DarkLight1337 Aug 13, 2025
6d51d82
Be more precise
DarkLight1337 Aug 13, 2025
6aae217
Try to auto-map GPU processor
DarkLight1337 Aug 13, 2025
55c5e1b
Fix
DarkLight1337 Aug 14, 2025
bceb6bd
Update
DarkLight1337 Aug 14, 2025
8e59dc8
Update docs
DarkLight1337 Aug 14, 2025
646ba93
Update
DarkLight1337 Aug 14, 2025
794cb4e
Test
DarkLight1337 Aug 14, 2025
07f0f1f
Fix typo
DarkLight1337 Aug 14, 2025
bce21a2
Fix arg
DarkLight1337 Aug 14, 2025
61d5422
Simplify
DarkLight1337 Aug 14, 2025
153d971
Update
DarkLight1337 Aug 14, 2025
e50ef02
Run profiling inside processor
DarkLight1337 Aug 14, 2025
72b5d94
Deprecate
DarkLight1337 Aug 14, 2025
e9b6f7b
Avoid conflicting profile runs between API servers
DarkLight1337 Aug 14, 2025
e74ec9d
Rename
DarkLight1337 Aug 14, 2025
fe15a96
Reword
DarkLight1337 Aug 14, 2025
47d5b81
Fix
DarkLight1337 Aug 14, 2025
b66826c
Warn
DarkLight1337 Aug 14, 2025
5e84834
Add TODO
DarkLight1337 Aug 14, 2025
f01d3d1
Comment
DarkLight1337 Aug 14, 2025
5e8a9fc
Comment
DarkLight1337 Aug 14, 2025
f59b27d
Remove redundant reset
DarkLight1337 Aug 14, 2025
210f849
Fix not working on other platforms
DarkLight1337 Aug 14, 2025
90eeeaa
Comment
DarkLight1337 Aug 14, 2025
ea4e97f
Doc
DarkLight1337 Aug 14, 2025
999fe9f
Update tests
DarkLight1337 Aug 14, 2025
1f3fb95
Update
DarkLight1337 Aug 14, 2025
28bde79
Simplify code
DarkLight1337 Aug 14, 2025
91b9563
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 14, 2025
ba6c0d6
Clean up
DarkLight1337 Aug 14, 2025
949cf54
Fix tests
DarkLight1337 Aug 15, 2025
43a537a
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 27, 2025
bdbe1f4
Update `supports_ipc_cache`
DarkLight1337 Aug 27, 2025
d52aa96
[Frontend] Pass API server count to each process
DarkLight1337 Aug 27, 2025
5ff210d
Tests
DarkLight1337 Aug 27, 2025
ed76170
Update
DarkLight1337 Aug 27, 2025
90703bd
Update and fix tests
DarkLight1337 Aug 27, 2025
3f97be4
Update docstring
DarkLight1337 Aug 27, 2025
91ea959
Optimize
DarkLight1337 Aug 27, 2025
6d0c040
Comment
DarkLight1337 Aug 27, 2025
69c9ff0
Improve error message
DarkLight1337 Aug 27, 2025
8e3ea32
Update docstring
DarkLight1337 Aug 27, 2025
082014c
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 27, 2025
1dd5894
Fixture
DarkLight1337 Aug 27, 2025
d06434f
Address comments in serve.py
DarkLight1337 Aug 27, 2025
dac1170
Rename attributes to internal and validate
DarkLight1337 Aug 27, 2025
3f62e0e
Fix
DarkLight1337 Aug 27, 2025
df9f9cb
Update
DarkLight1337 Aug 27, 2025
0ec4e66
Push down
DarkLight1337 Aug 27, 2025
e500a9b
Update
DarkLight1337 Aug 27, 2025
36fb875
Fix
DarkLight1337 Aug 27, 2025
e08e7b7
Try deepcopy
DarkLight1337 Aug 27, 2025
875c7e3
No print
DarkLight1337 Aug 27, 2025
d9a5c81
Simplify
DarkLight1337 Aug 27, 2025
dabe421
Fix
DarkLight1337 Aug 27, 2025
fdc9b6e
Update
DarkLight1337 Aug 27, 2025
94ec51d
Type checking
DarkLight1337 Aug 27, 2025
5191855
Merge branch 'api-server-count-cli' into gpu-mm-processing
DarkLight1337 Aug 27, 2025
013605d
Less diff
DarkLight1337 Aug 27, 2025
2592334
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Aug 27, 2025
3b5db2d
Merge branch 'main' into api-server-count-cli
DarkLight1337 Aug 27, 2025
22914db
Fix
DarkLight1337 Aug 27, 2025
6cb2566
Merge branch 'main' into 'api-server-count-cli'
DarkLight1337 Sep 8, 2025
6aa51a2
Merge branch 'main' into api-server-count-cli
DarkLight1337 Sep 16, 2025
c29495d
Merge branch 'main' into api-server-count-cli
DarkLight1337 Sep 19, 2025
b023ff4
Merge branch 'api-server-count-cli' into gpu-mm-processing
DarkLight1337 Sep 19, 2025
d9b42f1
Fix
DarkLight1337 Sep 19, 2025
0133b09
Fix
DarkLight1337 Sep 19, 2025
ccbc13b
Update tests
DarkLight1337 Sep 19, 2025
72aaa50
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Sep 19, 2025
f7c3e48
Apply device map
DarkLight1337 Sep 19, 2025
7fbc620
Move GPU allocation to `run_multi_api_server`
DarkLight1337 Sep 19, 2025
869ee6c
Fix
DarkLight1337 Sep 19, 2025
dd65208
Add code comment
DarkLight1337 Sep 19, 2025
a508b06
Move allocation to config
DarkLight1337 Sep 19, 2025
cf26387
Remove device map
DarkLight1337 Sep 19, 2025
0c2c4a6
Reduce diff
DarkLight1337 Sep 19, 2025
3426a20
Remove from init
DarkLight1337 Sep 19, 2025
1a5dd58
Clean
DarkLight1337 Sep 19, 2025
e4a7e27
Guard model config
DarkLight1337 Sep 20, 2025
555bd81
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Sep 20, 2025
f4ac4b1
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Sep 22, 2025
99b9e23
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Sep 22, 2025
fe159f6
Help debug
DarkLight1337 Sep 22, 2025
5629f00
Remove debug
DarkLight1337 Sep 23, 2025
d3c8894
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Sep 23, 2025
876cdfa
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Sep 24, 2025
2bb06a5
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Sep 24, 2025
033361e
ruff
DarkLight1337 Oct 6, 2025
78b2108
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Oct 6, 2025
7b7721e
Update
DarkLight1337 Oct 6, 2025
57c0f8b
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Oct 10, 2025
5d1a8d5
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Oct 11, 2025
a01ce77
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Oct 11, 2025
ed5d739
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Oct 13, 2025
ceedc51
ruff
DarkLight1337 Oct 13, 2025
f8ea93d
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Oct 14, 2025
2e3b66b
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Oct 15, 2025
a931694
ruff format
DarkLight1337 Oct 15, 2025
2a4562b
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Oct 15, 2025
8acf7ea
No yapf
DarkLight1337 Oct 15, 2025
d0cd1cf
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Oct 17, 2025
a43ab21
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Dec 15, 2025
84770b1
Fix
DarkLight1337 Dec 15, 2025
1fcbf47
Try fix docs
DarkLight1337 Dec 15, 2025
77bd9fe
Try
DarkLight1337 Dec 15, 2025
93d083d
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Dec 15, 2025
d3e213f
Bad import
DarkLight1337 Dec 15, 2025
8d0515c
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Dec 17, 2025
93b721d
Test
DarkLight1337 Dec 17, 2025
ff96ad4
Fix
DarkLight1337 Dec 17, 2025
db55e11
Simplify
DarkLight1337 Dec 17, 2025
46f9a72
Doc fix
DarkLight1337 Dec 18, 2025
1118961
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Jan 6, 2026
dde6621
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Jan 6, 2026
869926c
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Jan 20, 2026
8af0972
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Feb 11, 2026
e84e760
Reduce diff
DarkLight1337 Feb 11, 2026
0d161e4
Fix mypy
DarkLight1337 Feb 11, 2026
08b8313
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Feb 26, 2026
187602a
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Mar 23, 2026
e304e79
Fix
DarkLight1337 Mar 23, 2026
61d4444
Merge branch 'main' into gpu-mm-processing
DarkLight1337 Apr 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions docs/configuration/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,44 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2

This does not impact [multi-modal processor caching](#processor-caching).

### GPU Multi-Modal Processing

You can speed up multi-modal input processing by running Hugging Face processors on the GPU.
To support this, the processor must accept a `device` argument in its call signature.
As of this writing, the following processors are known to support GPU acceleration:

- Descendants of `BaseImageProcessorFast` (requires `use_fast=True`)
- Descendants of `BaseVideoProcessor`
- `WhisperFeatureExtractor`

To run Hugging Face processors on the GPU, you can pass the `device` argument
(and `use_fast` if needed) via `mm_processor_kwargs`:

```python
# Fast image processor requires use_fast=True
llm = LLM(
model="Qwen/Qwen2.5-VL-3B-Instruct",
mm_processor_kwargs={"use_fast": True, "device": "cuda"},
)

# Whisper feature extractor does not require use_fast
llm = LLM(
model="Qwen/Qwen2-Audio-7B-Instruct",
mm_processor_kwargs={"device": "cuda"},
)
```

!!! note
vLLM will try to allocate visible GPUs that are not used by the core engine
for multi-modal processing. If this is not possible, then the same GPU
will be used for multi-modal processing and model forward pass, resulting
in resource contention (both I/O and memory capacity).

!!! important
The performance improvement from GPU processing varies from model to model.
In some cases, GPU processing may even become detrimental because of resource contention.
Make sure to perform benchmarking before enabling this!

## Multi-Modal Caching

Multi-modal caching avoids repeated transfer or processing of the same multi-modal data,
Expand Down
83 changes: 82 additions & 1 deletion tests/entrypoints/llm/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@

import pytest

from tests.entrypoints.openai.chat_completion.test_vision import TEST_IMAGE_ASSETS
from tests.entrypoints.openai.chat_completion.test_audio import (
TEST_AUDIO_URLS,
dummy_messages_from_audio_url,
)
from tests.entrypoints.openai.chat_completion.test_vision import (
TEST_IMAGE_ASSETS,
dummy_messages_from_image_url,
)
from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams


Expand Down Expand Up @@ -206,3 +214,76 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test):
outputs_2 = llm.chat(batch_2, sampling_params=sampling_params)
assert len(outputs_2) == len(batch_2)
assert llm.llm_engine.get_num_unfinished_requests() == 0


@pytest.mark.parametrize(
("model_id", "modality", "mm_init_kwargs"),
[
("Qwen/Qwen2.5-VL-3B-Instruct", "image", {"use_fast": True}),
("Qwen/Qwen2-Audio-7B-Instruct", "audio", {}),
],
)
@pytest.mark.parametrize(
"image_urls", [[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]], indirect=True
)
def test_mm_processing_gpu(model_id, modality, mm_init_kwargs, image_urls: list[str]):
device = current_platform.device_name

num_items = 2
if modality == "image":
messages = dummy_messages_from_image_url(image_urls[:num_items])
elif modality == "audio":
messages = dummy_messages_from_audio_url(TEST_AUDIO_URLS[:num_items])
else:
raise NotImplementedError(modality)

llm = LLM(
model=model_id,
max_model_len=6144,
max_num_seqs=2,
enforce_eager=True,
seed=0,
limit_mm_per_prompt={modality: num_items},
mm_processor_kwargs=mm_init_kwargs | {"device": device},
)

outputs = llm.chat(messages)
assert len(outputs) == 1


@pytest.mark.parametrize(
("model_id", "modality", "mm_init_kwargs"),
[
("Qwen/Qwen2.5-VL-3B-Instruct", "image", {"use_fast": True}),
("Qwen/Qwen2-Audio-7B-Instruct", "audio", {}),
],
)
@pytest.mark.parametrize("image_urls", [[TEST_IMAGE_ASSETS[0]]], indirect=True)
def test_mm_processing_gpu_bad_device(
model_id, modality, mm_init_kwargs, image_urls: list[str]
):
device = current_platform.device_name
if device == "cpu":
pytest.skip("Not applicable to CPU")

num_items = 1
if modality == "image":
messages = dummy_messages_from_image_url(image_urls[:num_items])
elif modality == "audio":
messages = dummy_messages_from_audio_url(TEST_AUDIO_URLS[:num_items])
else:
raise NotImplementedError(modality)

llm = LLM(
model=model_id,
max_model_len=6144,
max_num_seqs=2,
enforce_eager=True,
seed=0,
limit_mm_per_prompt={modality: num_items},
mm_processor_kwargs=mm_init_kwargs,
)

match = "cannot override the device for multi-modal preprocessing"
with pytest.raises(ValueError, match=match):
llm.chat(messages, mm_processor_kwargs={"device": device})
117 changes: 116 additions & 1 deletion tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,122 @@
MultiModalSharedField,
PlaceholderRange,
)
from vllm.multimodal.utils import argsort_mm_positions, group_and_batch_mm_items
from vllm.multimodal.utils import (
allocate_gpu_mm_processors,
argsort_mm_positions,
group_and_batch_mm_items,
)


@pytest.mark.parametrize(
"case",
[
# Basic
dict(
mm_processor_device="cuda",
mm_processor_count=0,
available_device_count=1,
engine_device_count=1,
expected_gpu_allocation=[],
),
dict(
mm_processor_device="cuda",
mm_processor_count=1,
available_device_count=1,
engine_device_count=1,
expected_gpu_allocation=["cuda:0"],
),
# Use Engine GPUs
dict(
mm_processor_device="cuda",
mm_processor_count=2,
available_device_count=1,
engine_device_count=1,
expected_gpu_allocation=["cuda:0", "cuda:0"],
),
dict(
mm_processor_device="cuda",
mm_processor_count=2,
available_device_count=1,
engine_device_count=2,
expected_gpu_allocation=["cuda:0", "cuda:0"],
),
dict(
mm_processor_device="cuda",
mm_processor_count=2,
available_device_count=2,
engine_device_count=2,
expected_gpu_allocation=["cuda:0", "cuda:1"],
),
dict(
mm_processor_device="cuda",
mm_processor_count=3,
available_device_count=2,
engine_device_count=2,
expected_gpu_allocation=["cuda:0", "cuda:1", "cuda:0"],
),
# Use excess GPUs
dict(
mm_processor_device="cuda",
mm_processor_count=2,
available_device_count=3,
engine_device_count=2,
expected_gpu_allocation=["cuda:2", "cuda:2"],
),
dict(
mm_processor_device="cuda",
mm_processor_count=2,
available_device_count=4,
engine_device_count=2,
expected_gpu_allocation=["cuda:2", "cuda:3"],
),
dict(
mm_processor_device="cuda",
mm_processor_count=3,
available_device_count=4,
engine_device_count=2,
expected_gpu_allocation=["cuda:2", "cuda:3", "cuda:2"],
),
# Specific device
dict(
mm_processor_device="cuda:0",
mm_processor_count=2,
available_device_count=4,
engine_device_count=2,
expected_gpu_allocation=["cuda:0", "cuda:0"],
),
dict(
mm_processor_device="cuda:2",
mm_processor_count=2,
available_device_count=4,
engine_device_count=2,
expected_gpu_allocation=["cuda:2", "cuda:2"],
),
# Out-of-bounds device
dict(
mm_processor_device="cuda:4",
mm_processor_count=2,
available_device_count=4,
engine_device_count=2,
expected_gpu_allocation=["cuda:4", "cuda:4"],
),
],
)
def test_allocate_gpu_mm_processors(case):
mm_processor_device = case["mm_processor_device"]
mm_processor_count = case["mm_processor_count"]
available_device_count = case["available_device_count"]
engine_device_count = case["engine_device_count"]
expected_gpu_allocation = case["expected_gpu_allocation"]

gpu_allocation = allocate_gpu_mm_processors(
mm_processor_device,
mm_processor_count,
available_device_count=available_device_count,
engine_device_count=engine_device_count,
)

assert gpu_allocation == expected_gpu_allocation


@pytest.mark.parametrize(
Expand Down
25 changes: 25 additions & 0 deletions vllm/config/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,15 @@ def _validate_multimodal_config(self):
)
return self

@property
def mm_processing_device(self) -> str:
kwargs = self.mm_processor_kwargs or {}
return str(kwargs.get("device", "cpu"))

@mm_processing_device.setter
def mm_processing_device(self, device: str) -> None:
self.update_mm_processor_kwargs({"device": device})

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -272,6 +281,12 @@ def get_limit_per_prompt(self, modality: str) -> int:

return limit_data.count

def update_mm_processor_kwargs(self, value: dict[str, Any]) -> None:
if self.mm_processor_kwargs is None:
self.mm_processor_kwargs = {}

self.mm_processor_kwargs.update(value)

def merge_mm_processor_kwargs(
self,
inference_kwargs: Mapping[str, object],
Expand All @@ -281,6 +296,16 @@ def merge_mm_processor_kwargs(
according to the extra arguments passed during inference.
"""
kwargs = self.mm_processor_kwargs or {}

# This is to avoid breaking assumptions in memory profiling
init_device = kwargs.get("device", "cpu")
inference_device = inference_kwargs.get("device", init_device)
if init_device != inference_device:
raise ValueError(
"You cannot override the device for multi-modal preprocessing "
f"at runtime! Found: {init_device=} vs. {inference_device=}"
)

return kwargs | dict(inference_kwargs)

def is_multimodal_pruning_enabled(self):
Expand Down
9 changes: 9 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,15 @@ class is dynamically inherited by the worker class. This is used to inject
should only be set by API server scale-out.
"""

_renderer_gpu_allocation: list[str] | None = None
"""
The GPU allocated to the renderer of each API process.

Note:
This is an internal config that is only valid for and
should only be set internally.
"""

@field_validator("disable_nccl_for_dp_synchronization", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
Expand Down
27 changes: 27 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,33 @@ def has_blocked_weights():
self.model_config.disable_cascade_attn = True
logger.warning_once("Disabling cascade attention when DBO is enabled.")

mm_config = self.model_config.multimodal_config if self.model_config else None
if mm_config and mm_config.mm_processing_device != "cpu":
api_process_count = self.parallel_config._api_process_count
api_process_rank = self.parallel_config._api_process_rank
local_gpu_count = (
self.parallel_config.data_parallel_size_local
* self.parallel_config.world_size
)

if api_process_rank != -1:
from vllm.multimodal.utils import allocate_gpu_mm_processors

device_count = current_platform.device_count() # type: ignore

gpu_allocation = allocate_gpu_mm_processors(
mm_config.mm_processing_device,
api_process_count,
available_device_count=device_count,
engine_device_count=local_gpu_count,
)
device = gpu_allocation[api_process_rank]

logger.info("Multi-modal processor will be run on device %s", device)

self.parallel_config._renderer_gpu_allocation = gpu_allocation
mm_config.mm_processing_device = device

if not self.instance_id:
self.instance_id = random_uuid()[:5]

Expand Down
25 changes: 23 additions & 2 deletions vllm/multimodal/processing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,15 +226,36 @@
self,
output: JSONTree,
) -> JSONTree:
mm_config = self.model_config.get_multimodal_config()
is_mm_processing_gpu = mm_config.mm_processing_device != "cpu"

def _postprocess_one(x: object):
if isinstance(x, torch.Tensor): # noqa: SIM102
if isinstance(x, torch.Tensor):
# This mimics the behavior of transformers.BatchFeature
if x.is_floating_point():
x = x.to(dtype=self.model_config.dtype)

# This is required because we need to transfer the data
# to engine core, and the serialization process expects
# CPU tensors.
# The dtype of model config is usually lower precision
# so we call this last to transfer less data to CPU
if is_mm_processing_gpu:
x = x.to(device="cpu", non_blocking=True)

Check failure on line 244 in vllm/multimodal/processing/context.py

View workflow job for this annotation

GitHub Actions / pre-commit

"object" has no attribute "to" [attr-defined]

Check failure on line 244 in vllm/multimodal/processing/context.py

View workflow job for this annotation

GitHub Actions / pre-commit

"object" has no attribute "to" [attr-defined]

Check failure on line 244 in vllm/multimodal/processing/context.py

View workflow job for this annotation

GitHub Actions / pre-commit

"object" has no attribute "to" [attr-defined]

Check failure on line 244 in vllm/multimodal/processing/context.py

View workflow job for this annotation

GitHub Actions / pre-commit

"object" has no attribute "to" [attr-defined]

return x

return json_map_leaves(_postprocess_one, output)
output = json_map_leaves(_postprocess_one, output)

# Async GPU -> CPU requires explicit synchronization
if is_mm_processing_gpu:
from vllm.platforms import current_platform

synchronize = current_platform.synchronize
if synchronize is not None:
synchronize()

return output

def get_merged_mm_kwargs(self, kwargs: Mapping[str, object]):
mm_config = self.model_config.get_multimodal_config()
Expand Down
Loading
Loading