Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/offline_inference/qwen2_5_omni/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,11 @@ Then run the command below.
```bash
bash run_single_prompt.sh
```

### FAQ

If you encounter error about backend of librosa, try to install ffmpeg with command below.
```
sudo apt update
sudo apt install ffmpeg
```
339 changes: 194 additions & 145 deletions examples/offline_inference/qwen2_5_omni/end2end.py
Original file line number Diff line number Diff line change
@@ -1,166 +1,145 @@
import argparse
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use vLLM-omni for running offline inference
with the correct prompt format on Qwen2.5-Omni
"""

import os
import os as _os_env_toggle
import random
from typing import NamedTuple

import numpy as np
import soundfile as sf
import torch
from utils import make_omni_prompt
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.multimodal.image import convert_image_mode
from vllm.sampling_params import SamplingParams
from vllm.utils import FlexibleArgumentParser

from vllm_omni.entrypoints.omni_llm import OmniLLM

_os_env_toggle.environ["VLLM_USE_V1"] = "1"
from vllm_omni import OmniLLM

SEED = 42
# Set all random seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Make PyTorch deterministic
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set environment variables for deterministic behavior
os.environ["PYTHONHASHSEED"] = str(SEED)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
class QueryResult(NamedTuple):
inputs: dict
limit_mm_per_prompt: dict[str, int]


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
required=True,
help="Path to merged model directory (will be created if downloading).",
)
parser.add_argument("--thinker-model", type=str, default=None)
parser.add_argument("--talker-model", type=str, default=None)
parser.add_argument("--code2wav-model", type=str, default=None)
parser.add_argument(
"--hf-hub-id",
default="Qwen/Qwen2.5-Omni-7B",
help="Hugging Face repo id to download if needed.",
)
parser.add_argument("--hf-revision", default=None, help="Optional HF revision (branch/tag/commit).")
parser.add_argument("--prompts", nargs="+", default=None, help="Input text prompts.")
parser.add_argument("--voice-type", default="default", help="Voice type, e.g., m02, f030, default.")
parser.add_argument(
"--code2wav-dir",
default=None,
help="Path to code2wav folder (contains spk_dict.pt).",
)
parser.add_argument("--dit-ckpt", default=None, help="Path to DiT checkpoint file (e.g., dit.pt).")
parser.add_argument("--bigvgan-ckpt", default=None, help="Path to BigVGAN checkpoint file.")
parser.add_argument("--dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"])
parser.add_argument("--max-model-len", type=int, default=32768)
parser.add_argument(
"--init-sleep-seconds",
type=int,
default=20,
help="Sleep seconds after starting each stage process to allow initialization (default: 20)",
)
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.

parser.add_argument("--thinker-only", action="store_true")
parser.add_argument("--text-only", action="store_true")
parser.add_argument("--do-wave", action="store_true")
parser.add_argument(
"--prompt_type",
choices=[
"text",
"audio",
"audio-long",
"audio-long-chunks",
"audio-long-expand-chunks",
"image",
"video",
"video-frames",
"audio-in-video",
"audio-in-video-v2",
"audio-multi-round",
"badcase-vl",
"badcase-text",
"badcase-image-early-stop",
"badcase-two-audios",
"badcase-two-videos",
"badcase-multi-round",
"badcase-voice-type",
"badcase-voice-type-v2",
"badcase-audio-tower-1",
"badcase-audio-only",
],
default="text",
)
parser.add_argument("--use-torchvision", action="store_true")
parser.add_argument("--tokenize", action="store_true")
parser.add_argument(
"--output-wav",
default="output.wav",
help="[Deprecated] Output wav directory (use --output-dir).",
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech."
)


def get_text_query(question: str = None) -> QueryResult:
if question is None:
question = "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words."
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
parser.add_argument(
"--output-dir",
default="outputs",
help="Output directory to save text and wav files together.",
return QueryResult(
inputs={
"prompt": prompt,
},
limit_mm_per_prompt={},
)
parser.add_argument(
"--thinker-hidden-states-dir",
default="thinker_hidden_states",
help="Path to thinker hidden states directory.",


def get_mixed_modalities_query() -> QueryResult:
question = "What is recited in the audio? What is the content of this image? Why is this video funny?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
"<|vision_bos|><|IMAGE|><|vision_eos|>"
"<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
parser.add_argument(
"--batch-timeout",
type=int,
default=5,
help="Timeout for batching in seconds (default: 5)",
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
"image": convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB"),
"video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
},
},
limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
)
parser.add_argument(
"--init-timeout",
type=int,
default=300,
help="Timeout for initializing stages in seconds (default: 300)",


def get_use_audio_in_video_query() -> QueryResult:
question = "Describe the content of the video, then convert what the baby say into text."
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
parser.add_argument(
"--shm-threshold-bytes",
type=int,
default=65536,
help="Threshold for using shared memory in bytes (default: 65536)",
asset = VideoAsset(name="baby_reading", num_frames=16)
audio = asset.get_audio(sampling_rate=16000)

return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"video": asset.np_ndarrays,
"audio": audio,
},
"mm_processor_kwargs": {
"use_audio_in_video": True,
},
},
limit_mm_per_prompt={"audio": 1, "video": 1},
)
parser.add_argument(
"--enable-stats",
action="store_true",
default=False,
help="Enable writing detailed statistics (default: disabled)",


def get_multi_audios_query() -> QueryResult:
question = "Are these two audio clips the same?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
"<|audio_bos|><|AUDIO|><|audio_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
parser.add_argument(
"--txt-prompts",
type=str,
default=None,
help="Path to a .txt file with one prompt per line (preferred).",
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": [
AudioAsset("winning_call").audio_and_sample_rate,
AudioAsset("mary_had_lamb").audio_and_sample_rate,
],
},
},
limit_mm_per_prompt={
"audio": 2,
},
)
args = parser.parse_args()
return args


def main():
args = parse_args()
model_name = args.model
try:
# Preferred: load from txt file (one prompt per line)
if getattr(args, "txt_prompts", None) and args.prompt_type == "text":
with open(args.txt_prompts, encoding="utf-8") as f:
lines = [ln.strip() for ln in f.readlines()]
args.prompts = [ln for ln in lines if ln != ""]
print(f"[Info] Loaded {len(args.prompts)} prompts from {args.txt_prompts}")
except Exception as e:
print(f"[Error] Failed to load prompts: {e}")
raise

if args.prompts is None:
raise ValueError("No prompts provided. Use --prompts ... or --txt-prompts <file.txt> (with --prompt_type text)")
query_map = {
"mixed_modalities": get_mixed_modalities_query,
"use_audio_in_video": get_use_audio_in_video_query,
"multi_audios": get_multi_audios_query,
"text": get_text_query,
}


def main(args):
model_name = "Qwen/Qwen2.5-Omni-7B"
query_result = query_map[args.query_type]()

omni_llm = OmniLLM(
model=model_name,
log_stats=args.enable_stats,
Expand Down Expand Up @@ -205,8 +184,16 @@ def main():
code2wav_sampling_params,
]

prompt = [make_omni_prompt(args, prompt) for prompt in args.prompts]
omni_outputs = omni_llm.generate(prompt, sampling_params_list)
if args.txt_prompts is None:
prompts = [query_result.inputs for _ in range(args.num_prompts)]
else:
assert args.query_type == "text", "txt-prompts is only supported for text query type"
with open(args.txt_prompts, encoding="utf-8") as f:
lines = [ln.strip() for ln in f.readlines()]
prompts = [get_text_query(ln).inputs for ln in lines if ln != ""]
print(f"[Info] Loaded {len(prompts)} prompts from {args.txt_prompts}")

omni_outputs = omni_llm.generate(prompts, sampling_params_list)

# Determine output directory: prefer --output-dir; fallback to --output-wav
output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav
Expand All @@ -217,7 +204,7 @@ def main():
request_id = int(output.request_id)
text_output = output.outputs[0].text
# Save aligned text file per request
prompt_text = args.prompts[request_id]
prompt_text = prompts[request_id]["prompt"]
out_txt = os.path.join(output_dir, f"{request_id:05d}.txt")
lines = []
lines.append("Prompt:\n")
Expand All @@ -239,5 +226,67 @@ def main():
print(f"Request ID: {request_id}, Saved audio to {output_wav}")


def parse_args():
parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models")
parser.add_argument(
"--query-type",
"-q",
type=str,
default="mixed_modalities",
choices=query_map.keys(),
help="Query type.",
)
parser.add_argument(
"--enable-stats",
action="store_true",
default=False,
help="Enable writing detailed statistics (default: disabled)",
)
parser.add_argument(
"--init-sleep-seconds",
type=int,
default=20,
help="Sleep seconds after starting each stage process to allow initialization (default: 20)",
)
parser.add_argument(
"--batch-timeout",
type=int,
default=5,
help="Timeout for batching in seconds (default: 5)",
)
parser.add_argument(
"--init-timeout",
type=int,
default=300,
help="Timeout for initializing stages in seconds (default: 300)",
)
parser.add_argument(
"--shm-threshold-bytes",
type=int,
default=65536,
help="Threshold for using shared memory in bytes (default: 65536)",
)
parser.add_argument(
"--output-wav",
default="output_audio",
help="[Deprecated] Output wav directory (use --output-dir).",
)
parser.add_argument(
"--num-prompts",
type=int,
default=1,
help="Number of prompts to generate.",
)
parser.add_argument(
"--txt-prompts",
type=str,
default=None,
help="Path to a .txt file with one prompt per line (preferred).",
)

return parser.parse_args()


if __name__ == "__main__":
main()
args = parse_args()
main(args)
Loading