From c0cab49882069045c7d2a0c6bb457c92860fe035 Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Wed, 28 Aug 2024 16:56:16 +0000 Subject: [PATCH 1/2] [Model] Add support for multiple audio chunks for Ultravox --- examples/offline_inference_audio_language.py | 58 ++++--- tests/models/test_ultravox.py | 100 ++++++++---- vllm/model_executor/models/ultravox.py | 151 +++++++++++-------- 3 files changed, 191 insertions(+), 118 deletions(-) diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index 56ce8646c20c9..1c6ac06123bbb 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -11,25 +11,33 @@ from vllm.assets.audio import AudioAsset from vllm.utils import FlexibleArgumentParser -# Input audio and question -audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate -question = "What is recited in the audio?" +audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] +question_per_audio_count = [ + "What is recited in the audio?", + "What sport and what nursery rhyme are referenced?" +] # Ultravox 0.3 -def run_ultravox(question): +def run_ultravox(question, audio_count): model_name = "fixie-ai/ultravox-v0_3" tokenizer = AutoTokenizer.from_pretrained(model_name) messages = [{ - 'role': 'user', - 'content': f"<|reserved_special_token_0|>\n{question}" + 'role': + 'user', + 'content': + "<|reserved_special_token_0|>\n" * audio_count + question }] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - llm = LLM(model=model_name) + llm = LLM(model=model_name, + enforce_eager=True, + enable_chunked_prefill=False, + max_model_len=8192, + limit_mm_per_prompt={"audio": audio_count}) stop_token_ids = None return llm, prompt, stop_token_ids @@ -44,7 +52,9 @@ def main(args): if model not in model_example_map: raise ValueError(f"Model type {model} is not supported.") - llm, prompt, stop_token_ids = model_example_map[model](question) + audio_count = args.num_audios + llm, prompt, stop_token_ids = model_example_map[model]( + question_per_audio_count[audio_count - 1], audio_count) # We set temperature to 0.2 so that outputs can be different # even when all prompts are identical when running batch inference. @@ -53,23 +63,18 @@ def main(args): stop_token_ids=stop_token_ids) assert args.num_prompts > 0 - if args.num_prompts == 1: - # Single inference - inputs = { - "prompt": prompt, - "multi_modal_data": { - "audio": audio_and_sample_rate - }, - } - - else: + inputs = { + "prompt": prompt, + "multi_modal_data": { + "audio": [ + asset.audio_and_sample_rate + for asset in audio_assets[:audio_count] + ] + }, + } + if args.num_prompts > 1: # Batch inference - inputs = [{ - "prompt": prompt, - "multi_modal_data": { - "audio": audio_and_sample_rate - }, - } for _ in range(args.num_prompts)] + inputs = [inputs] * args.num_prompts outputs = llm.generate(inputs, sampling_params=sampling_params) @@ -92,6 +97,11 @@ def main(args): type=int, default=1, help='Number of prompts to run.') + parser.add_argument("--num-audios", + type=int, + default=1, + choices=[1, 2], + help="Number of audio items per prompt.") args = parser.parse_args() main(args) diff --git a/tests/models/test_ultravox.py b/tests/models/test_ultravox.py index 98de10aa08408..aa60cc76b3943 100644 --- a/tests/models/test_ultravox.py +++ b/tests/models/test_ultravox.py @@ -18,36 +18,21 @@ AudioTuple = Tuple[np.ndarray, int] +AUDIO_ASSETS = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] +VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" +HF_PLACEHOLDER = "<|audio|>" -@pytest.fixture(scope="session") -def audio_and_sample_rate(): - return AudioAsset("mary_had_lamb").audio_and_sample_rate - -@pytest.fixture -def prompts_and_audios(audio_and_sample_rate): +def _get_prompt(audio_count, question, placeholder): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + placeholder = f"{placeholder}\n" * audio_count - vllm_placeholder = "<|reserved_special_token_0|>" - hf_placeholder = "<|audio|>" - - question = "What's in the audio?" - vllm_prompt = tokenizer.apply_chat_template( - [{ - 'role': 'user', - 'content': f"{vllm_placeholder}\n{question}" - }], - tokenize=False, - add_generation_prompt=True) - hf_prompt = tokenizer.apply_chat_template( - [{ - 'role': 'user', - 'content': f"{hf_placeholder}\n{question}" - }], - tokenize=False, - add_generation_prompt=True) - - return [(vllm_prompt, hf_prompt, audio_and_sample_rate)] + return tokenizer.apply_chat_template([{ + 'role': 'user', + 'content': f"{placeholder}{question}" + }], + tokenize=False, + add_generation_prompt=True) def vllm_to_hf_output(vllm_output: Tuple[List[int], str, @@ -134,15 +119,72 @@ def process(hf_inputs: BatchEncoding): ) +def run_multi_audio_test( + vllm_runner: Type[VllmRunner], + prompts_and_audios: List[Tuple[str, List[AudioTuple]]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + limit_mm_per_prompt={ + "audio": + max((len(audio) for _, audio in prompts_and_audios)) + }) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + [prompt for prompt, _ in prompts_and_audios], + max_tokens, + num_logprobs=num_logprobs, + audios=[audios for _, audios in prompts_and_audios]) + + # The HuggingFace model doesn't support multiple audios yet, so + # just assert that some tokens were generated. + assert all(tokens for tokens, *_ in vllm_outputs) + + @pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("audio", AUDIO_ASSETS) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, prompts_and_audios, dtype: str, - max_tokens: int, num_logprobs: int) -> None: +def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, + num_logprobs: int) -> None: + + vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER) + hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER) run_test( hf_runner, vllm_runner, - prompts_and_audios, + [(vllm_prompt, hf_prompt, audio.audio_and_sample_rate)], + MODEL_NAME, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("audios", [AUDIO_ASSETS]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_with_multiple_audios(vllm_runner, audios, dtype: str, + max_tokens: int, + num_logprobs: int) -> None: + + vllm_prompt = _get_prompt(len(audios), + "Describe each of the audios above.", + VLLM_PLACEHOLDER) + run_multi_audio_test( + vllm_runner, + [(vllm_prompt, [audio.audio_and_sample_rate for audio in audios])], MODEL_NAME, dtype=dtype, max_tokens=max_tokens, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 03d6223225511..2ebcee776dc49 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -29,12 +29,12 @@ QuantizationConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal -from vllm.model_executor.models.utils import (filter_weights, +from vllm.model_executor.models.utils import (filter_weights, flatten_bn, init_vllm_registered_model, merge_multimodal_embeddings) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalInputs, NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SamplerOutput, SequenceData @@ -48,13 +48,13 @@ class UltravoxAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: Union[torch.Tensor, List[torch.Tensor]] - """Shape: `(batch_size * num_audios, 80, M)""" + data: NestedTensors + """Shape: `(batch_size, num_audios, 80, M)""" class UltravoxAudioEmbeddingInputs(TypedDict): type: Literal["audio_embeds"] - data: torch.Tensor + data: NestedTensors UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, @@ -85,24 +85,33 @@ def dummy_data_for_ultravox( audio_count = mm_counts["audio"] - audio_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [ - _AUDIO_PLACEHOLDER_TOKEN - ]) * get_ultravox_max_audio_tokens(ctx) * audio_count + audio_placeholder = array( + VLLM_TOKEN_ID_ARRAY_TYPE, + [_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx) + + # Add a separator between each chunk. + audio_token_ids = (audio_placeholder + + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - len(audio_token_ids)) audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1) - mm_dict = { - "audio": - audio_and_sr if audio_count == 1 else [audio_and_sr] * audio_count - } + mm_dict = {"audio": [audio_and_sr] * audio_count} return (SequenceData(audio_token_ids + other_token_ids), mm_dict) def input_mapper_for_ultravox(ctx: InputContext, data: object): - if isinstance(data, tuple): - (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data) + if not isinstance(data, list): + data = [data] + + audio_features = [] + for audio_input in data: + if not isinstance(audio_input, tuple): + raise NotImplementedError( + f"Unsupported data type: {type(audio_input)}") + + (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input) feature_extractor = whisper_feature_extractor(ctx) if sr != feature_extractor.sampling_rate: @@ -116,15 +125,14 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): # Not enough audio; pad it. audio = np.pad(audio, (0, minimum_audio_length - len(audio))) - return MultiModalInputs({ - "audio_features": - feature_extractor(audio, - sampling_rate=sr, - padding="longest", - return_tensors="pt")["input_features"] - }) + single_audio_features = feature_extractor( + audio, sampling_rate=sr, padding="longest", + return_tensors="pt")["input_features"] - raise NotImplementedError(f"Unsupported data type: {type(data)}") + # Remove the batch dimension because we're wrapping it in a list. + audio_features.append(single_audio_features.squeeze(0)) + + return MultiModalInputs({"audio_features": audio_features}) def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): @@ -133,25 +141,31 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): return llm_inputs feature_extractor = whisper_feature_extractor(ctx) - audio_data, sample_rate = multi_modal_data["audio"] - - audio_length = audio_data.shape[0] - if sample_rate != feature_extractor.sampling_rate: - # Account for resampling. - adjustment = feature_extractor.sampling_rate / sample_rate - audio_length = math.ceil(adjustment * audio_length) - - feature_extractor_output_length = math.ceil( - (audio_length - - (feature_extractor.hop_length - 1)) / feature_extractor.hop_length) - - uv_config = ctx.get_hf_config(UltravoxConfig) - audio_num_tokens = min( - max( - 1, - math.ceil(feature_extractor_output_length / - (uv_config.stack_factor * 2))), - get_ultravox_max_audio_tokens(ctx)) + audios = multi_modal_data["audio"] + if not isinstance(audios, list): + audios = [audios] + + audio_token_counts = [] + for audio_data, sample_rate in audios: + audio_length = audio_data.shape[0] + if sample_rate != feature_extractor.sampling_rate: + # Account for resampling. + adjustment = feature_extractor.sampling_rate / sample_rate + audio_length = math.ceil(adjustment * audio_length) + + feature_extractor_output_length = math.ceil( + (audio_length - (feature_extractor.hop_length - 1)) / + feature_extractor.hop_length) + + uv_config = ctx.get_hf_config(UltravoxConfig) + audio_num_tokens = min( + max( + 1, + math.ceil(feature_extractor_output_length / + (uv_config.stack_factor * 2))), + get_ultravox_max_audio_tokens(ctx)) + audio_token_counts.append(audio_num_tokens) + tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( @@ -159,7 +173,7 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN, - repeat_count=audio_num_tokens, + repeat_count=audio_token_counts, ) # NOTE: Create a defensive copy of the original inputs @@ -333,45 +347,52 @@ def _parse_and_validate_audio_input( raise ValueError("Incorrect type of audio features. " f"Got type: {type(audio_features)}") - # Remove the N dimension until multiple audios are supported. - if isinstance(audio_features, torch.Tensor): - audio_features = audio_features.squeeze(1) - else: - audio_features = [t.squeeze(0) for t in audio_features] - return UltravoxAudioFeatureInputs(type="audio_features", data=audio_features) if audio_embeds is not None: - if not isinstance(audio_embeds, torch.Tensor): + if not isinstance(audio_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of audio embeds. " f"Got type: {type(audio_embeds)}") - # Remove the N dimension until multiple audios are supported. - audio_embeds = audio_embeds.squeeze(1) - return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) raise AssertionError("This line should be unreachable.") def _process_audio_input( - self, audio_input: UltravoxAudioInputs - ) -> Union[torch.Tensor, List[torch.Tensor]]: + self, audio_input: UltravoxAudioInputs) -> NestedTensors: if audio_input["type"] == "audio_embeds": return audio_input["data"] audio_features = audio_input["data"] - if isinstance(audio_features, list): - # TODO: Batch these through the encoder/projector instead of - # serializing them. - return [ - self._audio_features_to_embeddings( - features.unsqueeze(0)).squeeze(0) - for features in audio_features - ] - else: - return self._audio_features_to_embeddings(audio_features) + if isinstance(audio_features, torch.Tensor): + # Combine the B and N dimensions for the encoder/projector + flattened = flatten_bn(audio_features) + flattened_embeddings = self._audio_features_to_embeddings( + flattened) + + # Restore the original dimensions + embeddings = flattened_embeddings.unflatten( + 0, audio_features.shape[:2]) + return embeddings + + result = [] + # TODO: Batch heterogeneous tensors through the encoder/projector + for audio_features_item in audio_features: + if isinstance(audio_features_item, torch.Tensor): + result.append( + self._audio_features_to_embeddings(audio_features_item)) + else: + embeddings = [ + # Add a batch dimension to embed it, then remove it. + self._audio_features_to_embeddings(tensor.unsqueeze(0) + ).squeeze(0) + for tensor in audio_features_item + ] + result.append(embeddings) + + return result def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], @@ -388,7 +409,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, with the `input_ids`. Args: - input_features: A batch of audio inputs, [1, 80, M]. + audio_features: A batch of audio inputs [B, N, 80, M]. """ audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is not None: From 6e8a4488dd53eec8ec6534aa009eb4cedcc83abc Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Tue, 3 Sep 2024 16:16:58 +0000 Subject: [PATCH 2/2] Defer import in test, add shape comment --- tests/models/test_ultravox.py | 5 ++++- vllm/model_executor/models/ultravox.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/models/test_ultravox.py b/tests/models/test_ultravox.py index 769a4bdcda317..e98db9b65f484 100644 --- a/tests/models/test_ultravox.py +++ b/tests/models/test_ultravox.py @@ -19,11 +19,13 @@ VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" HF_PLACEHOLDER = "<|audio|>" + @pytest.fixture(scope="session") def audio_assets(): from vllm.assets.audio import AudioAsset return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] + @pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call")) def audio(request): from vllm.assets.audio import AudioAsset @@ -190,7 +192,8 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, VLLM_PLACEHOLDER) run_multi_audio_test( vllm_runner, - [(vllm_prompt, [audio.audio_and_sample_rate for audio in audio_assets])], + [(vllm_prompt, [audio.audio_and_sample_rate + for audio in audio_assets])], MODEL_NAME, dtype=dtype, max_tokens=max_tokens, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 49ae6841f726a..416fabda831a2 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -55,6 +55,7 @@ class UltravoxAudioFeatureInputs(TypedDict): class UltravoxAudioEmbeddingInputs(TypedDict): type: Literal["audio_embeds"] data: NestedTensors + """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,