diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index eeefe3f9ba53..18f137d9fcff 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -1486,11 +1486,18 @@ async def openai_v1_audio_transcriptions( response_format: str = Form(default="json"), temperature: float = Form(default=0.0), stream: bool = Form(default=False), + timestamp_granularities: Optional[List[str]] = Form( + default=None, alias="timestamp_granularities[]" + ), ): """OpenAI-compatible audio transcription endpoint.""" - if response_format not in ["json", "text"]: + if response_format not in ["json", "text", "verbose_json"]: return ORJSONResponse( - content={"error": {"message": "Only 'json' and 'text' formats supported"}}, + content={ + "error": { + "message": "Only 'json', 'text', and 'verbose_json' formats supported" + } + }, status_code=400, ) @@ -1504,6 +1511,7 @@ async def openai_v1_audio_transcriptions( response_format=response_format, temperature=temperature, stream=stream, + timestamp_granularities=timestamp_granularities, raw_request=raw_request, ) ) diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 16fa9cbc807c..837c2f86b381 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -1443,6 +1443,7 @@ class TranscriptionRequest(BaseModel): language: Optional[str] = None response_format: str = "json" temperature: float = 0.0 + timestamp_granularities: Optional[List[str]] = None stream: bool = False # Internal fields (not from API) audio_data: Optional[bytes] = None @@ -1463,6 +1464,26 @@ class TranscriptionResponse(BaseModel): usage: Optional[TranscriptionUsage] = None +class TranscriptionSegment(BaseModel): + """A segment with timestamp information.""" + + id: int + start: float + end: float + text: str + + +class TranscriptionVerboseResponse(BaseModel): + """Verbose transcription response with timestamps (OpenAI-compatible).""" + + task: str = "transcribe" + language: Optional[str] = None + duration: Optional[float] = None + text: str + segments: List[TranscriptionSegment] = [] + usage: Optional[TranscriptionUsage] = None + + class TranscriptionStreamChoice(BaseModel): """Delta content for streaming transcription.""" diff --git a/python/sglang/srt/entrypoints/openai/serving_transcription.py b/python/sglang/srt/entrypoints/openai/serving_transcription.py index 2b5661f4967a..bfbad1e0d321 100644 --- a/python/sglang/srt/entrypoints/openai/serving_transcription.py +++ b/python/sglang/srt/entrypoints/openai/serving_transcription.py @@ -22,7 +22,7 @@ import math import time import uuid -from typing import TYPE_CHECKING, AsyncGenerator, Optional, Union +from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Union from fastapi import Request from fastapi.responses import ORJSONResponse, Response, StreamingResponse @@ -32,9 +32,11 @@ ErrorResponse, TranscriptionRequest, TranscriptionResponse, + TranscriptionSegment, TranscriptionStreamChoice, TranscriptionStreamResponse, TranscriptionUsage, + TranscriptionVerboseResponse, ) from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.managers.io_struct import GenerateReqInput @@ -44,6 +46,10 @@ logger = logging.getLogger(__name__) +# Whisper timestamp token constants +TIMESTAMP_BASE_TOKEN_ID = 50365 # <|0.00|> +TIMESTAMP_BASE_OFFSET = 0.02 # Each token step = 0.02 seconds + class OpenAIServingTranscription(OpenAIServingBase): """Handler for /v1/audio/transcriptions requests""" @@ -72,6 +78,9 @@ def _convert_to_internal_request( "language": request.language, # Pass to WhisperProcessor for language-specific decoding } + if request.timestamp_granularities: + sampling_params["timestamp_granularities"] = request.timestamp_granularities + # For Whisper, we pass audio_data and let the processor handle it adapted_request = GenerateReqInput( text="", # Empty text - Whisper processor will set proper decoder tokens @@ -89,13 +98,83 @@ def _get_audio_duration(self, audio_data: bytes) -> float: try: import soundfile as sf - audio_array, sr = sf.read(io.BytesIO(audio_data)) - duration = len(audio_array) / sr - return duration + info = sf.info(io.BytesIO(audio_data)) + return info.duration except Exception as e: logger.warning(f"Could not calculate audio duration: {e}") return 0.0 + def _parse_segments( + self, output_ids: List[int], tokenizer + ) -> tuple[str, List[TranscriptionSegment]]: + """Parse timestamp tokens from output_ids into segments. + + The decoder prompt ends with <|0.00|>, so the first segment starts at + t=0. The model then outputs: + text_tokens <|end_ts|> [<|start_ts|> text_tokens <|end_ts|> ...] + Each timestamp token marks the end of the current segment; its value + also becomes the start of the next segment. + """ + # Token IDs for special tokens we want to strip from segment text + eos_token_id = getattr(tokenizer, "eos_token_id", 50257) + + segments = [] + full_text_parts = [] + current_text_tokens = [] + current_start = 0.0 # First segment starts at 0.0 (from prompt <|0.00|>) + seg_id = 0 + + for token_id in output_ids: + if token_id >= TIMESTAMP_BASE_TOKEN_ID: + # This is a timestamp token — marks the end of current segment + timestamp = (token_id - TIMESTAMP_BASE_TOKEN_ID) * TIMESTAMP_BASE_OFFSET + + if current_text_tokens: + text = tokenizer.decode( + current_text_tokens, skip_special_tokens=True + ).strip() + if text: + segments.append( + TranscriptionSegment( + id=seg_id, + start=round(current_start, 2), + end=round(timestamp, 2), + text=text, + ) + ) + full_text_parts.append(text) + seg_id += 1 + current_text_tokens = [] + + # Next segment starts at this timestamp + current_start = timestamp + + elif token_id == eos_token_id: + # Skip end-of-text token + continue + else: + # Regular text token + current_text_tokens.append(token_id) + + # Handle any trailing text tokens without a closing timestamp + if current_text_tokens: + text = tokenizer.decode( + current_text_tokens, skip_special_tokens=True + ).strip() + if text: + segments.append( + TranscriptionSegment( + id=seg_id, + start=round(current_start, 2), + end=round(current_start, 2), + text=text, + ) + ) + full_text_parts.append(text) + + full_text = " ".join(full_text_parts) + return full_text, segments + async def create_transcription( self, audio_data: bytes, @@ -105,7 +184,14 @@ async def create_transcription( temperature: float, stream: bool, raw_request: Request, - ) -> Union[TranscriptionResponse, StreamingResponse, Response, ORJSONResponse]: + timestamp_granularities: Optional[List[str]] = None, + ) -> Union[ + TranscriptionResponse, + TranscriptionVerboseResponse, + StreamingResponse, + Response, + ORJSONResponse, + ]: """Main entry point for transcription requests.""" # Calculate audio duration for usage reporting audio_duration_s = self._get_audio_duration(audio_data) @@ -117,6 +203,7 @@ async def create_transcription( language=language, response_format=response_format, temperature=temperature, + timestamp_granularities=timestamp_granularities, stream=stream, audio_duration_s=audio_duration_s, ) @@ -129,7 +216,13 @@ async def _handle_non_streaming_request( adapted_request: GenerateReqInput, request: TranscriptionRequest, raw_request: Request, - ) -> Union[TranscriptionResponse, ErrorResponse, ORJSONResponse, Response]: + ) -> Union[ + TranscriptionResponse, + TranscriptionVerboseResponse, + ErrorResponse, + ORJSONResponse, + Response, + ]: """Handle non-streaming transcription request.""" try: ret = await self.tokenizer_manager.generate_request( @@ -139,14 +232,26 @@ async def _handle_non_streaming_request( return self.create_error_response(str(e)) text = ret.get("text", "") + usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s))) # Build response based on format if request.response_format == "text": return Response(content=text, media_type="text/plain") - # JSON format - usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s))) + if request.response_format == "verbose_json": + output_ids = ret.get("output_ids", []) + tokenizer = self.tokenizer_manager.tokenizer + parsed_text, segments = self._parse_segments(output_ids, tokenizer) + + return TranscriptionVerboseResponse( + language=request.language or "en", + duration=round(request.audio_duration_s, 2), + text=parsed_text or text, + segments=segments, + usage=usage, + ) + # Default JSON format return TranscriptionResponse(text=text, usage=usage) async def _handle_streaming_request( diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 147863803602..257cadf15829 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -1048,16 +1048,19 @@ def update_cross_attention( fixed_split_size: Optional[int] = None, disable_split_kv: Optional[bool] = None, ): + # Cache encoder_lens on CPU to avoid GPU→CPU transfer per call + encoder_lens_cpu = encoder_lens.cpu() if encoder_lens is not None else None for wrapper_id in range(2): if wrapper_id == 0: - # Normal attention paged_kernel_lens = seq_lens kv_start_idx = encoder_lens + kv_lens_cpu = seq_lens_cpu else: - # Cross attention + # Cross-attention: attend to encoder tokens only paged_kernel_lens = encoder_lens kv_start_idx = torch.zeros_like(encoder_lens) seq_lens_sum = encoder_lens.sum().item() + kv_lens_cpu = encoder_lens_cpu self.call_begin_forward( decode_wrappers[wrapper_id], @@ -1067,7 +1070,7 @@ def update_cross_attention( self.kv_indptr[wrapper_id], kv_start_idx, spec_info, - seq_lens_cpu=seq_lens_cpu, + seq_lens_cpu=kv_lens_cpu, ) def call_begin_forward( diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index ea8c93963a32..066d4fedac43 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -590,7 +590,12 @@ def __init__(self, model_runner: ModelRunner): else self.dllm_config.block_size ) - self.encoder_len_fill_value = 0 + # Non-zero encoder length ensures cross-attention kernels are captured in the graph. + self.encoder_len_fill_value = ( + getattr(model_runner.model_config.hf_config, "max_source_positions", 0) + if self.is_encoder_decoder + else 0 + ) if self.enable_torch_compile: set_torch_compile_config() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fc9afafac90b..fd7b709fbe5f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2068,7 +2068,11 @@ def _dummy_run(self, batch_size: int, run_ctx=None): is_encoder_decoder=self.model_config.is_encoder_decoder, require_mlp_tp_gather=require_mlp_tp_gather_, seq_len_fill_value=seq_len_fill_value, - encoder_len_fill_value=0, + encoder_len_fill_value=( + getattr(self.model_config.hf_config, "max_source_positions", 0) + if self.model_config.is_encoder_decoder + else 0 + ), num_tokens_per_bs=num_tokens_per_bs, cache_loc_dtype=torch.int64, enable_mamba_track=False, diff --git a/python/sglang/srt/models/whisper.py b/python/sglang/srt/models/whisper.py index d69fb666d2d8..d9190a2f12ed 100644 --- a/python/sglang/srt/models/whisper.py +++ b/python/sglang/srt/models/whisper.py @@ -94,70 +94,16 @@ def forward( """Input shape: Batch x Time x Channel""" if self.is_cross_attention: + # Cross-attention: KV cached during prefill, read from pool during decode. q, _ = self.q_proj(hidden_states) + q = q * self.scaling if cross_hidden_states is not None: kv, _ = self.kv_proj(cross_hidden_states) k, v = kv.split([self.kv_size, self.kv_size], dim=-1) else: - k = torch.zeros_like(q) - v = torch.zeros_like(q) - - q = q * self.scaling - num_heads = self.attn.tp_q_head_num - head_dim = self.attn.head_dim - - q = q.view(-1, num_heads, head_dim) - k = k.view(-1, num_heads, head_dim) - v = v.view(-1, num_heads, head_dim) - - q_len = q.shape[0] - kv_len = k.shape[0] - - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - - attn_weights = torch.bmm(q, k.transpose(1, 2)) - - # Apply block-diagonal mask for batched cross-attention - batch_size = forward_batch.batch_size if forward_batch else 1 - if batch_size > 1 and kv_len > 0: - encoder_len_per_request = kv_len // batch_size - if encoder_len_per_request * batch_size == kv_len: - is_decode = forward_batch.forward_mode.is_decode() - if is_decode: - mask = torch.zeros( - (q_len, kv_len), device=q.device, dtype=torch.bool - ) - for i in range(batch_size): - enc_start = i * encoder_len_per_request - enc_end = (i + 1) * encoder_len_per_request - mask[i, enc_start:enc_end] = True - attn_weights = attn_weights.masked_fill( - ~mask.unsqueeze(0), float("-inf") - ) - else: - seq_lens = forward_batch.seq_lens - if seq_lens is not None and len(seq_lens) == batch_size: - seq_lens_list = seq_lens.tolist() - mask = torch.zeros( - (q_len, kv_len), device=q.device, dtype=torch.bool - ) - q_start = 0 - for i, dec_len in enumerate(seq_lens_list): - enc_start = i * encoder_len_per_request - enc_end = (i + 1) * encoder_len_per_request - q_end = q_start + dec_len - mask[q_start:q_end, enc_start:enc_end] = True - q_start = q_end - attn_weights = attn_weights.masked_fill( - ~mask.unsqueeze(0), float("-inf") - ) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.bmm(attn_weights, v) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(q_len, num_heads * head_dim) + k = None + v = None + attn_output = self.attn(q, k, v, forward_batch) else: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) @@ -394,6 +340,7 @@ def forward( position_ids=None, ): inputs_embeds = self.embed_tokens(input_ids) + position_ids = position_ids.clamp(max=self.max_target_positions - 1) positions = self.embed_positions(position_ids) hidden_states = inputs_embeds + positions.to(inputs_embeds.device) @@ -420,7 +367,6 @@ def __init__( ) self.logits_processor = LogitsProcessor(config) self.config = config - self._encoder_cache = {} def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -468,8 +414,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - def pad_input_ids(self, input_ids: List[int], _mm_inputs: MultimodalInputs): - return input_ids + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + # Prepend dummy encoder tokens so that prepare_encoder_info_extend + # correctly allocates encoder KV cache locations in the KV pool. + # These dummy tokens are stripped before the model forward receives input_ids. + encoder_len = self.config.max_source_positions + mm_inputs.num_image_tokens = encoder_len + pad_ids = [0] * encoder_len + return pad_ids + input_ids def forward( self, @@ -479,29 +431,22 @@ def forward( **kwargs: Any, ) -> LogitsProcessorOutput: dtype = self.encoder.conv1.weight.dtype - is_decode = forward_batch.forward_mode.is_decode() - - if is_decode: - encoder_outputs = None - if forward_batch.req_pool_indices is not None: - req_indices = forward_batch.req_pool_indices.tolist() - encoder_list = [] - for req_idx in req_indices: - if req_idx in self._encoder_cache: - encoder_list.append(self._encoder_cache[req_idx]) - if encoder_list: - encoder_outputs = torch.cat(encoder_list, dim=0) - else: - encoder_list = [] + + # Run encoder for requests that haven't cached encoder output yet. + # During decode or when encoder is already cached, encoder_hidden_states + # is None and cross-attention reads KV from the pool via RadixAttention. + encoder_hidden_states = None + if not forward_batch.forward_mode.is_decode(): mm_inputs_list = forward_batch.mm_inputs if forward_batch.mm_inputs else [] - req_indices = ( - forward_batch.req_pool_indices.tolist() - if forward_batch.req_pool_indices is not None - else [] + encoder_cached_list = ( + forward_batch.encoder_cached if forward_batch.encoder_cached else [] ) - for req_idx, mm_input in zip(req_indices, mm_inputs_list): - if mm_input is None or not mm_input.mm_items: + encoder_list = [] + for i, (mm_input, cached) in enumerate( + zip(mm_inputs_list, encoder_cached_list) + ): + if cached or mm_input is None or not mm_input.mm_items: continue features = mm_input.mm_items[0].feature @@ -513,21 +458,17 @@ def forward( features.device, non_blocking=True ) - req_encoder_outputs = self.encoder( + req_encoder_output = self.encoder( features.to(dtype), encoder_position_ids, forward_batch ) - req_encoder_outputs = req_encoder_outputs.squeeze(0) - - self._encoder_cache[req_idx] = req_encoder_outputs - encoder_list.append(req_encoder_outputs) + req_encoder_output = req_encoder_output.squeeze(0) + encoder_list.append(req_encoder_output) if encoder_list: - encoder_outputs = torch.cat(encoder_list, dim=0) - else: - encoder_outputs = None + encoder_hidden_states = torch.cat(encoder_list, dim=0) decoder_outputs = self.decoder( - input_ids, encoder_outputs, forward_batch, positions + input_ids, encoder_hidden_states, forward_batch, positions ) logits = self.logits_processor( diff --git a/python/sglang/srt/multimodal/processors/whisper.py b/python/sglang/srt/multimodal/processors/whisper.py index 2737b2862eac..c09aa885426e 100644 --- a/python/sglang/srt/multimodal/processors/whisper.py +++ b/python/sglang/srt/multimodal/processors/whisper.py @@ -115,10 +115,9 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): # Cache tokenizer for language token lookup self._tokenizer = getattr(self._processor, "tokenizer", None) - def _extract_language_from_request(self, request_obj) -> Optional[str]: + def _pop_sampling_param(self, request_obj, key: str): sampling_params = getattr(request_obj, "sampling_params", None) or {} - language = sampling_params.pop("language", None) - return normalize_language_to_code(language) + return sampling_params.pop(key, None) def _get_language_token_id(self, language: Optional[str]) -> int: # Default to English if not specified @@ -148,27 +147,35 @@ async def process_mm_data_async( # For Whisper, ALWAYS use the proper transcription token sequence # and IGNORE any text prompt - Whisper is a pure speech-to-text model # The decoder_start_token_id and forced_decoder_ids from generation config - # set up: <|startoftranscript|> <|lang|> <|task|> [<|notimestamps|>] + # set up: <|startoftranscript|> <|lang|> <|task|> [<|notimestamps|> or <|0.00|>] - # Extract language from request and get token ID - language = self._extract_language_from_request(request_obj) + language = normalize_language_to_code( + self._pop_sampling_param(request_obj, "language") + ) language_token_id = self._get_language_token_id(language) + timestamp_granularities = self._pop_sampling_param( + request_obj, "timestamp_granularities" + ) # Build decoder input tokens - # <|startoftranscript|> + <|lang|> + <|transcribe|> + <|notimestamps|> decoder_start_token_id = getattr( self.hf_config, "decoder_start_token_id", 50258 ) transcribe_token_id = self._tokenizer.convert_tokens_to_ids("<|transcribe|>") - notimestamps_token_id = self._tokenizer.convert_tokens_to_ids( - "<|notimestamps|>" - ) + + # Use <|0.00|> to enable timestamp generation, or <|notimestamps|> to disable + if timestamp_granularities: + timestamp_token_id = self._tokenizer.convert_tokens_to_ids("<|0.00|>") + else: + timestamp_token_id = self._tokenizer.convert_tokens_to_ids( + "<|notimestamps|>" + ) input_ids = [ decoder_start_token_id, language_token_id, transcribe_token_id, - notimestamps_token_id, + timestamp_token_id, ] # Whisper expects input features padded to max_length (3000 frames = 30 seconds) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c770f3d161f4..a03345c6b9eb 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2192,6 +2192,12 @@ def _get_default_attn_backend(self, use_mla_backend: bool, model_config): 2.2 We will use Flashinfer backend on blackwell. 2.3 Otherwise, we will use triton backend. """ + # Whisper requires flashinfer for cross-attention CUDA graph support + if "WhisperForConditionalGeneration" in ( + model_config.hf_config.architectures or [] + ): + return "flashinfer" + if not use_mla_backend: # MHA architecture if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(self): @@ -2267,12 +2273,16 @@ def _handle_attention_backend_compatibility(self): self.speculative_algorithm is None ), "Speculative decoding is currently not supported with Flex Attention backend" - # Encoder-decoder models (e.g., Whisper) - if model_config.is_encoder_decoder: - logger.warning( - "Cuda graph is disabled for encoder-decoder models (e.g., Whisper)" - ) - self.disable_cuda_graph = True + # Whisper's encoder token padding conflicts with prefix caching. + # Only disable for Whisper; other encoder-decoder models (e.g., mllama) use radix cache. + if ( + model_config.is_encoder_decoder + and not self.disable_radix_cache + and "WhisperForConditionalGeneration" + in (model_config.hf_config.architectures or []) + ): + logger.info("Radix cache is disabled for Whisper") + self.disable_radix_cache = True # Major NVIDIA platforms backends if ( diff --git a/test/manual/test_whisper_cuda_graph.py b/test/manual/test_whisper_cuda_graph.py new file mode 100644 index 000000000000..72d6da16b068 --- /dev/null +++ b/test/manual/test_whisper_cuda_graph.py @@ -0,0 +1,161 @@ +""" +Test Whisper model with CUDA graph support. + +This test verifies that: +1. Whisper model works correctly with CUDA graph enabled (default) +2. Cross-attention KV cache is properly managed through RadixAttention +3. Output is consistent between CUDA graph and non-CUDA-graph modes + +Usage: + python test_whisper_cuda_graph.py + +Requires: + - A GPU with sufficient memory + - openai-whisper model (e.g., openai/whisper-large-v3) + - An audio file or URL for testing +""" + +import io +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +WHISPER_MODEL = "openai/whisper-large-v3" +TEST_AUDIO_URL = "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac" +TEST_AUDIO_LOCAL = "/tmp/test_whisper_audio.flac" + + +def get_audio_bytes(): + """Get audio bytes, downloading if necessary.""" + import os + + if os.path.exists(TEST_AUDIO_LOCAL): + with open(TEST_AUDIO_LOCAL, "rb") as f: + return f.read() + resp = requests.get(TEST_AUDIO_URL, timeout=30) + resp.raise_for_status() + with open(TEST_AUDIO_LOCAL, "wb") as f: + f.write(resp.content) + return resp.content + + +class TestWhisperCudaGraph(CustomTestCase): + """Test Whisper with CUDA graph enabled (default behavior).""" + + @classmethod + def setUpClass(cls): + cls.model = WHISPER_MODEL + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--served-model-name", + "whisper", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def _transcribe(self, language="en"): + """Send a transcription request via OpenAI-compatible audio endpoint.""" + audio_bytes = get_audio_bytes() + response = requests.post( + self.base_url + "/v1/audio/transcriptions", + files={"file": ("audio.ogg", io.BytesIO(audio_bytes), "audio/ogg")}, + data={ + "model": "whisper", + "language": language, + }, + ) + self.assertEqual(response.status_code, 200, response.text) + return response.json() + + def test_basic_transcription(self): + """Test that basic transcription works with CUDA graph.""" + result = self._transcribe() + self.assertIn("text", result) + text = result["text"] + self.assertTrue(len(text) > 0, "Transcription should not be empty") + print(f"Transcription: {text}") + + def test_multiple_sequential_requests(self): + """Test multiple sequential requests to verify CUDA graph replay consistency.""" + results = [] + for i in range(3): + result = self._transcribe() + self.assertIn("text", result) + results.append(result["text"]) + print(f"Request {i+1}: {result['text'][:80]}...") + + # All transcriptions of the same audio should be identical + for i in range(1, len(results)): + self.assertEqual( + results[0], + results[i], + f"Transcription {i+1} differs from first transcription", + ) + + def test_transcription_quality(self): + """Test that transcription quality is reasonable (contains expected words).""" + result = self._transcribe() + text = result["text"].lower() + # The test audio is a LibriSpeech sample about stew for dinner + self.assertIn("stew", text, f"Expected 'stew' in transcription: {text}") + self.assertIn("dinner", text, f"Expected 'dinner' in transcription: {text}") + print(f"Quality check passed: {result['text'][:80]}...") + + +class TestWhisperNoCudaGraph(CustomTestCase): + """Test Whisper with CUDA graph explicitly disabled for comparison.""" + + @classmethod + def setUpClass(cls): + cls.model = WHISPER_MODEL + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--served-model-name", + "whisper", + "--disable-cuda-graph", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_basic_transcription_no_cuda_graph(self): + """Test that transcription works without CUDA graph (baseline).""" + audio_bytes = get_audio_bytes() + response = requests.post( + self.base_url + "/v1/audio/transcriptions", + files={"file": ("audio.ogg", io.BytesIO(audio_bytes), "audio/ogg")}, + data={ + "model": "whisper", + "language": "en", + }, + ) + self.assertEqual(response.status_code, 200, response.text) + result = response.json() + self.assertIn("text", result) + self.assertTrue(len(result["text"]) > 0) + print(f"No CUDA graph transcription: {result['text'][:80]}...") + + +if __name__ == "__main__": + unittest.main(verbosity=3)