Skip to content

Commit 2be8ec6

Browse files
authored
[Model] Add Ultravox support for multiple audio chunks (#7963)
1 parent e16fa99 commit 2be8ec6

File tree

3 files changed

+198
-115
lines changed

3 files changed

+198
-115
lines changed

examples/offline_inference_audio_language.py

+34-24
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,33 @@
1111
from vllm.assets.audio import AudioAsset
1212
from vllm.utils import FlexibleArgumentParser
1313

14-
# Input audio and question
15-
audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate
16-
question = "What is recited in the audio?"
14+
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
15+
question_per_audio_count = [
16+
"What is recited in the audio?",
17+
"What sport and what nursery rhyme are referenced?"
18+
]
1719

1820

1921
# Ultravox 0.3
20-
def run_ultravox(question):
22+
def run_ultravox(question, audio_count):
2123
model_name = "fixie-ai/ultravox-v0_3"
2224

2325
tokenizer = AutoTokenizer.from_pretrained(model_name)
2426
messages = [{
25-
'role': 'user',
26-
'content': f"<|reserved_special_token_0|>\n{question}"
27+
'role':
28+
'user',
29+
'content':
30+
"<|reserved_special_token_0|>\n" * audio_count + question
2731
}]
2832
prompt = tokenizer.apply_chat_template(messages,
2933
tokenize=False,
3034
add_generation_prompt=True)
3135

32-
llm = LLM(model=model_name)
36+
llm = LLM(model=model_name,
37+
enforce_eager=True,
38+
enable_chunked_prefill=False,
39+
max_model_len=8192,
40+
limit_mm_per_prompt={"audio": audio_count})
3341
stop_token_ids = None
3442
return llm, prompt, stop_token_ids
3543

@@ -44,7 +52,9 @@ def main(args):
4452
if model not in model_example_map:
4553
raise ValueError(f"Model type {model} is not supported.")
4654

47-
llm, prompt, stop_token_ids = model_example_map[model](question)
55+
audio_count = args.num_audios
56+
llm, prompt, stop_token_ids = model_example_map[model](
57+
question_per_audio_count[audio_count - 1], audio_count)
4858

4959
# We set temperature to 0.2 so that outputs can be different
5060
# even when all prompts are identical when running batch inference.
@@ -53,23 +63,18 @@ def main(args):
5363
stop_token_ids=stop_token_ids)
5464

5565
assert args.num_prompts > 0
56-
if args.num_prompts == 1:
57-
# Single inference
58-
inputs = {
59-
"prompt": prompt,
60-
"multi_modal_data": {
61-
"audio": audio_and_sample_rate
62-
},
63-
}
64-
65-
else:
66+
inputs = {
67+
"prompt": prompt,
68+
"multi_modal_data": {
69+
"audio": [
70+
asset.audio_and_sample_rate
71+
for asset in audio_assets[:audio_count]
72+
]
73+
},
74+
}
75+
if args.num_prompts > 1:
6676
# Batch inference
67-
inputs = [{
68-
"prompt": prompt,
69-
"multi_modal_data": {
70-
"audio": audio_and_sample_rate
71-
},
72-
} for _ in range(args.num_prompts)]
77+
inputs = [inputs] * args.num_prompts
7378

7479
outputs = llm.generate(inputs, sampling_params=sampling_params)
7580

@@ -92,6 +97,11 @@ def main(args):
9297
type=int,
9398
default=1,
9499
help='Number of prompts to run.')
100+
parser.add_argument("--num-audios",
101+
type=int,
102+
default=1,
103+
choices=[1, 2],
104+
help="Number of audio items per prompt.")
95105

96106
args = parser.parse_args()
97107
main(args)

tests/models/test_ultravox.py

+77-26
Original file line numberDiff line numberDiff line change
@@ -16,37 +16,32 @@
1616

1717
AudioTuple = Tuple[np.ndarray, int]
1818

19+
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
20+
HF_PLACEHOLDER = "<|audio|>"
21+
1922

2023
@pytest.fixture(scope="session")
21-
def audio_and_sample_rate():
24+
def audio_assets():
2225
from vllm.assets.audio import AudioAsset
23-
return AudioAsset("mary_had_lamb").audio_and_sample_rate
26+
return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
2427

2528

26-
@pytest.fixture
27-
def prompts_and_audios(audio_and_sample_rate):
28-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
29+
@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call"))
30+
def audio(request):
31+
from vllm.assets.audio import AudioAsset
32+
return AudioAsset(request.param)
2933

30-
vllm_placeholder = "<|reserved_special_token_0|>"
31-
hf_placeholder = "<|audio|>"
3234

33-
question = "What's in the audio?"
34-
vllm_prompt = tokenizer.apply_chat_template(
35-
[{
36-
'role': 'user',
37-
'content': f"{vllm_placeholder}\n{question}"
38-
}],
39-
tokenize=False,
40-
add_generation_prompt=True)
41-
hf_prompt = tokenizer.apply_chat_template(
42-
[{
43-
'role': 'user',
44-
'content': f"{hf_placeholder}\n{question}"
45-
}],
46-
tokenize=False,
47-
add_generation_prompt=True)
35+
def _get_prompt(audio_count, question, placeholder):
36+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
37+
placeholder = f"{placeholder}\n" * audio_count
4838

49-
return [(vllm_prompt, hf_prompt, audio_and_sample_rate)]
39+
return tokenizer.apply_chat_template([{
40+
'role': 'user',
41+
'content': f"{placeholder}{question}"
42+
}],
43+
tokenize=False,
44+
add_generation_prompt=True)
5045

5146

5247
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
@@ -134,15 +129,71 @@ def process(hf_inputs: BatchEncoding):
134129
)
135130

136131

132+
def run_multi_audio_test(
133+
vllm_runner: Type[VllmRunner],
134+
prompts_and_audios: List[Tuple[str, List[AudioTuple]]],
135+
model: str,
136+
*,
137+
dtype: str,
138+
max_tokens: int,
139+
num_logprobs: int,
140+
tensor_parallel_size: int,
141+
distributed_executor_backend: Optional[str] = None,
142+
):
143+
with vllm_runner(model,
144+
dtype=dtype,
145+
tensor_parallel_size=tensor_parallel_size,
146+
distributed_executor_backend=distributed_executor_backend,
147+
enforce_eager=True,
148+
limit_mm_per_prompt={
149+
"audio":
150+
max((len(audio) for _, audio in prompts_and_audios))
151+
}) as vllm_model:
152+
vllm_outputs = vllm_model.generate_greedy_logprobs(
153+
[prompt for prompt, _ in prompts_and_audios],
154+
max_tokens,
155+
num_logprobs=num_logprobs,
156+
audios=[audios for _, audios in prompts_and_audios])
157+
158+
# The HuggingFace model doesn't support multiple audios yet, so
159+
# just assert that some tokens were generated.
160+
assert all(tokens for tokens, *_ in vllm_outputs)
161+
162+
137163
@pytest.mark.parametrize("dtype", ["half"])
138164
@pytest.mark.parametrize("max_tokens", [128])
139165
@pytest.mark.parametrize("num_logprobs", [5])
140-
def test_models(hf_runner, vllm_runner, prompts_and_audios, dtype: str,
141-
max_tokens: int, num_logprobs: int) -> None:
166+
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
167+
num_logprobs: int) -> None:
168+
169+
vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER)
170+
hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
142171
run_test(
143172
hf_runner,
144173
vllm_runner,
145-
prompts_and_audios,
174+
[(vllm_prompt, hf_prompt, audio.audio_and_sample_rate)],
175+
MODEL_NAME,
176+
dtype=dtype,
177+
max_tokens=max_tokens,
178+
num_logprobs=num_logprobs,
179+
tensor_parallel_size=1,
180+
)
181+
182+
183+
@pytest.mark.parametrize("dtype", ["half"])
184+
@pytest.mark.parametrize("max_tokens", [128])
185+
@pytest.mark.parametrize("num_logprobs", [5])
186+
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
187+
max_tokens: int,
188+
num_logprobs: int) -> None:
189+
190+
vllm_prompt = _get_prompt(len(audio_assets),
191+
"Describe each of the audios above.",
192+
VLLM_PLACEHOLDER)
193+
run_multi_audio_test(
194+
vllm_runner,
195+
[(vllm_prompt, [audio.audio_and_sample_rate
196+
for audio in audio_assets])],
146197
MODEL_NAME,
147198
dtype=dtype,
148199
max_tokens=max_tokens,

0 commit comments

Comments
 (0)