Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/distributed/test_shm_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _dummy_elem(modality: str, key: str, size: int):
modality=modality,
key=key,
data=torch.empty((size,), dtype=torch.int8),
field=MultiModalSharedField(1),
field=MultiModalSharedField(batch_size=1),
)


Expand Down
2 changes: 1 addition & 1 deletion tests/multimodal/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _dummy_elem(
modality=modality,
key=key,
data=data,
field=MultiModalSharedField(1),
field=MultiModalSharedField(batch_size=1),
)


Expand Down
21 changes: 15 additions & 6 deletions tests/v1/test_serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,31 @@ class MyRequest(msgspec.Struct):

def test_multimodal_kwargs():
e1 = MultiModalFieldElem(
"audio", "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField()
"audio",
"a0",
torch.zeros(1000, dtype=torch.bfloat16),
MultiModalBatchedField(),
)
e2 = MultiModalFieldElem(
"video",
"v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalFlatField([[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0),
MultiModalFlatField(
slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]],
dim=0,
),
)
e3 = MultiModalFieldElem(
"image", "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(4)
"image",
"i0",
torch.zeros(1000, dtype=torch.int32),
MultiModalSharedField(batch_size=4),
)
e4 = MultiModalFieldElem(
"image",
"i1",
torch.zeros(1000, dtype=torch.int32),
MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2),
MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2),
)
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
Expand All @@ -138,8 +147,8 @@ def test_multimodal_kwargs():

total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)

# expected total encoding length, should be 14306, +-20 for minor changes
assert 14275 <= total_len <= 14325
# expected total encoding length, should be 14395, +-20 for minor changes
assert 14375 <= total_len <= 14425
decoded = decoder.decode(encoded).mm[0]
assert isinstance(decoded, MultiModalKwargsItems)

Expand Down
29 changes: 11 additions & 18 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,10 +787,10 @@ def compute_attn_mask_seqlen(
def forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor:
# Convert grid_thw to tensor (always expecting list format now)
grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)

# patchify
x = x.to(device=self.device, dtype=self.dtype)
Comment thread
DarkLight1337 marked this conversation as resolved.
Expand All @@ -805,7 +805,8 @@ def forward(
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)

# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
Expand Down Expand Up @@ -1548,7 +1549,6 @@ def _process_image_input(
) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()

if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
Expand All @@ -1559,20 +1559,17 @@ def _process_image_input(
self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw.tolist())
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)

merge_size = self.visual.spatial_merge_size
sizes = (
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return image_embeds.split(sizes)

def _process_video_input(
self, video_input: Glm4vVideoInputs
) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()

if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
Expand All @@ -1588,15 +1585,11 @@ def _process_video_input(
rope_type="rope_3d",
)
else:
video_embeds = self.visual(
pixel_values_videos, grid_thw=grid_thw.tolist()
)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)

# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = (
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return video_embeds.split(sizes)

def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/models/hunyuan_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def _hunyuan_vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
)


Expand Down Expand Up @@ -786,8 +786,6 @@ class HunYuanVLForConditionalGeneration(
SupportsQuant,
SupportsXDRoPE,
):
multimodal_cpu_fields = {"image_grid_thw"}

# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
Expand Down
13 changes: 11 additions & 2 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ class SupportsMultiModal(Protocol):
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
"""

multimodal_cpu_fields: ClassVar[Set[str]] = frozenset()
multimodal_cpu_fields: ClassVar[Set[str] | None] = None
"""
A set indicating CPU-only multimodal fields.
[DEPRECATED] A set indicating CPU-only multimodal fields.
"""

_processor_factory: ClassVar[_ProcessorFactories]
Expand Down Expand Up @@ -279,6 +279,15 @@ def supports_multimodal(
"please remove the override from your model."
)

multimodal_cpu_fields = getattr(model, "multimodal_cpu_fields", None)
if multimodal_cpu_fields is not None:
raise ValueError(
"`multimodal_cpu_fields` is no longer effective, "
"please set `keep_on_cpu=True` in `MultiModalFieldConfig` "
"(refer to https://github.com/vllm-project/vllm/pull/30181), "
"and then remove the override from your model."
)

return res


Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/opencua.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,6 @@ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
dummy_inputs=OpenCUADummyInputsBuilder,
)
class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
multimodal_cpu_fields = {"image_grid_thw"}

packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,8 +1039,6 @@ class Qwen2_5_VLForConditionalGeneration(
SupportsMultiModalPruning,
SupportsMRoPE,
):
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}

packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
Expand Down
12 changes: 6 additions & 6 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,14 +811,14 @@ def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_embed_grid_sizes
),
image_grid_thw=MultiModalFieldConfig.batched("image"),
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_embed_grid_sizes
),
video_grid_thw=MultiModalFieldConfig.batched("video"),
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
)

return _qwen2vl_field_config
Expand Down Expand Up @@ -1131,8 +1131,6 @@ def _get_mm_fields_config(
class Qwen2VLForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}

# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
Expand Down Expand Up @@ -1393,9 +1391,11 @@ def _process_video_input(
else:
pixel_values_videos = video_input["pixel_values_videos"]
if self.use_data_parallel:
grid_thw_list = grid_thw.tolist()
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
self.visual,
pixel_values_videos,
grid_thw.tolist(),
rope_type="rope_3d",
)
else:
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,14 +984,14 @@ def _get_mm_fields_config(
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes
),
image_grid_thw=MultiModalFieldConfig.batched("image"),
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
),
video_grid_thw=MultiModalFieldConfig.batched("video"),
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
)

def _get_prompt_updates(
Expand Down Expand Up @@ -1190,8 +1190,6 @@ class Qwen3VLForConditionalGeneration(
SupportsMRoPE,
SupportsEagle3,
):
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down
Loading