diff --git a/vllm/config/speech_to_text.py b/vllm/config/speech_to_text.py index 3eafff1a3060..fe3532c9742d 100644 --- a/vllm/config/speech_to_text.py +++ b/vllm/config/speech_to_text.py @@ -17,10 +17,11 @@ class SpeechToTextConfig: 16kHz audio input. The input audio will be automatically resampled to this rate before processing.""" - max_audio_clip_s: int = 30 + max_audio_clip_s: int | None = 30 """Maximum duration in seconds for a single audio clip without chunking. Audio longer than this will be split into smaller chunks if - `allow_audio_chunking` evaluates to True, otherwise it will be rejected.""" + `allow_audio_chunking` evaluates to True, otherwise it will be rejected. + `None` means audio duration can be unlimited and won't be chunked.""" overlap_chunk_second: int = 1 """Overlap duration in seconds between consecutive audio chunks when diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index acfc8160a6d7..b6332d1941c1 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -477,7 +477,15 @@ async def _create_speech_to_text( } segment_class: type[SpeechToTextSegment] = segments_types[self.task_type] text = "" + chunk_size_in_s = self.asr_config.max_audio_clip_s + if chunk_size_in_s is None: + assert len(list_result_generator) == 1, ( + "`max_audio_clip_s` is set to None, audio cannot be chunked" + ) for idx, result_generator in enumerate(list_result_generator): + start_time = ( + float(idx * chunk_size_in_s) if chunk_size_in_s is not None else 0.0 + ) async for op in result_generator: if request.response_format == "verbose_json": segments: list[SpeechToTextSegment] = ( @@ -485,7 +493,7 @@ async def _create_speech_to_text( tokens=tuple(op.outputs[0].token_ids), segment_class=segment_class, request=request, - start_time=idx * self.asr_config.max_audio_clip_s, + start_time=start_time, ) ) @@ -653,6 +661,10 @@ async def _speech_to_text_stream_generator( def _split_audio( self, audio_data: np.ndarray, sample_rate: int ) -> list[np.ndarray]: + assert self.asr_config.max_audio_clip_s is not None, ( + f"{self.asr_config.max_audio_clip_s=} cannot be None to" + " split audio into chunks." + ) chunk_size = sample_rate * self.asr_config.max_audio_clip_s overlap_size = sample_rate * self.asr_config.overlap_chunk_second chunks = [] diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index cbba1af89190..63b26c789091 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -17,7 +17,11 @@ from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest -from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder +from mistral_common.tokens.tokenizers.audio import ( + Audio, + AudioEncoder, + TranscriptionFormat, +) from transformers import BatchFeature, TensorType, WhisperConfig from transformers.tokenization_utils_base import TextInput @@ -157,13 +161,17 @@ def __call__( # pad if necessary # TODO(Patrick) - remove once mistral-common is bumped - sig = inspect.signature(self._audio_processor.pad) - if "is_online_streaming" in sig.parameters: - audio = self._audio_processor.pad( - audio, self.sampling_rate, is_online_streaming=False - ) - else: - audio = self._audio_processor.pad(audio, self.sampling_rate) + if ( + self._audio_processor.audio_config.transcription_format + != TranscriptionFormat.STREAMING + ): + sig = inspect.signature(self._audio_processor.pad) + if "is_online_streaming" in sig.parameters: + audio = self._audio_processor.pad( + audio, self.sampling_rate, is_online_streaming=False + ) + else: + audio = self._audio_processor.pad(audio, self.sampling_rate) audio_tokens = [self.begin_audio_token_id] + [ self.audio_token_id diff --git a/vllm/model_executor/models/voxtral_streaming.py b/vllm/model_executor/models/voxtral_streaming.py index 2e79e24e6f19..a89a0eedd8e7 100644 --- a/vllm/model_executor/models/voxtral_streaming.py +++ b/vllm/model_executor/models/voxtral_streaming.py @@ -3,10 +3,19 @@ import math from collections.abc import Mapping +from typing import Literal, cast +import numpy as np import torch +from mistral_common.protocol.instruct.chunk import RawAudio +from mistral_common.protocol.transcription.request import ( + StreamingMode, + TranscriptionRequest, +) +from mistral_common.tokens.tokenizers.audio import Audio -from vllm.config.vllm import VllmConfig +from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.voxtral import ( @@ -27,6 +36,7 @@ ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.tokenizers import cached_tokenizer_from_config from .utils import ( _flatten_embeddings, @@ -205,13 +215,17 @@ def embed_multimodal( "For streaming you must provide an audio input at every step." ) - multiple_of = self.audio_config.raw_audio_length_per_tok - assert all( - (this_audio := audio.shape[0]) % multiple_of == 0 for audio in audio_inputs - ), ( - f"Every input audio waveform has to be a multiple of {multiple_of}, but" - f" one is {this_audio} with {(this_audio / multiple_of)=}." - ) + def _truncate_left( + sample: torch.Tensor, mult_of: int, pos: int + ) -> torch.Tensor: + assert pos in [0, 1], pos + if (ctx := sample.shape[pos] % mult_of) != 0: + sample = sample[ctx:] if pos == 0 else sample[:, ctx:] + assert sample.shape[pos] > 0, ( + f"Sample is empty after truncation with ctx {ctx}" + ) + + return sample mel_features = [ self.whisper_encoder.compute_whisper_melspec(audio).to( @@ -219,11 +233,16 @@ def embed_multimodal( ) for audio in audio_inputs ] + + # we truncate the left most mel feature + # if the sequence length in impair + mel_features = [_truncate_left(mel, 2, 1) for mel in mel_features] + seq_lens = [mel.shape[1] for mel in mel_features] # [total_num_20ms_frames, hidden_size] audio_embeddings = self.whisper_encoder.whisper_encoder.forward_conv( mel_features - )[0] + ) conv_stride = self.whisper_encoder.whisper_encoder.total_stride audio_embeddings_per_sample = audio_embeddings.split( [s // conv_stride for s in seq_lens], dim=0 @@ -231,13 +250,55 @@ def embed_multimodal( # audio_embeddings per sample need to be divisible by 4 pool_size = self.config.audio_config.block_pool_size - assert all( - (this_shape := sample.shape[0]) % pool_size == 0 + + audio_embeddings_per_sample = [ + _truncate_left(sample, pool_size, 0) for sample in audio_embeddings_per_sample - ), f"Every audio embedding has to be a multiple of 4, but one is {this_shape}." + ] audio_embeddings_per_sample = [ e.view(e.shape[0] // pool_size, e.shape[1] * pool_size) for e in audio_embeddings_per_sample ] return audio_embeddings_per_sample + + @classmethod + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: str + ) -> SpeechToTextConfig: + tokenizer = cached_tokenizer_from_config(model_config) + audio_config = tokenizer.instruct.audio_encoder.audio_config + sample_rate = audio_config.sampling_rate + return SpeechToTextConfig( + max_audio_clip_s=None, # only limited by memory + sample_rate=sample_rate, + min_energy_split_window_size=None, + ) + + @classmethod + # for speech-to-text transcription + def get_generation_prompt( + cls, + audio: np.ndarray, + model_config: ModelConfig, + stt_config: SpeechToTextConfig, + language: str | None, + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: str | None, + ) -> PromptType: + tokenizer = cached_tokenizer_from_config(model_config) + audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless + + req = TranscriptionRequest( + model=model_config.model, + audio=RawAudio.from_audio(audio), + language=language, + streaming=StreamingMode.OFFLINE, + ) + + tokenized = tokenizer.instruct.encode_transcription(req) + audio = (tokenized.audios[0].audio_array, stt_config.sample_rate) + prompts_dict = {"multi_modal_data": {"audio": audio}} + prompts_dict["prompt_token_ids"] = tokenized.tokens + return cast(PromptType, prompts_dict) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index f1bae28debad..d1dadc3d5198 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -469,8 +469,10 @@ def __init__( self.max_source_positions = config.max_source_positions self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - is_causal = getattr(config, "is_causal", False) - Conv1d = WhisperCausalConv1d if is_causal else partial(nn.Conv1d, padding=1) + self.is_causal = getattr(config, "is_causal", False) + Conv1d = ( + WhisperCausalConv1d if self.is_causal else partial(nn.Conv1d, padding=1) + ) self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3) self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3) @@ -485,7 +487,7 @@ def __init__( ) self.layer_norm = nn.LayerNorm(config.d_model) - if is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE: + if self.is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE: raise ValueError( "Only NOPE position embeddings are supported " f"for causal models, but got {self.pos_embed_type}" @@ -536,8 +538,11 @@ def forward_conv( hidden_states.append(embeds) input_is_batched = embeds.ndim > 2 # Input to MHA must be B x T x D - if input_is_batched: + if input_is_batched or self.is_causal: # Models using WhisperEncoder may handle batching internally. + # If WhisperEncoder is causal, sequences + # are not padded to have identical seq length (T) + # => concat over feature dim hidden_states = torch.cat(hidden_states) else: hidden_states = torch.stack(hidden_states, dim=0)