diff --git a/examples/offline_inference/qwen2_5_omni/README.md b/examples/offline_inference/qwen2_5_omni/README.md index 112ef5335e2..14203978972 100644 --- a/examples/offline_inference/qwen2_5_omni/README.md +++ b/examples/offline_inference/qwen2_5_omni/README.md @@ -32,3 +32,11 @@ Then run the command below. ```bash bash run_single_prompt.sh ``` + +### FAQ + +If you encounter error about backend of librosa, try to install ffmpeg with command below. +``` +sudo apt update +sudo apt install ffmpeg +``` diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py index e3e496f46d1..be4e2c1a1a2 100644 --- a/examples/offline_inference/qwen2_5_omni/end2end.py +++ b/examples/offline_inference/qwen2_5_omni/end2end.py @@ -1,166 +1,145 @@ -import argparse +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This example shows how to use vLLM-omni for running offline inference +with the correct prompt format on Qwen2.5-Omni +""" + import os -import os as _os_env_toggle -import random +from typing import NamedTuple -import numpy as np import soundfile as sf -import torch -from utils import make_omni_prompt +from vllm.assets.audio import AudioAsset +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset +from vllm.multimodal.image import convert_image_mode from vllm.sampling_params import SamplingParams +from vllm.utils import FlexibleArgumentParser -from vllm_omni.entrypoints.omni_llm import OmniLLM - -_os_env_toggle.environ["VLLM_USE_V1"] = "1" +from vllm_omni import OmniLLM SEED = 42 -# Set all random seeds -random.seed(SEED) -np.random.seed(SEED) -torch.manual_seed(SEED) -torch.cuda.manual_seed(SEED) -torch.cuda.manual_seed_all(SEED) -# Make PyTorch deterministic -torch.backends.cudnn.deterministic = True -torch.backends.cudnn.benchmark = False -# Set environment variables for deterministic behavior -os.environ["PYTHONHASHSEED"] = str(SEED) -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +class QueryResult(NamedTuple): + inputs: dict + limit_mm_per_prompt: dict[str, int] -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--model", - required=True, - help="Path to merged model directory (will be created if downloading).", - ) - parser.add_argument("--thinker-model", type=str, default=None) - parser.add_argument("--talker-model", type=str, default=None) - parser.add_argument("--code2wav-model", type=str, default=None) - parser.add_argument( - "--hf-hub-id", - default="Qwen/Qwen2.5-Omni-7B", - help="Hugging Face repo id to download if needed.", - ) - parser.add_argument("--hf-revision", default=None, help="Optional HF revision (branch/tag/commit).") - parser.add_argument("--prompts", nargs="+", default=None, help="Input text prompts.") - parser.add_argument("--voice-type", default="default", help="Voice type, e.g., m02, f030, default.") - parser.add_argument( - "--code2wav-dir", - default=None, - help="Path to code2wav folder (contains spk_dict.pt).", - ) - parser.add_argument("--dit-ckpt", default=None, help="Path to DiT checkpoint file (e.g., dit.pt).") - parser.add_argument("--bigvgan-ckpt", default=None, help="Path to BigVGAN checkpoint file.") - parser.add_argument("--dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"]) - parser.add_argument("--max-model-len", type=int, default=32768) - parser.add_argument( - "--init-sleep-seconds", - type=int, - default=20, - help="Sleep seconds after starting each stage process to allow initialization (default: 20)", - ) +# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on +# lower-end GPUs. +# Unless specified, these settings have been tested to work on a single L4. - parser.add_argument("--thinker-only", action="store_true") - parser.add_argument("--text-only", action="store_true") - parser.add_argument("--do-wave", action="store_true") - parser.add_argument( - "--prompt_type", - choices=[ - "text", - "audio", - "audio-long", - "audio-long-chunks", - "audio-long-expand-chunks", - "image", - "video", - "video-frames", - "audio-in-video", - "audio-in-video-v2", - "audio-multi-round", - "badcase-vl", - "badcase-text", - "badcase-image-early-stop", - "badcase-two-audios", - "badcase-two-videos", - "badcase-multi-round", - "badcase-voice-type", - "badcase-voice-type-v2", - "badcase-audio-tower-1", - "badcase-audio-only", - ], - default="text", - ) - parser.add_argument("--use-torchvision", action="store_true") - parser.add_argument("--tokenize", action="store_true") - parser.add_argument( - "--output-wav", - default="output.wav", - help="[Deprecated] Output wav directory (use --output-dir).", +default_system = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " + "Group, capable of perceiving auditory and visual inputs, as well as " + "generating text and speech." +) + + +def get_text_query(question: str = None) -> QueryResult: + if question is None: + question = "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" ) - parser.add_argument( - "--output-dir", - default="outputs", - help="Output directory to save text and wav files together.", + return QueryResult( + inputs={ + "prompt": prompt, + }, + limit_mm_per_prompt={}, ) - parser.add_argument( - "--thinker-hidden-states-dir", - default="thinker_hidden_states", - help="Path to thinker hidden states directory.", + + +def get_mixed_modalities_query() -> QueryResult: + question = "What is recited in the audio? What is the content of this image? Why is this video funny?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + "<|vision_bos|><|IMAGE|><|vision_eos|>" + "<|vision_bos|><|VIDEO|><|vision_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" ) - parser.add_argument( - "--batch-timeout", - type=int, - default=5, - help="Timeout for batching in seconds (default: 5)", + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, + "image": convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB"), + "video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays, + }, + }, + limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1}, ) - parser.add_argument( - "--init-timeout", - type=int, - default=300, - help="Timeout for initializing stages in seconds (default: 300)", + + +def get_use_audio_in_video_query() -> QueryResult: + question = "Describe the content of the video, then convert what the baby say into text." + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" ) - parser.add_argument( - "--shm-threshold-bytes", - type=int, - default=65536, - help="Threshold for using shared memory in bytes (default: 65536)", + asset = VideoAsset(name="baby_reading", num_frames=16) + audio = asset.get_audio(sampling_rate=16000) + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "video": asset.np_ndarrays, + "audio": audio, + }, + "mm_processor_kwargs": { + "use_audio_in_video": True, + }, + }, + limit_mm_per_prompt={"audio": 1, "video": 1}, ) - parser.add_argument( - "--enable-stats", - action="store_true", - default=False, - help="Enable writing detailed statistics (default: disabled)", + + +def get_multi_audios_query() -> QueryResult: + question = "Are these two audio clips the same?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + "<|audio_bos|><|AUDIO|><|audio_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" ) - parser.add_argument( - "--txt-prompts", - type=str, - default=None, - help="Path to a .txt file with one prompt per line (preferred).", + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "audio": [ + AudioAsset("winning_call").audio_and_sample_rate, + AudioAsset("mary_had_lamb").audio_and_sample_rate, + ], + }, + }, + limit_mm_per_prompt={ + "audio": 2, + }, ) - args = parser.parse_args() - return args -def main(): - args = parse_args() - model_name = args.model - try: - # Preferred: load from txt file (one prompt per line) - if getattr(args, "txt_prompts", None) and args.prompt_type == "text": - with open(args.txt_prompts, encoding="utf-8") as f: - lines = [ln.strip() for ln in f.readlines()] - args.prompts = [ln for ln in lines if ln != ""] - print(f"[Info] Loaded {len(args.prompts)} prompts from {args.txt_prompts}") - except Exception as e: - print(f"[Error] Failed to load prompts: {e}") - raise - - if args.prompts is None: - raise ValueError("No prompts provided. Use --prompts ... or --txt-prompts (with --prompt_type text)") +query_map = { + "mixed_modalities": get_mixed_modalities_query, + "use_audio_in_video": get_use_audio_in_video_query, + "multi_audios": get_multi_audios_query, + "text": get_text_query, +} + + +def main(args): + model_name = "Qwen/Qwen2.5-Omni-7B" + query_result = query_map[args.query_type]() + omni_llm = OmniLLM( model=model_name, log_stats=args.enable_stats, @@ -205,8 +184,16 @@ def main(): code2wav_sampling_params, ] - prompt = [make_omni_prompt(args, prompt) for prompt in args.prompts] - omni_outputs = omni_llm.generate(prompt, sampling_params_list) + if args.txt_prompts is None: + prompts = [query_result.inputs for _ in range(args.num_prompts)] + else: + assert args.query_type == "text", "txt-prompts is only supported for text query type" + with open(args.txt_prompts, encoding="utf-8") as f: + lines = [ln.strip() for ln in f.readlines()] + prompts = [get_text_query(ln).inputs for ln in lines if ln != ""] + print(f"[Info] Loaded {len(prompts)} prompts from {args.txt_prompts}") + + omni_outputs = omni_llm.generate(prompts, sampling_params_list) # Determine output directory: prefer --output-dir; fallback to --output-wav output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav @@ -217,7 +204,7 @@ def main(): request_id = int(output.request_id) text_output = output.outputs[0].text # Save aligned text file per request - prompt_text = args.prompts[request_id] + prompt_text = prompts[request_id]["prompt"] out_txt = os.path.join(output_dir, f"{request_id:05d}.txt") lines = [] lines.append("Prompt:\n") @@ -239,5 +226,67 @@ def main(): print(f"Request ID: {request_id}, Saved audio to {output_wav}") +def parse_args(): + parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models") + parser.add_argument( + "--query-type", + "-q", + type=str, + default="mixed_modalities", + choices=query_map.keys(), + help="Query type.", + ) + parser.add_argument( + "--enable-stats", + action="store_true", + default=False, + help="Enable writing detailed statistics (default: disabled)", + ) + parser.add_argument( + "--init-sleep-seconds", + type=int, + default=20, + help="Sleep seconds after starting each stage process to allow initialization (default: 20)", + ) + parser.add_argument( + "--batch-timeout", + type=int, + default=5, + help="Timeout for batching in seconds (default: 5)", + ) + parser.add_argument( + "--init-timeout", + type=int, + default=300, + help="Timeout for initializing stages in seconds (default: 300)", + ) + parser.add_argument( + "--shm-threshold-bytes", + type=int, + default=65536, + help="Threshold for using shared memory in bytes (default: 65536)", + ) + parser.add_argument( + "--output-wav", + default="output_audio", + help="[Deprecated] Output wav directory (use --output-dir).", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1, + help="Number of prompts to generate.", + ) + parser.add_argument( + "--txt-prompts", + type=str, + default=None, + help="Path to a .txt file with one prompt per line (preferred).", + ) + + return parser.parse_args() + + if __name__ == "__main__": - main() + args = parse_args() + main(args) diff --git a/examples/offline_inference/qwen2_5_omni/processing_omni.py b/examples/offline_inference/qwen2_5_omni/processing_omni.py deleted file mode 100644 index a22220dd388..00000000000 --- a/examples/offline_inference/qwen2_5_omni/processing_omni.py +++ /dev/null @@ -1,367 +0,0 @@ -from __future__ import annotations - -import base64 -import logging -import math -import os -import time -import warnings -from functools import lru_cache -from io import BytesIO - -import requests -import torch -import torchvision -from packaging import version -from PIL import Image -from torchvision import io, transforms -from torchvision.transforms import InterpolationMode - -logger = logging.getLogger(__name__) - -IMAGE_FACTOR = 28 -MIN_PIXELS = 4 * 28 * 28 -MAX_PIXELS = 16384 * 28 * 28 -MAX_RATIO = 200 - -VIDEO_MIN_PIXELS = 128 * 28 * 28 -VIDEO_MAX_PIXELS = 768 * 28 * 28 -VIDEO_TOTAL_PIXELS = 24576 * 28 * 28 -FRAME_FACTOR = 2 -FPS = 2.0 -FPS_MIN_FRAMES = 4 -FPS_MAX_FRAMES = 768 - -temporal_patch_size = 2 -spatial_patch_size = 14 -spatial_merge_size = 2 - - -def round_by_factor(number: int, factor: int) -> int: - """Returns the closest integer to 'number' that is divisible by 'factor'.""" - return round(number / factor) * factor - - -def ceil_by_factor(number: int, factor: int) -> int: - """Returns the smallest integer greater than or equal to 'number' that is - divisible by 'factor'.""" - return math.ceil(number / factor) * factor - - -def floor_by_factor(number: int, factor: int) -> int: - """Returns the largest integer less than or equal to 'number' that is - divisible by 'factor'.""" - return math.floor(number / factor) * factor - - -def smart_resize( - height: int, - width: int, - factor: int = IMAGE_FACTOR, - min_pixels: int = MIN_PIXELS, - max_pixels: int = MAX_PIXELS, -) -> tuple[int, int]: - """ - Rescales the image so that the following conditions are met: - - 1. Both dimensions (height and width) are divisible by 'factor'. - - 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. - - 3. The aspect ratio of the image is maintained as closely as possible. - """ - if max(height, width) / min(height, width) > MAX_RATIO: - raise ValueError( - f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" - ) - h_bar = max(factor, round_by_factor(height, factor)) - w_bar = max(factor, round_by_factor(width, factor)) - if h_bar * w_bar > max_pixels: - beta = math.sqrt((height * width) / max_pixels) - h_bar = floor_by_factor(height / beta, factor) - w_bar = floor_by_factor(width / beta, factor) - elif h_bar * w_bar < min_pixels: - beta = math.sqrt(min_pixels / (height * width)) - h_bar = ceil_by_factor(height * beta, factor) - w_bar = ceil_by_factor(width * beta, factor) - return h_bar, w_bar - - -def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image: - if "image" in ele: - image = ele["image"] - else: - image = ele["image_url"] - image_obj = None - if isinstance(image, Image.Image): - image_obj = image - elif image.startswith("http://") or image.startswith("https://"): - image_obj = Image.open(requests.get(image, stream=True).raw) - elif image.startswith("file://"): - image_obj = Image.open(image[7:]) - elif image.startswith("data:image"): - if "base64," in image: - _, base64_data = image.split("base64,", 1) - data = base64.b64decode(base64_data) - image_obj = Image.open(BytesIO(data)) - else: - image_obj = Image.open(image) - if image_obj is None: - raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") - image = image_obj.convert("RGB") - # resize - if "resized_height" in ele and "resized_width" in ele: - resized_height, resized_width = smart_resize( - ele["resized_height"], - ele["resized_width"], - factor=size_factor, - ) - else: - width, height = image.size - min_pixels = ele.get("min_pixels", MIN_PIXELS) - max_pixels = ele.get("max_pixels", MAX_PIXELS) - resized_height, resized_width = smart_resize( - height, - width, - factor=size_factor, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - image = image.resize((resized_width, resized_height)) - - return image - - -def smart_nframes( - ele: dict, - total_frames: int, - video_fps: int | float, -) -> int: - """calculate the number of frames for video used for model inputs. - - Args: - ele (dict): a dict contains the configuration of video. - support either `fps` or `nframes`: - - nframes: the number of frames to extract for model inputs. - - fps: the fps to extract frames for model inputs. - - min_frames: the minimum number of frames of the video, - only used when fps is provided. - - max_frames: the maximum number of frames of the video, - only used when fps is provided. - total_frames (int): the original total number of frames of the video. - video_fps (int | float): the original fps of the video. - - Raises: - ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. - - Returns: - int: the number of frames for video used for model inputs. - """ - assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" - if "nframes" in ele: - nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) - else: - fps = ele.get("fps", FPS) - min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) - max_frames = floor_by_factor( - ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), - FRAME_FACTOR, - ) - nframes = total_frames / video_fps * fps - nframes = min(max(nframes, min_frames), max_frames) - nframes = round_by_factor(nframes, FRAME_FACTOR) - if not (FRAME_FACTOR <= nframes and nframes <= total_frames): - raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") - return nframes - - -def _read_video_torchvision( - ele: dict, -) -> torch.Tensor: - """read video using torchvision.io.read_video - - Args: - ele (dict): a dict contains the configuration of video. - support keys: - - video: the path of video. support "file://", "http://", - "https://" and local path. - - video_start: the start time of video. - - video_end: the end time of video. - Returns: - torch.Tensor: the video tensor with shape (T, C, H, W). - """ - video_path = ele["video"] - if version.parse(torchvision.__version__) < version.parse("0.19.0"): - if "http://" in video_path or "https://" in video_path: - warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.") - if "file://" in video_path: - video_path = video_path[7:] - st = time.time() - video, audio, info = io.read_video( - video_path, - start_pts=ele.get("video_start", 0.0), - end_pts=ele.get("video_end", None), - pts_unit="sec", - output_format="TCHW", - ) - total_frames, video_fps = video.size(0), info["video_fps"] - total_duration = round(total_frames / video_fps, 3) - logger.info( - f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, duration={total_duration}s, time={time.time() - st:.3f}s" - ) - nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) - idx = torch.linspace(0, total_frames - 1, nframes).round().long() - video = video[idx] - return video, total_duration, nframes - - -def is_decord_available() -> bool: - import importlib.util - - return importlib.util.find_spec("decord") is not None - - -def _read_video_decord( - ele: dict, -) -> torch.Tensor: - """read video using decord.VideoReader - - Args: - ele (dict): a dict contains the configuration of video. - support keys: - - video: the path of video. support "file://", "http://", - "https://" and local path. - - video_start: the start time of video. - - video_end: the end time of video. - Returns: - torch.Tensor: the video tensor with shape (T, C, H, W). - """ - import decord - - video_path = ele["video"] - st = time.time() - vr = decord.VideoReader(video_path) - # TODO: support start_pts and end_pts - if "video_start" in ele or "video_end" in ele: - raise NotImplementedError("not support start_pts and end_pts in decord for now.") - total_frames, video_fps = len(vr), vr.get_avg_fps() - total_duration = round(total_frames / video_fps, 3) - logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") - nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) - idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() - video = vr.get_batch(idx).asnumpy() - video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format - return video, total_duration, nframes - - -VIDEO_READER_BACKENDS = { - "decord": _read_video_decord, - "torchvision": _read_video_torchvision, -} - -FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) - - -@lru_cache(maxsize=1) -def get_video_reader_backend() -> str: - if FORCE_QWENVL_VIDEO_READER is not None: - video_reader_backend = FORCE_QWENVL_VIDEO_READER - elif is_decord_available(): - video_reader_backend = "decord" - else: - video_reader_backend = "torchvision" - # print(f"qwen-vl-utils using {video_reader_backend} to read video.", - # file=sys.stderr) - return video_reader_backend - - -def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]: - if isinstance(ele["video"], str): - video_reader_backend = get_video_reader_backend() - video, total_dur, nframes = VIDEO_READER_BACKENDS[video_reader_backend](ele) - frame_timestamps = total_dur * torch.arange(1, nframes + 1) / nframes - grid_timestamps = frame_timestamps[::FRAME_FACTOR] - second_per_grid = grid_timestamps[1] - grid_timestamps[0] - nframes, _, height, width = video.shape - min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) - total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) - max_pixels = max( - min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), - int(min_pixels * 1.05), - ) - max_pixels = ele.get("max_pixels", max_pixels) - # min_pixels = (factor ** 2) * 52 - # max_pixels = (factor ** 2) * min(768, (16384 / nframes * temporal_patch_size)) - if "resized_height" in ele and "resized_width" in ele: - resized_height, resized_width = smart_resize( - ele["resized_height"], - ele["resized_width"], - factor=image_factor, - ) - else: - resized_height, resized_width = smart_resize( - height, - width, - factor=image_factor, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - video = transforms.functional.resize( - video, - [resized_height, resized_width], - interpolation=InterpolationMode.BICUBIC, - antialias=True, - ).float() - return video, total_dur, nframes, second_per_grid - else: - assert isinstance(ele["video"], (list, tuple)) - process_info = ele.copy() - process_info.pop("type", None) - process_info.pop("video", None) - images = [ - fetch_image({"image": video_element, **process_info}, size_factor=image_factor) - for video_element in ele["video"] - ] - nframes = ceil_by_factor(len(images), FRAME_FACTOR) - if len(images) < nframes: - images.extend([images[-1]] * (nframes - len(images))) - return images, None, None, None - - -def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]: - vision_infos = [] - if isinstance(conversations[0], dict): - conversations = [conversations] - for conversation in conversations: - for message in conversation: - if isinstance(message["content"], list): - for ele in message["content"]: - if ( - "image" in ele - or "image_url" in ele - or "video" in ele - or ele["type"] in ("image", "image_url", "video") - ): - vision_infos.append(ele) - return vision_infos - - -def process_vision_info( - conversations: list[dict] | list[list[dict]], -) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None]: - vision_infos = extract_vision_info(conversations) - # Read images or videos - image_inputs = [] - video_inputs = [] - for vision_info in vision_infos: - if "image" in vision_info or "image_url" in vision_info: - image_inputs.append(fetch_image(vision_info)) - elif "video" in vision_info: - video_inputs.append(fetch_video(vision_info)) - else: - raise ValueError("image, image_url or video should in content.") - if len(image_inputs) == 0: - image_inputs = None - if len(video_inputs) == 0: - video_inputs = None - return image_inputs, video_inputs diff --git a/examples/offline_inference/qwen2_5_omni/run_multiple_prompts.sh b/examples/offline_inference/qwen2_5_omni/run_multiple_prompts.sh index 78d3dd54fb0..2ec8a1c57ec 100644 --- a/examples/offline_inference/qwen2_5_omni/run_multiple_prompts.sh +++ b/examples/offline_inference/qwen2_5_omni/run_multiple_prompts.sh @@ -1,8 +1,3 @@ -python end2end.py --model Qwen/Qwen2.5-Omni-7B \ - --voice-type "m02" \ - --dit-ckpt none \ - --bigvgan-ckpt none \ - --output-wav output_audio \ - --prompt_type text \ - --init-sleep-seconds 0 \ - --txt-prompts top100.txt +python end2end.py --output-wav output_audio \ + --query-type text \ + --txt-prompts top10.txt diff --git a/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh b/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh index 739902b2561..5b3c19cdc27 100644 --- a/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh +++ b/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh @@ -1,8 +1,2 @@ -python end2end.py --model Qwen/Qwen2.5-Omni-7B \ - --voice-type "m02" \ - --dit-ckpt none \ - --bigvgan-ckpt none \ - --output-wav output_audio \ - --prompt_type text \ - --init-sleep-seconds 0 \ - --prompts "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." +python end2end.py --output-wav output_audio \ + --query-type use_audio_in_video diff --git a/examples/offline_inference/qwen2_5_omni/utils.py b/examples/offline_inference/qwen2_5_omni/utils.py deleted file mode 100644 index 11df5c55a0e..00000000000 --- a/examples/offline_inference/qwen2_5_omni/utils.py +++ /dev/null @@ -1,312 +0,0 @@ -import tempfile -from typing import Optional, Union -from urllib.request import urlopen - -import librosa -import requests -import resampy -import soundfile as sf -import torch -import torchvision.io -from processing_omni import fetch_image, fetch_video -from transformers import AutoConfig, AutoProcessor -from vllm.inputs import TextPrompt - -from vllm_omni.inputs.data import OmniTokensPrompt - -# Simple caches to avoid repeated heavy HF loads per prompt -_PROCESSOR_CACHE: dict[str, "AutoProcessor"] = {} -_CONFIG_CACHE: dict[str, "AutoConfig"] = {} - - -def get_system_prompt(): - return { - "role": "system", - "content": [ - { - "type": "text", - "text": ( - "You are Qwen, a virtual human developed by the Qwen Team, " - "Alibaba Group, capable of perceiving auditory and visual inputs, " - "as well as generating text and speech." - ), - } - ], - } - - -def resample_wav_to_16khz(input_filepath): - data, original_sample_rate = sf.read(input_filepath) - # Only use the first channel - if len(data.shape) > 1: - data = data[:, 0] - # resample to 16kHz - data_resampled = resampy.resample(data, sr_orig=original_sample_rate, sr_new=16000) - return data_resampled - - -def fetch_and_read_video(args, video_url: str, fps=2): - def read_video_with_torchvision(video_file_name: str): - video, audio, info = torchvision.io.read_video( - video_file_name, - start_pts=0.0, - end_pts=None, - pts_unit="sec", - output_format="TCHW", - ) - - total_frames, video_fps = video.size(0), info["video_fps"] - total_duration = round(total_frames / video_fps, 3) - nframes = int(total_frames / video_fps * fps) - - frame_timestamps = total_duration * torch.arange(1, nframes + 1) / nframes - grid_timestamps = frame_timestamps[::2] - second_per_grid = grid_timestamps[1] - grid_timestamps[0] - - idx = torch.linspace(0, video.size(0) - 1, nframes).round().long() - video = video[idx] - - if args.legacy_omni_video: - return [video, total_duration, nframes, second_per_grid.item()] - else: - return video - - def read_video_with_transformers(video_file_name: Union[str, list[str]]): - video, total_duration, nframes, second_per_grid = fetch_video({"video": video_file_name}) - if total_duration is None and nframes is None: - nframes = len(video) - total_duration = 0.5 * nframes - second_per_grid = 1.0 - if args.legacy_omni_video: - return [video, total_duration, nframes, second_per_grid] - else: - return video - - def read_video(video_file_name: str): - if args.use_torchvision: - return read_video_with_torchvision(video_file_name) - else: - return read_video_with_transformers(video_file_name) - - if isinstance(video_url, str) and video_url.startswith("http"): - with tempfile.NamedTemporaryFile(delete=True) as temp_video_file: - resp = requests.get(video_url) - assert resp.status_code == requests.codes.ok, ( - f"Failed to fetch video from {video_url}, status_code:{resp.status_code}, resp:{resp}" - ) - - temp_video_file.write(urlopen(video_url).read()) - temp_video_file_path = temp_video_file.name - video_file_name = temp_video_file_path - return read_video(video_file_name) - else: - video_file_name = video_url - return read_video(video_file_name) - - -def make_inputs_qwen2_omni( - args, - messages: list[dict[str, Union[str, list[dict[str, str]]]]], - use_audio_in_video: Optional[bool] = False, - tokenize: bool = False, -) -> Union[OmniTokensPrompt, TextPrompt]: - from transformers import AutoConfig, AutoProcessor - - # Cached processor/config to prevent per-prompt reloading and repeated warnings - if args.model not in _PROCESSOR_CACHE: - _PROCESSOR_CACHE[args.model] = AutoProcessor.from_pretrained(args.model) - processor = _PROCESSOR_CACHE[args.model] - - config = _CONFIG_CACHE.get(args.model) - if config is None: - try: - config = AutoConfig.from_pretrained(args.model) - except Exception: - config = None - _CONFIG_CACHE[args.model] = config # cache even if None to avoid retry storms - - # Decide legacy flag only once based on config (default True if unknown) - if getattr(args, "legacy_omni_video", None) is None: - if config is not None and hasattr(config, "architectures"): - args.legacy_omni_video = "Qwen2_5OmniModel" not in config.architectures - else: - args.legacy_omni_video = True - - audios, images, videos = [], [], [] - for message in messages: - if not isinstance(message["content"], list): - message["content"] = [ - { - "type": "text", - "text": message["content"], - } - ] - index, num_contents = 0, len(message["content"]) - while index < num_contents: - ele = message["content"][index] - if "type" not in ele: - if "text" in ele: - ele["type"] = "text" - elif "audio" in ele: - ele["type"] = "audio" - elif "audio_url" in ele: - ele["type"] = "audio_url" - elif "image" in ele: - ele["type"] = "image" - elif "image_url" in ele: - ele["type"] = "image_url" - elif "video" in ele: - ele["type"] = "video" - elif "video_url" in ele: - ele["type"] = "video_url" - else: - raise ValueError(f"Unknown ele: {ele}") - - if ele["type"] == "audio" or ele["type"] == "audio_url": - if "audio_url" in ele: - audio_key = "audio_url" - with tempfile.NamedTemporaryFile(delete=True) as temp_audio_file: - temp_audio_file.write(urlopen(ele[audio_key]).read()) - temp_audio_file_path = temp_audio_file.name - audios.append(resample_wav_to_16khz(temp_audio_file_path)) - ele["audio"] = temp_audio_file_path - elif "audio" in ele: - audio_key = "audio" - audios.append(resample_wav_to_16khz(ele[audio_key])) - else: - raise ValueError(f"Unknown ele {ele}") - elif use_audio_in_video and (ele["type"] == "video" or ele["type"] == "video_url"): - # use video as audio as well - if "video_url" in ele: - audio_key = "video_url" - with tempfile.NamedTemporaryFile(delete=True) as temp_video_file: - temp_video_file.write(urlopen(ele[audio_key]).read()) - temp_video_file_path = temp_video_file.name - ele[audio_key] = temp_video_file_path - audios.append(librosa.load(temp_video_file_path, sr=16000)[0]) - videos.append(fetch_and_read_video(args, temp_video_file_path)) - ele["video"] = temp_video_file_path - elif "video" in ele: - audio_key = "video" - audios.append(librosa.load(ele[audio_key], sr=16000)[0]) - videos.append(fetch_and_read_video(args, audio_key)) - else: - raise ValueError(f"Unknown ele {ele}") - # insert a audio after the video - message["content"].insert( - index + 1, - { - "type": "audio", - "audio": ele[audio_key], - }, - ) - # no need to load the added audio again - index += 1 - elif ele["type"] == "video" or ele["type"] == "video_url": - if "video_url" in ele: - video_key = "video_url" - with tempfile.NamedTemporaryFile(delete=True) as temp_video_file: - temp_video_file.write(urlopen(ele["video_url"]).read()) - temp_video_file_path = temp_video_file.name - videos.append(fetch_and_read_video(args, temp_video_file)) - ele["video"] = temp_video_file_path - else: - video_key = "video" - videos.append(fetch_and_read_video(args, ele[video_key])) - elif ele["type"] == "image" or ele["type"] == "image_url": - images.append(fetch_image(ele)) - - # move to the next content - index += 1 - - prompt = processor.apply_chat_template( - messages, - tokenize=tokenize, - add_generation_prompt=True, - add_vision_id=True, - ) - - audios = audios if len(audios) > 0 else None - images = images if len(images) > 0 else None - videos = videos if len(videos) > 0 else None - - multi_modal_data = {} - if audios: - multi_modal_data["audio"] = audios - if images: - multi_modal_data["image"] = images - if videos: - multi_modal_data["video"] = videos - - if isinstance(prompt, list) and isinstance(prompt[0], (list, str)): - prompt = prompt[0] - - if tokenize: - return OmniTokensPrompt( - prompt_token_ids=prompt, - multi_modal_data=multi_modal_data, - ) - else: - return TextPrompt( - prompt=prompt, - multi_modal_data=multi_modal_data, - ) - - -def make_text_prompt(args, prompt): - messages = [ - get_system_prompt(), - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - ], - }, - ] - - prompt = make_inputs_qwen2_omni(args, messages, tokenize=args.tokenize) - return prompt - - -def make_audio_in_video_v2_prompt(args): - messages = [ - { - "role": "system", - "content": [ - { - "type": "text", - "text": ( - "You are Qwen, a virtual human developed by the Qwen Team, " - "Alibaba Group, capable of perceiving auditory and visual " - "inputs, as well as generating text and speech." - ), - } - ], - }, - { - "role": "user", - "content": [ - { - "type": "video_url", - "video_url": ("https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/draw_small.mp4"), - }, - ], - }, - ] - prompt = make_inputs_qwen2_omni( - args, - messages, - use_audio_in_video=True, - tokenize=args.tokenize, - ) - return prompt - - -def make_omni_prompt(args, prompt=None) -> Union[OmniTokensPrompt, list[OmniTokensPrompt]]: - if args.prompt_type == "text": - prompt = make_text_prompt(args, prompt) - elif args.prompt_type == "audio-in-video-v2": - prompt = make_audio_in_video_v2_prompt(args) - else: - raise ValueError(f"Unsupported prompt type: {args.prompt_type}") - return prompt diff --git a/examples/online_serving/README.md b/examples/online_serving/README.md index b64989326dc..ac86f96c900 100644 --- a/examples/online_serving/README.md +++ b/examples/online_serving/README.md @@ -23,10 +23,18 @@ cd examples/online_serving Send request via python ```bash -python openai_chat_completion_client_for_multimodal_generation.py +python openai_chat_completion_client_for_multimodal_generation.py --query-type mixed_modalities ``` Send request via curl ```bash -bash run_curl_multimodal_generation.sh +bash run_curl_multimodal_generation.sh mixed_modalities +``` + +### FAQ + +If you encounter error about backend of librosa, try to install ffmpeg with command below. +``` +sudo apt update +sudo apt install ffmpeg ``` diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py b/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py index 13d37476d8f..47e109be5ac 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py @@ -1,6 +1,9 @@ import base64 +import requests from openai import OpenAI +from vllm.assets.audio import AudioAsset +from vllm.utils import FlexibleArgumentParser # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" @@ -15,6 +18,16 @@ SEED = 42 +def encode_base64_content_from_url(content_url: str) -> str: + """Encode a content retrieved from a remote url to base64 format.""" + + with requests.get(content_url) as response: + response.raise_for_status() + result = base64.b64encode(response.content).decode("utf-8") + + return result + + def get_system_prompt(): return { "role": "system", @@ -31,7 +44,106 @@ def get_system_prompt(): } -def run_text_to_audio(model: str) -> None: +def get_text_query(): + question = "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." + prompt = { + "role": "user", + "content": [ + { + "type": "text", + "text": f"{question}", + } + ], + } + return prompt + + +def get_mixed_modalities_query(): + question = "What is recited in the audio? What is the content of this image? Why is this video funny?" + prompt = { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": {"url": AudioAsset("mary_had_lamb").url}, + }, + { + "type": "image_url", + "image_url": { + "url": "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg" + }, + }, + { + "type": "video_url", + "video_url": { + "url": "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4" + }, + }, + { + "type": "text", + "text": f"{question}", + }, + ], + } + + return prompt + + +def get_use_audio_in_video_query(): + question = "Describe the content of the video, then convert what the baby say into text." + + prompt = { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": { + "url": "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4", + "num_frames": 16, + }, + }, + { + "type": "text", + "text": f"{question}", + }, + ], + } + + return prompt + + +def get_multi_audios_query(): + question = "Are these two audio clips the same?" + prompt = { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": {"url": AudioAsset("mary_had_lamb").url}, + }, + { + "type": "audio_url", + "audio_url": {"url": AudioAsset("winning_call").url}, + }, + { + "type": "text", + "text": f"{question}", + }, + ], + } + return prompt + + +query_map = { + "mixed_modalities": get_mixed_modalities_query, + "use_audio_in_video": get_use_audio_in_video_query, + "multi_audios": get_multi_audios_query, + "text": get_text_query, +} + + +def run_multimodal_generation(args) -> None: + model_name = "Qwen/Qwen2.5-Omni-7B" thinker_sampling_params = { "temperature": 0.0, # Deterministic - no randomness "top_p": 1.0, # Disable nucleus sampling @@ -67,23 +179,21 @@ def run_text_to_audio(model: str) -> None: code2wav_sampling_params, ] + prompt = query_map[args.query_type]() + extra_body = { + "sampling_params_list": sampling_params_list # Optional, it has a default setting in stage_configs of the corresponding model. + } + + if args.query_type == "use_audio_in_video": + extra_body["mm_processor_kwargs"] = {"use_audio_in_video": True} + chat_completion = client.chat.completions.create( messages=[ get_system_prompt(), - { - "role": "user", - "content": [ - { - "type": "text", - "text": "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words.", - }, - ], - }, + prompt, ], - model=model, - extra_body={ - "sampling_params_list": sampling_params_list - }, # Optional, it has a default setting in stage_configs of the corresponding model. + model=model_name, + extra_body=extra_body, ) count = 0 @@ -99,5 +209,20 @@ def run_text_to_audio(model: str) -> None: print("Chat completion output from text:", choice.message.content) +def parse_args(): + parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models") + parser.add_argument( + "--query-type", + "-q", + type=str, + default="mixed_modalities", + choices=query_map.keys(), + help="Query type.", + ) + + return parser.parse_args() + + if __name__ == "__main__": - run_text_to_audio("Qwen/Qwen2.5-Omni-7B") + args = parse_args() + run_multimodal_generation(args) diff --git a/examples/online_serving/run_curl_multimodal_generation.sh b/examples/online_serving/run_curl_multimodal_generation.sh index 6d2afd4a846..d0c85bc391b 100644 --- a/examples/online_serving/run_curl_multimodal_generation.sh +++ b/examples/online_serving/run_curl_multimodal_generation.sh @@ -1,6 +1,20 @@ #!/usr/bin/env bash set -euo pipefail +# Default query type +QUERY_TYPE="${1:-mixed_modalities}" + +# Validate query type +if [[ ! "$QUERY_TYPE" =~ ^(mixed_modalities|use_audio_in_video|multi_audios|text)$ ]]; then + echo "Error: Invalid query type '$QUERY_TYPE'" + echo "Usage: $0 [mixed_modalities|use_audio_in_video|multi_audios|text]" + echo " mixed_modalities: Audio + Image + Video + Text query" + echo " use_audio_in_video: Video + Text query (with audio extraction from video)" + echo " multi_audios: Two audio clips + Text query" + echo " text: Text query" + exit 1 +fi + SEED=42 thinker_sampling_params='{ @@ -35,18 +49,121 @@ code2wav_sampling_params='{ }' # Above is optional, it has a default setting in stage_configs of the corresponding model. +# Define URLs for assets +MARY_HAD_LAMB_AUDIO_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/mary_had_lamb.ogg" +WINNING_CALL_AUDIO_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/winning_call.ogg" +CHERRY_BLOSSOM_IMAGE_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg" +SAMPLE_VIDEO_URL="https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4" + +# Build user content and extra fields based on query type +case "$QUERY_TYPE" in + text) + user_content='[ + { + "type": "text", + "text": "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs="{}" + ;; + mixed_modalities) + user_content='[ + { + "type": "audio_url", + "audio_url": { + "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'" + } + }, + { + "type": "image_url", + "image_url": { + "url": "'"$CHERRY_BLOSSOM_IMAGE_URL"'" + } + }, + { + "type": "video_url", + "video_url": { + "url": "'"$SAMPLE_VIDEO_URL"'" + } + }, + { + "type": "text", + "text": "What is recited in the audio? What is the content of this image? Why is this video funny?" + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs="{}" + ;; + use_audio_in_video) + user_content='[ + { + "type": "video_url", + "video_url": { + "url": "'"$SAMPLE_VIDEO_URL"'" + } + }, + { + "type": "text", + "text": "Describe the content of the video, then convert what the baby say into text." + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs='{ + "use_audio_in_video": true + }' + ;; + multi_audios) + user_content='[ + { + "type": "audio_url", + "audio_url": { + "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'" + } + }, + { + "type": "audio_url", + "audio_url": { + "url": "'"$WINNING_CALL_AUDIO_URL"'" + } + }, + { + "type": "text", + "text": "Are these two audio clips the same?" + } + ]' + sampling_params_list='[ + '"$thinker_sampling_params"', + '"$talker_sampling_params"', + '"$code2wav_sampling_params"' + ]' + mm_processor_kwargs="{}" + ;; +esac + +echo "Running query type: $QUERY_TYPE" +echo "" + + output=$(curl -sS -X POST http://localhost:8091/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d @- < "BaseMultiModalContentParser": + return OmniAsyncMultiModalContentParser(self) + + +class OmniAsyncMultiModalContentParser(AsyncMultiModalContentParser): + def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: + super().__init__(tracker=tracker) + self._mm_processor_kwargs: Optional[dict[str, Any]] = None + + def set_mm_processor_kwargs(self, mm_processor_kwargs: Optional[dict[str, Any]]) -> None: + """Set mm_processor_kwargs for use in parsing.""" + self._mm_processor_kwargs = mm_processor_kwargs + + def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None: + video = self._connector.fetch_video_async(video_url=video_url) if video_url else None + + placeholder = self._tracker.add("video", video, uuid) + self._add_placeholder("video", placeholder) + + # Extract audio from video if use_audio_in_video is True + if video_url and self._mm_processor_kwargs and self._mm_processor_kwargs.get("use_audio_in_video", False): + audio_coro = self._extract_audio_from_video_async(video_url) + audio_placeholder = self._tracker.add("audio", audio_coro, uuid) + self._add_placeholder("audio", audio_placeholder) + + async def _extract_audio_from_video_async(self, video_url: str) -> tuple[np.ndarray, Union[int, float]]: + """ + Extract audio from video URL using librosa. + Returns tuple of (audio_array, sample_rate) compatible with audio format. + + All blocking I/O operations are run in a thread pool to avoid blocking the event loop. + """ + import asyncio + import os + import tempfile + from urllib.parse import urlparse + + # Parse URL to determine type + parsed_url = urlparse(video_url) + temp_video_file_path = None + + def _download_video_sync(url: str) -> bytes: + """Synchronous video download - runs in thread pool.""" + from urllib.request import urlopen + + return urlopen(url).read() + + def _write_temp_file_sync(data: bytes, suffix: str) -> str: + """Synchronous temp file write - runs in thread pool.""" + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: + temp_file.write(data) + return temp_file.name + + def _load_audio_sync(file_path: str) -> tuple[np.ndarray, Union[int, float]]: + """Synchronous audio loading with librosa - runs in thread pool.""" + import librosa + + return librosa.load(file_path, sr=16000) + + def _cleanup_file_sync(file_path: str) -> None: + """Synchronous file deletion - runs in thread pool.""" + try: + if os.path.exists(file_path): + os.unlink(file_path) + except OSError: + pass + + try: + if parsed_url.scheme in ("http", "https"): + # Download video from HTTP/HTTPS URL asynchronously + video_data = await asyncio.to_thread(_download_video_sync, video_url) + # Write temp file asynchronously + temp_video_file_path = await asyncio.to_thread(_write_temp_file_sync, video_data, ".mp4") + elif parsed_url.scheme == "file": + # Use file path directly (handle Windows paths) + from urllib.request import url2pathname + + temp_video_file_path = url2pathname(parsed_url.path) + elif parsed_url.scheme == "data": + # Handle data URL (base64 encoded video) + import base64 + + header, data = video_url.split(",", 1) + video_data = base64.b64decode(data) + # Write temp file asynchronously + temp_video_file_path = await asyncio.to_thread(_write_temp_file_sync, video_data, ".mp4") + else: + # Assume it's a local file path + temp_video_file_path = video_url + + # Extract audio using librosa asynchronously (CPU-intensive, runs in thread pool) + audio_array, sample_rate = await asyncio.to_thread(_load_audio_sync, temp_video_file_path) + + return audio_array, sample_rate + finally: + # Clean up temporary file if we created one (asynchronously) + if temp_video_file_path and parsed_url.scheme in ("http", "https", "data"): + await asyncio.to_thread(_cleanup_file_sync, temp_video_file_path) + + +def parse_chat_messages_futures( + messages: list[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, + content_format: _ChatTemplateContentFormat, + mm_processor_kwargs: Optional[dict[str, Any]] = None, +) -> tuple[ + list[ConversationMessage], + Awaitable[Optional[MultiModalDataDict]], + Optional[MultiModalUUIDDict], +]: + conversation: list[ConversationMessage] = [] + mm_tracker = OmniAsyncMultiModalItemTracker(model_config, tokenizer) + + for msg in messages: + sub_messages = _parse_chat_message_content( + msg, + mm_tracker, + content_format, + interleave_strings=( + content_format == "string" + and model_config.multimodal_config is not None + and model_config.multimodal_config.interleave_mm_strings + ), + mm_processor_kwargs=mm_processor_kwargs, + ) + + conversation.extend(sub_messages) + + _postprocess_messages(conversation) + + return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() + + +def _parse_chat_message_content( + message: ChatCompletionMessageParam, + mm_tracker: BaseMultiModalItemTracker, + content_format: _ChatTemplateContentFormat, + interleave_strings: bool, + mm_processor_kwargs: Optional[dict[str, Any]] = None, +) -> list[ConversationMessage]: + role = message["role"] + content = message.get("content") + + if content is None: + content = [] + elif isinstance(content, str): + content = [ChatCompletionContentPartTextParam(type="text", text=content)] + result = _parse_chat_message_content_parts( + role, + content, # type: ignore + mm_tracker, + wrap_dicts=(content_format == "openai"), + interleave_strings=interleave_strings, + mm_processor_kwargs=mm_processor_kwargs, + ) + + for result_msg in result: + if role == "assistant": + parsed_msg = _AssistantParser(message) + + # The 'tool_calls' is not None check ensures compatibility. + # It's needed only if downstream code doesn't strictly + # follow the OpenAI spec. + if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None: + result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) + elif role == "tool": + parsed_msg = _ToolParser(message) + if "tool_call_id" in parsed_msg: + result_msg["tool_call_id"] = parsed_msg["tool_call_id"] + + if "name" in message and isinstance(message["name"], str): + result_msg["name"] = message["name"] + + return result + + +def _parse_chat_message_content_parts( + role: str, + parts: Iterable[ChatCompletionContentPartParam], + mm_tracker: BaseMultiModalItemTracker, + *, + wrap_dicts: bool, + interleave_strings: bool, + mm_processor_kwargs: Optional[dict[str, Any]] = None, +) -> list[ConversationMessage]: + content = list[_ContentPart]() + + mm_parser = mm_tracker.create_parser() + # Set mm_processor_kwargs if parser supports it + if hasattr(mm_parser, "set_mm_processor_kwargs"): + mm_parser.set_mm_processor_kwargs(mm_processor_kwargs) + + for part in parts: + parse_res = _parse_chat_message_content_part( + part, + mm_parser, + wrap_dicts=wrap_dicts, + interleave_strings=interleave_strings, + ) + if parse_res: + content.append(parse_res) + + if wrap_dicts: + # Parsing wraps images and texts as interleaved dictionaries + return [ConversationMessage(role=role, content=content)] # type: ignore + texts = cast(list[str], content) + mm_placeholder_storage = mm_parser.mm_placeholder_storage() + if mm_placeholder_storage: + text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage, texts, interleave_strings) + else: + text_prompt = "\n".join(texts) + + return [ConversationMessage(role=role, content=text_prompt)] diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 994dfa423ca..481c4821359 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -50,6 +50,7 @@ def __init__(self, stage_config): self.stage_id = stage_config.stage_id self.engine_args = stage_config.engine_args self.model_stage = stage_config.engine_args.model_stage + self.requires_multimodal_data = getattr(stage_config.runtime, "requires_multimodal_data", False) self.engine_input_source = getattr(stage_config, "engine_input_source", []) self.engine_output_type = stage_config.engine_args.engine_output_type self.engine_outputs = None @@ -205,14 +206,20 @@ def process_engine_inputs( for source_output in source_outputs: engine_input = OmniTokensPrompt( prompt_token_ids=source_output.outputs[0].token_ids, - multi_modal_data=(multi_modal_data[source_output.request_id] if multi_modal_data else None), + multi_modal_data=( + multi_modal_data[source_output.request_id] + if self.requires_multimodal_data and multi_modal_data + else None + ), ) engine_inputs.append(engine_input) return engine_inputs else: engine_input_source = self.engine_input_source - return self.custom_process_input_func(stage_list, engine_input_source, prompt) + return self.custom_process_input_func( + stage_list, engine_input_source, prompt, self.requires_multimodal_data + ) def _stage_worker( @@ -227,13 +234,13 @@ def _stage_worker( import logging as _logging import time as _time - from vllm_omni.entrypoints.log_utils import ( # noqa: WPS433 + from vllm_omni.entrypoints.log_utils import ( compute_and_log_stage_request_stats, count_tokens_from_outputs, log_stage_batch_stats, log_stage_running_avg, ) - from vllm_omni.entrypoints.omni_llm import OmniStageLLM # noqa: WPS433 + from vllm_omni.entrypoints.omni_llm import OmniStageLLM # no inline JSONL/serialization imports; logging handled by utilities @@ -501,8 +508,8 @@ async def _stage_worker_async( import logging as _logging import time as _time - from vllm_omni.entrypoints.async_omni_llm import AsyncOmniStageLLM # noqa: WPS433 - from vllm_omni.entrypoints.log_utils import ( # noqa: WPS433 + from vllm_omni.entrypoints.async_omni_llm import AsyncOmniStageLLM + from vllm_omni.entrypoints.log_utils import ( compute_and_log_stage_request_stats, count_tokens_from_outputs, log_stage_batch_stats, diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 6e7ade7c777..14c266c85f9 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -3,10 +3,10 @@ import json import time import uuid -from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import AsyncGenerator, AsyncIterator, Sequence from datetime import datetime, timedelta, timezone from io import BytesIO -from typing import Optional, Union +from typing import Any, Callable, Optional, Union import jinja2 from fastapi import Request @@ -18,7 +18,16 @@ soundfile = None from openai.types.chat.chat_completion_audio import ChatCompletionAudio as OpenAIChatCompletionAudio -from vllm.entrypoints.chat_utils import ConversationMessage, get_history_tool_calls_cnt, make_tool_call_id +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + ConversationMessage, + apply_hf_chat_template, + apply_mistral_chat_template, + get_history_tool_calls_cnt, + make_tool_call_id, + resolve_chat_template_content_format, +) from vllm.entrypoints.harmony_utils import parse_chat_output from vllm.entrypoints.openai.protocol import ( ChatCompletionNamedToolChoiceParam, @@ -35,7 +44,16 @@ UsageInfo, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_engine import RequestPrompt, clamp_prompt_logprobs +from vllm.entrypoints.openai.serving_engine import ( + ChatLikeRequest, + EngineTokensPrompt, + RequestPrompt, + ResponsesRequest, + TextTokensPrompt, + clamp_prompt_logprobs, + is_list_of, +) +from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.inputs.data import PromptType from vllm.logger import init_logger @@ -50,6 +68,7 @@ ) from vllm.utils import as_list +from vllm_omni.entrypoints.chat_utils import parse_chat_messages_futures from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -215,6 +234,123 @@ async def create_chat_completion( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) + async def _preprocess_chat( + self, + request: Union[ChatLikeRequest, ResponsesRequest], + tokenizer: AnyTokenizer, + messages: list[ChatCompletionMessageParam], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tool_dicts: Optional[list[dict[str, Any]]] = None, + documents: Optional[list[dict[str, str]]] = None, + chat_template_kwargs: Optional[dict[str, Any]] = None, + tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, + add_special_tokens: bool = False, + ) -> tuple[ + list[ConversationMessage], + Sequence[RequestPrompt], + list[EngineTokensPrompt], + ]: + model_config = self.model_config + + resolved_content_format = resolve_chat_template_content_format( + chat_template, + tool_dicts, + chat_template_content_format, + tokenizer, + model_config=model_config, + ) + conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( + messages, + model_config, + tokenizer, + content_format=resolved_content_format, + mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None), + ) + + _chat_template_kwargs: dict[str, Any] = dict( + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tool_dicts, + documents=documents, + ) + _chat_template_kwargs.update(chat_template_kwargs or {}) + + request_prompt: Union[str, list[int]] + + if tokenizer is None: + request_prompt = "placeholder" + elif isinstance(tokenizer, MistralTokenizer): + request_prompt = apply_mistral_chat_template( + tokenizer, + messages=messages, + **_chat_template_kwargs, + ) + else: + request_prompt = apply_hf_chat_template( + tokenizer=tokenizer, + conversation=conversation, + model_config=model_config, + **_chat_template_kwargs, + ) + + mm_data = await mm_data_future + + # tool parsing is done only if a tool_parser has been set and if + # tool_choice is not "none" (if tool_choice is "none" but a tool_parser + # is set, we want to prevent parsing a tool_call hallucinated by the LLM + should_parse_tools = tool_parser is not None and ( + hasattr(request, "tool_choice") and request.tool_choice != "none" + ) + + if should_parse_tools: + if not isinstance(request, ChatCompletionRequest): + msg = "Tool usage is only supported for Chat Completions API" + raise NotImplementedError(msg) + + request = tool_parser(tokenizer).adjust_request( # type: ignore + request=request + ) + + if tokenizer is None: + assert isinstance(request_prompt, str), ( + "Prompt has to be a string", + "when the tokenizer is not initialised", + ) + prompt_inputs = TextTokensPrompt(prompt=request_prompt, prompt_token_ids=[1]) + elif isinstance(request_prompt, str): + prompt_inputs = await self._tokenize_prompt_input_async( + request, + tokenizer, + request_prompt, + add_special_tokens=add_special_tokens, + ) + else: + # For MistralTokenizer + assert is_list_of(request_prompt, int), "Prompt has to be either a string or a list of token ids" + prompt_inputs = TextTokensPrompt( + prompt=tokenizer.decode(request_prompt), + prompt_token_ids=request_prompt, + ) + + engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"]) + if mm_data is not None: + engine_prompt["multi_modal_data"] = mm_data + + if mm_uuids is not None: + engine_prompt["multi_modal_uuids"] = mm_uuids + + if request.mm_processor_kwargs is not None: + engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs + + if hasattr(request, "cache_salt") and request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + + return conversation, [request_prompt], [engine_prompt] + def _to_sampling_params_list(self, sampling_params_list: list[dict]) -> list[SamplingParams]: final_sampling_params_list = [] for sampling_params in sampling_params_list: diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/__init__.py b/vllm_omni/model_executor/models/qwen2_5_omni/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_omni/model_executor/models/qwen2_5_omni.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py similarity index 80% rename from vllm_omni/model_executor/models/qwen2_5_omni.py rename to vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py index ae5ff25ca58..b650712b59f 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py @@ -7,6 +7,7 @@ import numpy as np import torch import torch.nn as nn +from transformers import PretrainedConfig from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( Qwen2_5OmniConfig, Qwen2_5OmniTalkerConfig, @@ -14,13 +15,7 @@ ) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP -from vllm.model_executor.models.qwen2_5_omni_thinker import ( - Qwen2_5OmniConditionalGenerationMixin, - Qwen2_5OmniThinkerDummyInputsBuilder, - Qwen2_5OmniThinkerMultiModalProcessor, - Qwen2_5OmniThinkerProcessingInfo, -) +from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal, SupportsPP from vllm.model_executor.models.utils import init_vllm_registered_model, maybe_prefix # from vllm.model_executor.models.qwen2_code2wav_dit import Qwen2Code2wav @@ -31,7 +26,14 @@ from vllm.v1.sample.sampler import Sampler from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific -from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights +from vllm_omni.model_executor.models.qwen2_5_omni.qwen2_5_omni_thinker import ( + Qwen2_5OmniConditionalGenerationMixin, + Qwen2_5OmniThinkerDummyInputsBuilder, + Qwen2_5OmniThinkerMultiModalProcessor, + Qwen2_5OmniThinkerProcessingInfo, +) +from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights, split_list_into_ranges +from vllm_omni.model_executor.models.vision import get_llm_pos_ids_for_vision TALKER_CODEC_EOS_TOKEN_ID = 8294 TALKER_CODEC_BOS_TOKEN_ID = 8293 @@ -54,7 +56,7 @@ class OmniOutput(NamedTuple): dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, ) class Qwen2_5OmniForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, Qwen2_5OmniConditionalGenerationMixin + nn.Module, SupportsMultiModal, SupportsPP, Qwen2_5OmniConditionalGenerationMixin, SupportsMRoPE ): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -459,6 +461,170 @@ def forward( multimodal_outputs=None, ) + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + second_per_grid_ts: list[float] | None = None, + context_len: int = 0, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value (Qwen2.5-Omni version). + + Differences from MRotaryEmbedding: + 1. Add audio support (and related `audio_feature_lengths`). + 2. Add `use_audio_in_video` option to read audio from video inputs. + In this case, audio and vision position ids will be split into + chunks and interleaved. + + Example: + + (V_i are vision position ids, A_i are audio position ids) + + |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... + |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... + """ + + # TODO(fyabc): refactor and share more code with + # _vl_get_input_positions_tensor. + + thinker_config = hf_config.thinker_config + audio_token_id = thinker_config.audio_token_index + image_token_id = thinker_config.image_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + vision_start_token_id = thinker_config.vision_start_token_id + vision_end_token_id = thinker_config.vision_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, "tokens_per_second", 25) + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + src_item = input_tokens + audio_seqlens = audio_feature_lengths + if not second_per_grid_ts: + second_per_grid_ts = [1] * video_grid_thw.shape[0] + audio_idx = 0 + video_idx = 0 + image_idx = 0 + new_src_item: list[int] = [] + llm_pos_ids_list: list[torch.Tensor] = [] + + idx = 0 + while idx < len(src_item): + new_src_item_len = len(new_src_item) + start_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]: + if use_audio_in_video and idx > 0: + if src_item[idx] == vision_end_token_id and src_item[idx - 1] == audio_end_token_id: + # processing the <|audio_eos|> before <|vision_eos|> + start_idx -= 1 + elif src_item[idx] == audio_start_token_id and src_item[idx - 1] == vision_start_token_id: + # processing the <|audio_bos|> after <|vision_eos|> + start_idx -= 1 + new_src_item.append(src_item[idx]) + llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1) + llm_pos_ids_list.append(llm_pos_ids) + elif src_item[idx] == audio_token_id: + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 + new_src_item.extend([audio_token_id] * place_num) + llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx + llm_pos_ids_list.append(llm_pos_ids) + audio_idx += 1 + elif src_item[idx] == image_token_id: + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + new_src_item.extend([image_token_id] * vision_seqlen) + image_idx += 1 + elif src_item[idx] == video_token_id and not use_audio_in_video: + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * second_per_grid_ts[video_idx] * tokens_per_second).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_seqlen) + video_idx += 1 + else: + # read audio from video + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + vision_seqlen = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + grid_t = video_grid_thw[video_idx][0] + grid_h = video_grid_thw[video_idx][1] + grid_w = video_grid_thw[video_idx][2] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * second_per_grid_ts[video_idx] * tokens_per_second).long() + t_index_split_chunk = split_list_into_ranges(t_index, t_ntoken_per_chunk) + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 + pure_audio_len = place_num - 2 + added_audio_len = 0 + audio_llm_pos_ids_list: list[torch.Tensor] = [] + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_ntoken_per_chunk) + vision_llm_pos_ids_list = get_llm_pos_ids_for_vision( + start_idx, + video_idx, + spatial_merge_size, + t_chunk, + grid_hs, + grid_ws, + ).split(1, dim=1) + llm_pos_ids_list.extend(vision_llm_pos_ids_list) + new_src_item.extend(min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) * [audio_token_id]) + audio_start_idx = ( + start_idx if len(audio_llm_pos_ids_list) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 + ) + if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0: + audio_llm_pos_ids_list = ( + torch.arange(min(t_ntoken_per_chunk, pure_audio_len - added_audio_len)).expand(3, -1) + + audio_start_idx + ).split(1, dim=1) + else: + audio_llm_pos_ids_list = [] + added_audio_len += min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + if added_audio_len < pure_audio_len: + new_src_item.extend((pure_audio_len - added_audio_len) * [audio_token_id]) + audio_llm_pos_ids_list = ( + torch.arange(pure_audio_len - added_audio_len).expand(3, -1) + llm_pos_ids_list[-1].max() + 1 + ).split(1, dim=1) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + audio_idx += 1 + video_idx += 1 + # move to the next token + idx += len(new_src_item) - new_src_item_len + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + mrope_position_delta = torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item) + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + def generate_audio(self, code, voice_type): token2wav_dev = self._module_device(self.token2wav) if isinstance(code, torch.Tensor): diff --git a/vllm_omni/model_executor/models/qwen2_5_omni_talker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py similarity index 100% rename from vllm_omni/model_executor/models/qwen2_5_omni_talker.py rename to vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py diff --git a/vllm_omni/model_executor/models/qwen2_5_omni_thinker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py similarity index 69% rename from vllm_omni/model_executor/models/qwen2_5_omni_thinker.py rename to vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py index 41036a56724..28107cb455e 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py @@ -1,7 +1,7 @@ +from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence -from copy import copy from functools import partial -from typing import Any, Optional, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch import torch.nn as nn @@ -12,7 +12,7 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal, SupportsPP from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs, @@ -24,7 +24,6 @@ Qwen2_5_VLVideoPixelInputs, ) from vllm.model_executor.models.qwen2_audio import ( - Qwen2AudioInputs, Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths, ) @@ -42,7 +41,7 @@ ModalityData, MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, + MultiModalKwargsItems, NestedTensors, ) from vllm.multimodal.parse import ( @@ -52,10 +51,19 @@ MultiModalDataItems, MultiModalDataParser, ) -from vllm.multimodal.processing import BaseMultiModalProcessor, PlaceholderFeaturesInfo, PromptReplacement, PromptUpdate +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + MultiModalPromptUpdates, + PlaceholderFeaturesInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, + ResolvedPromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens +from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm_omni.model_executor.layers.mrope import MRotaryEmbedding @@ -67,6 +75,27 @@ logger = init_logger(__name__) +class Qwen2_5OmniAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - na: Number of audios + - nmb: Number of mel bins + - msl: Maximum sequence length + - tsl: Total sequence length + """ + + type: Literal["audio_features"] + input_features: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("nmb", "tsl"), + ] + + feature_attention_mask: Annotated[ + torch.Tensor, + TensorShape("na", "msl"), + ] + + def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): audio_feature_lengths = hf_inputs.get("audio_feature_lengths", torch.empty((0,))) @@ -90,6 +119,39 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): ) +def create_qwen2_5_omni_thinker_field_factory( + spatial_merge_size: int, +) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, MultiModalFieldConfig]]: + def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): + audio_feature_lengths = hf_inputs.get("audio_feature_lengths", torch.empty((0,))) + + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_pixel_grid_sizes = image_grid_thw.prod(-1) + image_embed_grid_sizes = image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + video_embed_grid_sizes = video_grid_sizes // spatial_merge_size // spatial_merge_size + + num_videos = len(video_grid_sizes) + + return dict( + input_audio_features=MultiModalFieldConfig.flat_from_sizes("audio", audio_feature_lengths, dim=1), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + audio_feature_lengths=MultiModalFieldConfig.batched("audio"), + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_pixel_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_embed_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes("video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes("video", video_embed_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + second_per_grid_ts=MultiModalFieldConfig.batched("video"), + use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos), + ) + + return _qwen2_5_omni_thinker_field_config + + class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): def _parse_audio_data( self, @@ -166,10 +228,7 @@ def get_dummy_mm_data( "audio": self._get_dummy_audios(length=target_audio_length, num_audios=num_audios), "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images), "video": self._get_dummy_videos( - width=target_width, - height=target_height, - num_frames=target_num_frames, - num_videos=num_videos, + width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos ), } @@ -179,7 +238,10 @@ def get_dummy_mm_data( class Qwen2_5OmniThinkerMultiModalProcessor(BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() - return Qwen2_5OmniThinkerMultiModalDataParser(target_sr=feature_extractor.sampling_rate) + return Qwen2_5OmniThinkerMultiModalDataParser( + spatial_merge_size=self.info.get_hf_config().vision_config.spatial_merge_size, + target_sr=feature_extractor.sampling_rate, + ) def _call_hf_processor( self, @@ -214,6 +276,14 @@ def _call_hf_processor( hf_inputs["input_audio_features"] = input_features if "audio_feature_lengths" not in hf_inputs and feature_attention_mask is not None: hf_inputs["audio_feature_lengths"] = feature_attention_mask.sum(-1) + + video_second_per_grid = hf_inputs.get("video_second_per_grid", None) + if video_second_per_grid is not None: + hf_inputs["second_per_grid_ts"] = video_second_per_grid + + use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) + hf_inputs["use_audio_in_video"] = torch.tensor(use_audio_in_video) + return hf_inputs def _get_mm_fields_config( @@ -221,66 +291,177 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2_5_omni_thinker_field_config(hf_inputs) + return create_qwen2_5_omni_thinker_field_factory(self.info.get_hf_config().vision_config.spatial_merge_size)( + hf_inputs + ) + + def _apply_prompt_updates( + self, + token_ids: list[int], + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]: + tokenizer = self.info.get_tokenizer() + + new_token_ids, match_result = self._apply_token_matches( + token_ids, + mm_prompt_updates, + ) + + # If the search text does not represent a special token, + # it may have different token IDs in the prompt, because + # the tokens may go across the boundaries of the search text. + # ---- + # e.g. when searching for "foo" in "food", if "food" itself makes + # up a token, then the token ID of "foo" will not appear at all + # ---- + # Since it is inefficient to search for all possible tokenizations + # of the search text in the prompt, we instead perform string-based + # updates on the decoded token IDs, then encode them back. + if not all(all(update_idx is not None for update_idx in update_idxs) for update_idxs in match_result.values()): + new_text, match_result = self._apply_text_matches( + decode_tokens(tokenizer, token_ids), + mm_prompt_updates, + ) + + new_token_ids = encode_tokens( + tokenizer, + new_text, + add_special_tokens=False, + ) + + matched_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]](list) + unmatched_audio_items = [] + + for modality, update_idxs in match_result.items(): + for item_idx, update_idx in enumerate(update_idxs): + if update_idx is None: + # Check if this is audio that might be embedded in video + if modality == "audio": + unmatched_audio_items.append(item_idx) + continue + else: + assert False, f"Failed to apply prompt replacement for mm_items[{modality!r}][{item_idx}]" + + matched_updates[modality].append([mm_prompt_updates[modality][item_idx][update_idx]]) + + # If there are unmatched audio items, check if we can derive them from video + if unmatched_audio_items: + # Check if video exists in the original updates (not just matched ones) + if "video" in mm_prompt_updates: + num_videos = len(mm_prompt_updates.get("video", [])) + num_audios = len(mm_prompt_updates.get("audio", [])) + # If counts match, audio might be embedded in video - skip assertion + # The placeholders will be derived from video in _maybe_apply_prompt_updates + if num_audios == num_videos and "video" in matched_updates: + pass # Will be handled by deriving from video placeholders + else: + # Audio items exist but can't be matched + # Check if video was also matched + if "video" not in matched_updates: + # Neither audio nor video matched - this is a real error + assert False, ( + f"Failed to apply prompt replacement for {len(unmatched_audio_items)} " + f"audio item(s). Audio items cannot be matched in prompt." + ) + elif num_audios != num_videos: + # Audio and video counts don't match + assert False, ( + f"Failed to apply prompt replacement for {len(unmatched_audio_items)} " + f"audio item(s). Audio items cannot be matched in prompt and " + f"audio count ({num_audios}) does not match video count ({num_videos})." + ) + else: + # Video matched but audio didn't, and counts match - allow it + pass + else: + # Audio items can't be matched and there are no video items + assert False, ( + f"Failed to apply prompt replacement for {len(unmatched_audio_items)} " + f"audio item(s). Audio items cannot be matched in prompt." + ) + + placeholders = self._find_mm_placeholders( + new_token_ids, + dict(matched_updates), + ) + + return new_token_ids, placeholders def _maybe_apply_prompt_updates( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], prompt_ids: list[int], - mm_kwargs: MultiModalKwargs, + mm_kwargs: MultiModalKwargsItems, + mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: """ Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. """ - unbound_prompt_updates = self._get_prompt_updates( - mm_items, - hf_processor_mm_kwargs, - mm_kwargs, - ) - mm_prompt_updates = self._bind_and_group_updates(unbound_prompt_updates) - mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) - use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False) + use_audio_in_video = False + if "video" in mm_kwargs: + video_items = [item for item in mm_kwargs["video"] if item is not None] + # only check video items (if there are any) + if video_items: + use_audio_in_video = all(item["use_audio_in_video"].data for item in video_items) if is_update_applied: mm_placeholders = self._find_mm_placeholders( - mm_prompt_updates, prompt_ids, + mm_prompt_updates, + ) + self._validate_mm_placeholders( + mm_placeholders, mm_item_counts, ) - self._validate_mm_placeholders(mm_placeholders, mm_item_counts, use_audio_in_video=use_audio_in_video) - tokenizer = self.info.get_tokenizer() - prompt = decode_tokens(tokenizer, prompt_ids) else: - ( - prompt_ids, - prompt, + # When use_audio_in_video=True, audio tokens are not in the prompt, + # so we need to filter out audio updates before applying replacements + if use_audio_in_video and "audio" in mm_prompt_updates: + # Remove audio from prompt updates (it won't match anything) + filtered_updates = {k: v for k, v in mm_prompt_updates.items() if k != "audio"} + prompt_ids, mm_placeholders = self._apply_prompt_updates( + prompt_ids, + filtered_updates, + ) + # Derive audio placeholders from video placeholders + mm_placeholders = self._derive_audio_from_video_placeholders(mm_placeholders, mm_prompt_updates) + else: + # Apply prompt updates normally + # _apply_prompt_updates will handle unmatched audio items + # when video exists with matching counts + prompt_ids, mm_placeholders = self._apply_prompt_updates( + prompt_ids, + mm_prompt_updates, + ) + + # Check if audio placeholders are missing but should exist + # This can happen when audio items are embedded in video tokens + if "audio" in mm_prompt_updates and "audio" not in mm_placeholders and "video" in mm_placeholders: + num_audios = len(mm_prompt_updates.get("audio", [])) + num_videos = len(mm_placeholders.get("video", [])) + # If counts match, derive audio placeholders from video + if num_audios == num_videos: + mm_placeholders = self._derive_audio_from_video_placeholders(mm_placeholders, mm_prompt_updates) + self._validate_mm_placeholders( mm_placeholders, - ) = self._apply_prompt_updates( - prompt_ids, - mm_prompt_updates, mm_item_counts, ) - self._validate_mm_placeholders(mm_placeholders, mm_item_counts, use_audio_in_video=use_audio_in_video) tokenizer = self.info.get_tokenizer() prompt = decode_tokens(tokenizer, prompt_ids) - if use_audio_in_video: - mm_kwargs["use_audio_in_video"] = True - return prompt_ids, prompt, mm_placeholders def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() @@ -294,8 +475,9 @@ def _get_prompt_updates( image_token_id = vocab[image_token] video_token_id = vocab[video_token] - audio_feature_lengths = out_mm_kwargs.get("audio_feature_lengths") - feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") + out_mm_data = out_mm_kwargs.get_data() + audio_feature_lengths = out_mm_data.get("audio_feature_lengths") + feature_attention_mask = out_mm_data.get("feature_attention_mask") if audio_feature_lengths is None and feature_attention_mask is None: audio_output_lengths = [] elif audio_feature_lengths is not None: @@ -323,7 +505,7 @@ def get_replacement_qwen2_audio(item_idx: int): return [audio_token_id] * num_features def get_replacement_qwen2_vision(item_idx: int, modality: str): - grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + grid_thw = out_mm_data[f"{modality}_grid_thw"][item_idx] assert isinstance(grid_thw, torch.Tensor) merge_length = image_processor.merge_size**2 @@ -337,7 +519,7 @@ def get_replacement_qwen2_use_audio_in_video(item_idx: int): nonlocal audio_in_video_item_idx audio_num_features = audio_output_lengths[audio_in_video_item_idx + item_idx] - video_grid_thw = out_mm_kwargs["video_grid_thw"][item_idx] + video_grid_thw = out_mm_data["video_grid_thw"][item_idx] audio_in_video_item_idx += 1 @@ -347,13 +529,15 @@ def get_replacement_qwen2_use_audio_in_video(item_idx: int): else: video_second_per_grid_t = 1.0 - return MRotaryEmbedding.omni_get_updates_use_audio_in_video( + placeholder = MRotaryEmbedding.omni_get_updates_use_audio_in_video( thinker_config=thinker_config, audio_len=audio_num_features, video_grid_thw=video_grid_thw, video_second_per_grid_t=video_second_per_grid_t, ) + return PromptUpdateDetails.select_token_id(placeholder, embed_token_id=video_token_id) + video_replacement_fn = ( get_replacement_qwen2_use_audio_in_video if use_audio_in_video @@ -378,6 +562,50 @@ def get_replacement_qwen2_use_audio_in_video(item_idx: int): ), ] + def _derive_audio_from_video_placeholders( + self, + placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], + mm_prompt_updates: MultiModalPromptUpdates, + ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: + """ + Helper to derive audio placeholders from video placeholders when + use_audio_in_video=True. + """ + if "video" not in placeholders: + return placeholders + + # Validate audio and video counts match + num_videos = len(placeholders["video"]) + num_audios = len(mm_prompt_updates.get("audio", [])) + if num_audios != num_videos: + raise ValueError( + f"use_audio_in_video requires equal number of audio and video items, got {num_audios=}, {num_videos=}" + ) + + tokenizer = self.info.get_tokenizer() + processor = self.info.get_hf_processor() + audio_token_id = tokenizer.get_vocab()[processor.audio_token] + + result_placeholders = dict(placeholders) + audio_placeholders = [] + + # Each video is paired with one audio + for video_idx, video_placeholder in enumerate(placeholders["video"]): + # Create is_embed mask selecting only audio tokens + audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id + + audio_placeholder = PlaceholderFeaturesInfo( + modality="audio", + item_idx=video_idx, + start_idx=video_placeholder.start_idx, + tokens=video_placeholder.tokens, + is_embed=audio_is_embed, + ) + audio_placeholders.append(audio_placeholder) + + result_placeholders["audio"] = audio_placeholders + return result_placeholders + def _apply_hf_processor_main( self, prompt: Union[str, list[int]], @@ -436,30 +664,19 @@ def _apply_hf_processor_mm_only( return mm_processed_data - def _validate_mm_placeholders( - self, - mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], - mm_item_counts: Mapping[str, int], - use_audio_in_video: bool = False, - ) -> None: - if use_audio_in_video: - mm_item_counts = copy(mm_item_counts) - if "video" in mm_item_counts: - assert "audio" in mm_item_counts - mm_item_counts["audio"] -= mm_item_counts["video"] - super()._validate_mm_placeholders(mm_placeholders, mm_item_counts) - class Qwen2_5OmniConditionalGenerationMixin: def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str, dim: int = 0) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): + if dim == 0: + return mm_input.reshape(-1, *mm_input.shape[2:]) return torch.concat(list(mm_input), dim=dim) else: return torch.concat(mm_input, dim=dim) - def _parse_and_validate_audio_input(self, **kwargs: object) -> Optional[Qwen2AudioInputs]: + def _parse_and_validate_audio_input(self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]: input_audio_features = kwargs.pop("input_audio_features", None) audio_feature_lengths = kwargs.pop("audio_feature_lengths", None) feature_attention_mask = kwargs.pop("feature_attention_mask", None) @@ -472,7 +689,8 @@ def _parse_and_validate_audio_input(self, **kwargs: object) -> Optional[Qwen2Aud ) if not isinstance(input_audio_features, (torch.Tensor, list)): raise ValueError(f"Incorrect type of audio input features. Got type: {type(input_audio_features)}") - return Qwen2AudioInputs( + return Qwen2_5OmniAudioFeatureInputs( + type="audio_features", input_features=input_audio_features, audio_feature_lengths=audio_feature_lengths, feature_attention_mask=feature_attention_mask, @@ -497,9 +715,7 @@ def _parse_and_validate_image_input( raise ValueError(f"Incorrect type of image pixel values. Got type: {type(pixel_values)}") return Qwen2_5_VLImagePixelInputs( - type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, + type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw ) if image_embeds is not None: @@ -509,9 +725,7 @@ def _parse_and_validate_image_input( if not isinstance(image_embeds, torch.Tensor): raise ValueError(f"Incorrect type of image embeddings. Got type: {type(image_embeds)}") return Qwen2_5_VLImageEmbeddingInputs( - type="image_embeds", - image_embeds=image_embeds, - image_grid_thw=image_grid_thw, + type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw ) def _parse_and_validate_video_input( @@ -542,14 +756,12 @@ def _parse_and_validate_video_input( if not isinstance(video_embeds, torch.Tensor): raise ValueError(f"Incorrect type of video embeddings. Got type: {type(video_embeds)}") return Qwen2_5_VLVideoEmbeddingInputs( - type="video_embeds", - video_embeds=video_embeds, - video_grid_thw=video_grid_thw, + type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw ) def _process_audio_input( self, - audio_input: Qwen2AudioInputs, + audio_input: Qwen2_5OmniAudioFeatureInputs, audio_hashes: list[str] = None, cached_audio_features: torch.Tensor = None, ) -> torch.Tensor: @@ -574,8 +786,7 @@ def _process_audio_input( feature_lens=audio_feature_lengths, aftercnn_lens=audio_feat_lengths, ) - audio_features = audio_outputs.last_hidden_state - return audio_features.split(audio_output_lengths.tolist()) + return audio_outputs.last_hidden_state.split(audio_output_lengths.tolist()) def _process_image_input(self, image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: if image_input["type"] == "image_embeds": @@ -619,7 +830,7 @@ def _process_video_input( dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, ) class Qwen2_5OmniThinkerForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, Qwen2_5OmniConditionalGenerationMixin + nn.Module, SupportsMultiModal, SupportsPP, Qwen2_5OmniConditionalGenerationMixin, SupportsMRoPE ): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -660,13 +871,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): "exactly same result as the transformers implementation in the audio tower part." ) - self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config) - self.visual = Qwen2_5_VisionTransformer( - vision_config=thinker_config.vision_config, - norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - ) + if multimodal_config.get_limit_per_prompt("audio"): + self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config) + else: + self.audio_tower = None + + if multimodal_config.get_limit_per_prompt("image") or multimodal_config.get_limit_per_prompt("video"): + self.visual = Qwen2_5_VisionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + # attn_backend_override=None, + ) + else: + self.visual = None self.quant_config = quant_config self.language_model = init_vllm_registered_model( vllm_config=vllm_config, diff --git a/vllm_omni/model_executor/models/qwen2_5_omni_token2wav.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py similarity index 100% rename from vllm_omni/model_executor/models/qwen2_5_omni_token2wav.py rename to vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py diff --git a/vllm_omni/model_executor/models/qwen2_old.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py similarity index 100% rename from vllm_omni/model_executor/models/qwen2_old.py rename to vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py index 25b16c6a5c3..8e6977950ea 100644 --- a/vllm_omni/model_executor/models/registry.py +++ b/vllm_omni/model_executor/models/registry.py @@ -2,26 +2,31 @@ _OMNI_MODELS = { "Qwen2_5OmniForConditionalGeneration": ( + "qwen2_5_omni", "qwen2_5_omni", "Qwen2_5OmniForConditionalGeneration", ), "Qwen2_5OmniThinkerModel": ( + "qwen2_5_omni", "qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration", ), # noqa: E501 "Qwen2_5OmniTalkerModel": ( + "qwen2_5_omni", "qwen2_5_omni_talker", "Qwen2_5OmniTalkerForConditionalGeneration", ), # noqa: E501 "Qwen2_5OmniToken2WavModel": ( + "qwen2_5_omni", "qwen2_5_omni_token2wav", "Qwen2_5OmniToken2WavForConditionalGenerationVLLM", ), "Qwen2_5OmniToken2WavDiTModel": ( + "qwen2_5_omni", "qwen2_5_omni_token2wav", "Qwen2_5OmniToken2WavModel", ), - "Qwen2ForCausalLM_old": ("qwen2_old", "Qwen2ForCausalLM"), # need to discuss + "Qwen2ForCausalLM_old": ("qwen2_5_omni", "qwen2_old", "Qwen2ForCausalLM"), # need to discuss } _VLLM_OMNI_MODELS = { @@ -41,10 +46,10 @@ }, **{ model_arch: _LazyRegisteredModel( - module_name=f"vllm_omni.model_executor.models.{mod_relname}", + module_name=f"vllm_omni.model_executor.models.{mod_folder}.{mod_relname}", class_name=cls_name, ) - for model_arch, (mod_relname, cls_name) in _OMNI_MODELS.items() + for model_arch, (mod_folder, mod_relname, cls_name) in _OMNI_MODELS.items() }, } ) diff --git a/vllm_omni/model_executor/models/utils.py b/vllm_omni/model_executor/models/utils.py index 602baacc329..f3854ad4eb0 100644 --- a/vllm_omni/model_executor/models/utils.py +++ b/vllm_omni/model_executor/models/utils.py @@ -1,3 +1,4 @@ +import torch from vllm.model_executor.models.utils import maybe_prefix @@ -6,3 +7,24 @@ def add_prefix_to_loaded_weights(weights: set[str], prefix: str) -> set[str]: Add a prefix to the names of the loaded weights. """ return {maybe_prefix(prefix, name) for name in weights} + + +def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]: + if lst.numel() == 0: + return [] + + # Move to CPU and convert to list once (High Speedup) + # using .item() inside a loop is very slow. + data_list = lst.detach().cpu().tolist() + + # Calculate max on the list or tensor (Tensor max is fast enough) + max_val = int(torch.max(lst).item()) + + # Pre-allocate buckets + ranges: list[list[int]] = [[] for _ in range((max_val // interval) + 1)] + + for num in data_list: + index = int(num // interval) + ranges[index].append(num) + + return ranges diff --git a/vllm_omni/model_executor/models/vision.py b/vllm_omni/model_executor/models/vision.py new file mode 100644 index 00000000000..850286a597d --- /dev/null +++ b/vllm_omni/model_executor/models/vision.py @@ -0,0 +1,23 @@ +import torch + + +def get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[int], + grid_hs: torch.Tensor, + grid_ws: torch.Tensor, +) -> torch.Tensor: + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten() + t_index_tensor = ( + torch.Tensor(t_index).to(llm_grid_h.device).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten() + ) + _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index c1a04d9b04f..ce2d281fb70 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -10,7 +10,12 @@ TALKER_CODEC_END_TOKEN_ID = 8294 -def thinker2talker(stage_list, engine_input_source, prompt: Union[OmniTokensPrompt, TextPrompt] = None): +def thinker2talker( + stage_list, + engine_input_source, + prompt: Union[OmniTokensPrompt, TextPrompt] = None, + requires_multimodal_data: bool = False, +): if not engine_input_source: raise ValueError("engine_input_source cannot be empty") source_stage_id = engine_input_source[0] @@ -48,7 +53,9 @@ def thinker2talker(stage_list, engine_input_source, prompt: Union[OmniTokensProm + [TALKER_CODEC_END_TOKEN_ID], additional_information=additional_information, multi_modal_data=( - multi_modal_data[thinker_output.request_id] if multi_modal_data is not None else None + multi_modal_data[thinker_output.request_id] + if requires_multimodal_data and multi_modal_data is not None + else None ), mm_processor_kwargs=None, ) diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 0a90dbdb8ee..55e5e1391ed 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -6,6 +6,7 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import supports_mrope from vllm.model_executor.models.interfaces_base import VllmModelForPooling from vllm.sampling_params import SamplingType from vllm.utils import LazyLoader, cdiv @@ -29,6 +30,42 @@ class OmniGPUModelRunner(GPUModelRunner): + def _init_mrope_positions(self, req_state: CachedRequestState): + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_feature in req_state.mm_features: + mm_item = mm_feature.data + if mm_item is None: + continue + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + # Check for use_audio_in_video + use_audio_in_video_value = mm_input.get("use_audio_in_video") + if use_audio_in_video_value is not None: + use_audio_in_video = bool(use_audio_in_video_value.item()) + + assert supports_mrope(self.get_model()), "M-RoPE support is not implemented." + + req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output.