Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|--------|-------------------|----------------------|---------------------------|
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | |
| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A<sup>+</sup> | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-2601-hf` | ✅︎ | ✅︎ |
| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-2601-hf` | ✅︎ | ✅︎ |
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereLabs/aya-vision-8b`, `CohereLabs/aya-vision-32b`, etc. | | ✅︎ |
| `BagelForConditionalGeneration` | BAGEL | T + I<sup>+</sup> | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ |
| `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ |
Expand Down
8 changes: 4 additions & 4 deletions tests/models/multimodal/generation/test_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ def batch_make_video_embeddings(
videos += video_batch

# video to pixel values
image_processor = processor.image_processor
video_processor = processor.video_processor

preprocess_result = image_processor.preprocess(
images=None, videos=videos, return_tensors="pt"
preprocess_result = video_processor.preprocess(
videos=videos, return_tensors="pt"
).data
pixel_values = preprocess_result["pixel_values_videos"]
video_grid_thw = preprocess_result["video_grid_thw"]
Expand All @@ -222,7 +222,7 @@ def get_image_embeds(model):
embed_counter = 0
for video_batch in video_batches_:
cur_batch_video_count = len(video_batch)
merge_size = image_processor.merge_size
merge_size = video_processor.merge_size
cur_batch_embed_len = sum(
grid_thw.prod(-1) // merge_size // merge_size
for grid_thw in video_grid_thw[
Expand Down
2 changes: 1 addition & 1 deletion tests/models/multimodal/pooling/test_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _run_test(
# Patch the issue where image_token_id
# exceeds the maximum allowed vocab size
hf_model.model.resize_token_embeddings(
hf_model.model.language_model.vocab_size + 1
hf_model.model.model.language_model.vocab_size + 1
)

all_inputs = hf_model.get_inputs(input_texts, images=input_images)
Expand Down
4 changes: 3 additions & 1 deletion tests/models/multimodal/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def create_dummy_model(repo: str, model_arch: str) -> PreTrainedModel:
model = model_cls._from_config(config)
# TODO(hmellor): Remove this once Transformers has fixed tied weights on meta device
# https://github.com/huggingface/transformers/issues/43522
if getattr(config.get_text_config(), "tie_word_embeddings", False):
if getattr(config.get_text_config(), "tie_word_embeddings", False) or getattr(
config, "tie_word_embeddings", False
):
model.tie_weights()
return model

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/audioflamingo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def get_data_parser(self):
)

def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None}
return {"audio": 1}
Comment thread
zucchini-nlp marked this conversation as resolved.


class AudioFlamingo3DummyInputsBuilder(
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/hunyuan_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ class HunYuanVLDummyInputsBuilder(BaseDummyInputsBuilder[HunYuanVLProcessingInfo
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)

hf_processor = self.info.get_hf_processor()
hf_processor = self.info.get_hf_processor(typ=HunYuanVLProcessor)
image_token: str = hf_processor.image_token

return image_token * num_images
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/isaac.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.nn.functional as F
from einops import rearrange
from transformers.image_processing_utils import BatchFeature
from transformers.tokenization_utils import TensorType
from transformers.utils import TensorType
from typing_extensions import TypedDict, Unpack

from vllm.config import VllmConfig
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/minimax_vl_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.vision_feature_layer = config.vision_feature_layer
self.vocab_size = config.text_config.vocab_size
self.pad_token_id = -1
if self.config.pad_token_id is not None:
self.pad_token_id = self.config.pad_token_id
if self.config.text_config.pad_token_id is not None:
self.pad_token_id = self.config.text_config.pad_token_id

self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
Expand Down
46 changes: 26 additions & 20 deletions vllm/transformers_utils/processors/bagel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@
from transformers import AutoProcessor
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessorMixin
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput


class BagelProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg]
_defaults = {
"images_kwargs": {
"return_tensors": "pt",
},
}


class BagelProcessor(ProcessorMixin):
"""
Constructs a BAGEL processor which wraps a
Expand All @@ -27,34 +35,32 @@ def __call__(
| list[TextInput]
| list[PreTokenizedInput] = None,
images: ImageInput = None,
**kwargs,
**kwargs: Unpack[BagelProcessorKwargs],
):
"""
Main method to prepare for the model one or several sequences(s) and image(s).
"""
output_kwargs = self._merge_kwargs(
BagelProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if images is not None:
# Process images with the image processor
# Ensure return_tensors is set to "pt" for PyTorch tensors
image_kwargs = {**kwargs}
if "return_tensors" not in image_kwargs:
image_kwargs["return_tensors"] = "pt"
pixel_values = self.image_processor(images, **image_kwargs)
pixel_values = self.image_processor(
images, **output_kwargs["images_kwargs"]
)
else:
pixel_values = None
pixel_values = {}

text_inputs = self.tokenizer(text, **kwargs) if text is not None else None
text_inputs = (
self.tokenizer(text, **output_kwargs["text_kwargs"])
if text is not None
else {}
)

if pixel_values is not None and text_inputs is not None:
# Combine text and image inputs into BatchFeature
combined = dict(text_inputs)
combined["pixel_values"] = pixel_values["pixel_values"]
return BatchFeature(combined)
elif pixel_values is not None:
return pixel_values
elif text_inputs is not None:
return BatchFeature(dict(text_inputs))
else:
return BatchFeature({})
return BatchFeature(data={**pixel_values, **text_inputs})

def batch_decode(self, *args, **kwargs):
"""
Expand Down
5 changes: 1 addition & 4 deletions vllm/transformers_utils/processors/hunyuan_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(
self,
image_processor=None,
tokenizer=None,
video_processor=None,
chat_template=None,
**kwargs,
):
Expand All @@ -42,9 +41,7 @@ def __init__(
)
self.pad_id = 120002 # self.tokenizer.pad_token_id

super().__init__(
image_processor, tokenizer, video_processor, chat_template=chat_template
)
super().__init__(image_processor, tokenizer, chat_template=chat_template)

def __call__(
self,
Expand Down
17 changes: 11 additions & 6 deletions vllm/transformers_utils/processors/ovis.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ class OvisProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-a
"padding": False,
},
"images_kwargs": {
"max_partition": 9,
"covering_threshold": 0.9,
"convert_to_rgb": True,
"do_convert_rgb": True,
"return_tensors": "pt",
},
}
Expand Down Expand Up @@ -143,6 +141,10 @@ def __call__(
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
"""

max_partition = kwargs.pop("max_partition", 9)
covering_threshold = kwargs.pop("covering_threshold", 0.9)

output_kwargs = self._merge_kwargs(
OvisProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
Expand All @@ -159,7 +161,10 @@ def __call__(
# Process each image
for image in images if isinstance(images, list) else [images]:
pixel_values, image_placeholders, grid = self.preprocess_image(
image=image, **output_kwargs["images_kwargs"]
image=image,
max_partition=max_partition,
covering_threshold=covering_threshold,
**output_kwargs["images_kwargs"],
)
processed_images.append(pixel_values)
image_placeholders_list.append(image_placeholders)
Expand Down Expand Up @@ -300,7 +305,7 @@ def preprocess_image(
image: PIL.Image.Image,
max_partition,
covering_threshold,
convert_to_rgb,
do_convert_rgb,
return_tensors,
):
def _preprocess(img: PIL.Image.Image, side):
Expand Down Expand Up @@ -394,7 +399,7 @@ def _get_best_grid(img, side):
# pick the partition with maximum covering_ratio and break the tie using #sub_images
return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]

if convert_to_rgb:
if do_convert_rgb:
image = convert_image_mode(image, "RGB")

sides = self.get_image_size()
Expand Down
31 changes: 18 additions & 13 deletions vllm/transformers_utils/processors/ovis2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,10 @@ class Ovis2_5ProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[cal
"padding": False,
},
"images_kwargs": {
"convert_to_rgb": True,
"min_pixels": MIN_PIXELS,
"max_pixels": MAX_PIXELS,
"do_convert_rgb": True,
},
"videos_kwargs": {
"convert_to_rgb": True,
"min_pixels": MIN_PIXELS,
"max_pixels": MAX_PIXELS,
"do_convert_rgb": True,
},
}

Expand Down Expand Up @@ -160,6 +156,9 @@ def __call__(
- **second_per_grid_ts** -- list of video seconds per time grid.
Returned when `videos` is not `None`.
"""
min_pixels = kwargs.pop("min_pixels", MIN_PIXELS)
max_pixels = kwargs.pop("max_pixels", MAX_PIXELS)

output_kwargs = self._merge_kwargs(
Ovis2_5ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
Expand All @@ -175,7 +174,10 @@ def __call__(
# Process each image
for image in images if isinstance(images, list) else [images]:
pixel_values, image_placeholders, grid = self.preprocess_multidata(
images=image, **output_kwargs["images_kwargs"]
images=image,
min_pixels=min_pixels,
max_pixels=max_pixels,
**output_kwargs["images_kwargs"],
)
processed_images.append(pixel_values)
image_placeholders_list.append(image_placeholders)
Expand All @@ -194,7 +196,10 @@ def __call__(
# Process each video
for video in videos if isinstance(videos, list) else [videos]:
pixel_values, video_placeholders, grid = self.preprocess_multidata(
video=video, **output_kwargs["videos_kwargs"]
video=video,
min_pixels=min_pixels,
max_pixels=max_pixels,
**output_kwargs["videos_kwargs"],
)
processed_videos.append(pixel_values)
videos_placeholders_list.append(video_placeholders)
Expand Down Expand Up @@ -378,7 +383,7 @@ def preprocess_multidata(
self,
images: PIL.Image.Image | list[PIL.Image.Image] | None = None,
video: list[PIL.Image.Image] | np.ndarray | None = None,
convert_to_rgb: bool | None = True,
do_convert_rgb: bool | None = True,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
return_tensors: str | None = "pt",
Expand All @@ -404,7 +409,7 @@ def preprocess_multidata(
min_pixels if min_pixels is not None else MIN_PIXELS,
)
images = [
image.convert("RGB") if convert_to_rgb and image.mode != "RGB" else image
image.convert("RGB") if do_convert_rgb and image.mode != "RGB" else image
for image in images
]

Expand All @@ -420,9 +425,9 @@ def preprocess_multidata(
max_pixels=max_pixels,
)
new_size = dict(height=resized_height, width=resized_width)
image_pt = self.image_processor.preprocess(
image, size=new_size, return_tensors="np"
)["pixel_values"][0]
image_pt = self.image_processor.preprocess(image, size=new_size)[
"pixel_values"
][0]
Comment thread
zucchini-nlp marked this conversation as resolved.

processed_images.append(image_pt)

Expand Down