From 9b3c983ef514cdd543127b16656cdffb5779692a Mon Sep 17 00:00:00 2001 From: Yihao Wang <42559837+AgainstEntropy@users.noreply.github.com> Date: Sat, 4 Apr 2026 18:10:36 +0000 Subject: [PATCH 1/4] [feat] Implement transcription adapter framework for ASR models - Introduced `TranscriptionAdapter` abstract class for model-specific transcription logic. - Added `Qwen3ASRAdapter` and `WhisperAdapter` implementations for respective ASR models. - Implemented adapter registration via `@register_transcription_adapter` decorator. - Updated `OpenAIServingTranscription` to utilize the adapter framework for processing requests. --- .../openai/serving_transcription.py | 169 ++---------------- .../openai/transcription_adapters/__init__.py | 23 +++ .../openai/transcription_adapters/base.py | 77 ++++++++ .../transcription_adapters/qwen3_asr.py | 49 +++++ .../openai/transcription_adapters/whisper.py | 117 ++++++++++++ 5 files changed, 284 insertions(+), 151 deletions(-) create mode 100644 python/sglang/srt/entrypoints/openai/transcription_adapters/__init__.py create mode 100644 python/sglang/srt/entrypoints/openai/transcription_adapters/base.py create mode 100644 python/sglang/srt/entrypoints/openai/transcription_adapters/qwen3_asr.py create mode 100644 python/sglang/srt/entrypoints/openai/transcription_adapters/whisper.py diff --git a/python/sglang/srt/entrypoints/openai/serving_transcription.py b/python/sglang/srt/entrypoints/openai/serving_transcription.py index 1040122b2e15..5540b518f7dd 100644 --- a/python/sglang/srt/entrypoints/openai/serving_transcription.py +++ b/python/sglang/srt/entrypoints/openai/serving_transcription.py @@ -12,7 +12,11 @@ # limitations under the License. # ============================================================================== """ -OpenAI-compatible transcription endpoint handler for Whisper models. +OpenAI-compatible transcription endpoint handler for audio ASR models. + +New ASR models are supported by subclassing ``TranscriptionAdapter`` and +registering via the ``@register_transcription_adapter`` decorator. +See ``transcription_adapters/`` for built-in implementations. """ from __future__ import annotations @@ -32,13 +36,13 @@ ErrorResponse, TranscriptionRequest, TranscriptionResponse, - TranscriptionSegment, TranscriptionStreamChoice, TranscriptionStreamResponse, TranscriptionUsage, TranscriptionVerboseResponse, ) from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase +from sglang.srt.entrypoints.openai.transcription_adapters import resolve_adapter from sglang.srt.managers.io_struct import GenerateReqInput if TYPE_CHECKING: @@ -46,26 +50,16 @@ logger = logging.getLogger(__name__) -# Whisper timestamp token constants -TIMESTAMP_BASE_TOKEN_ID = 50365 # <|0.00|> -TIMESTAMP_BASE_OFFSET = 0.02 # Each token step = 0.02 seconds - -_QWEN3_ASR_TEXT_TAG = "" - - -def _detect_model_family(model_config) -> str: - archs = getattr(getattr(model_config, "hf_config", None), "architectures", []) or [] - if "Qwen3ASRForConditionalGeneration" in archs: - return "qwen3_asr" - return "whisper" - class OpenAIServingTranscription(OpenAIServingBase): """Handler for /v1/audio/transcriptions requests""" def __init__(self, tokenizer_manager: TokenizerManager): super().__init__(tokenizer_manager) - self._model_family = _detect_model_family(tokenizer_manager.model_config) + model_config = tokenizer_manager.model_config + self._adapter = resolve_adapter( + getattr(model_config.hf_config, "architectures", []) + ) def _request_id_prefix(self) -> str: return "trsc-" @@ -81,40 +75,9 @@ def _convert_to_internal_request( raw_request: Request = None, ) -> tuple[GenerateReqInput, TranscriptionRequest]: """Convert transcription request to internal format.""" - if self._model_family == "qwen3_asr": - prompt = ( - "<|im_start|>user\n" - "<|audio_start|><|audio_pad|><|audio_end|>" - "<|im_end|>\n" - "<|im_start|>assistant\n" - ) - sampling_params = { - "temperature": request.temperature, - "max_new_tokens": 1024, - } - adapted_request = GenerateReqInput( - text=prompt, - audio_data=request.audio_data, - sampling_params=sampling_params, - stream=request.stream, - modalities=["audio"], - routing_key=self.extract_routing_key(raw_request), - ) - return adapted_request, request - - # Build sampling params - include language for WhisperProcessor - sampling_params = { - "temperature": request.temperature, - "max_new_tokens": 448, # Whisper default max tokens - "language": request.language, # Pass to WhisperProcessor for language-specific decoding - } - - if request.timestamp_granularities: - sampling_params["timestamp_granularities"] = request.timestamp_granularities - - # For Whisper, we pass audio_data and let the processor handle it + sampling_params = self._adapter.build_sampling_params(request) adapted_request = GenerateReqInput( - text="", # Empty text - Whisper processor will set proper decoder tokens + text="", # Empty text — the multimodal processor sets proper decoder/prompt tokens audio_data=request.audio_data, sampling_params=sampling_params, stream=request.stream, @@ -124,7 +87,8 @@ def _convert_to_internal_request( return adapted_request, request - def _get_audio_duration(self, audio_data: bytes) -> float: + @staticmethod + def _get_audio_duration(audio_data: bytes) -> float: """Calculate audio duration in seconds.""" try: import soundfile as sf @@ -135,77 +99,6 @@ def _get_audio_duration(self, audio_data: bytes) -> float: logger.warning(f"Could not calculate audio duration: {e}") return 0.0 - def _parse_segments( - self, output_ids: List[int], tokenizer - ) -> tuple[str, List[TranscriptionSegment]]: - """Parse timestamp tokens from output_ids into segments. - - The decoder prompt ends with <|0.00|>, so the first segment starts at - t=0. The model then outputs: - text_tokens <|end_ts|> [<|start_ts|> text_tokens <|end_ts|> ...] - Each timestamp token marks the end of the current segment; its value - also becomes the start of the next segment. - """ - # Token IDs for special tokens we want to strip from segment text - eos_token_id = getattr(tokenizer, "eos_token_id", 50257) - - segments = [] - full_text_parts = [] - current_text_tokens = [] - current_start = 0.0 # First segment starts at 0.0 (from prompt <|0.00|>) - seg_id = 0 - - for token_id in output_ids: - if token_id >= TIMESTAMP_BASE_TOKEN_ID: - # This is a timestamp token — marks the end of current segment - timestamp = (token_id - TIMESTAMP_BASE_TOKEN_ID) * TIMESTAMP_BASE_OFFSET - - if current_text_tokens: - text = tokenizer.decode( - current_text_tokens, skip_special_tokens=True - ).strip() - if text: - segments.append( - TranscriptionSegment( - id=seg_id, - start=round(current_start, 2), - end=round(timestamp, 2), - text=text, - ) - ) - full_text_parts.append(text) - seg_id += 1 - current_text_tokens = [] - - # Next segment starts at this timestamp - current_start = timestamp - - elif token_id == eos_token_id: - # Skip end-of-text token - continue - else: - # Regular text token - current_text_tokens.append(token_id) - - # Handle any trailing text tokens without a closing timestamp - if current_text_tokens: - text = tokenizer.decode( - current_text_tokens, skip_special_tokens=True - ).strip() - if text: - segments.append( - TranscriptionSegment( - id=seg_id, - start=round(current_start, 2), - end=round(current_start, 2), - text=text, - ) - ) - full_text_parts.append(text) - - full_text = " ".join(full_text_parts) - return full_text, segments - async def create_transcription( self, audio_data: bytes, @@ -262,9 +155,7 @@ async def _handle_non_streaming_request( except ValueError as e: return self.create_error_response(str(e)) - text = ret.get("text", "") - if self._model_family == "qwen3_asr": - text = _postprocess_qwen3_asr(text) + text = self._adapter.postprocess_text(ret.get("text", "")) usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s))) # Build response based on format @@ -272,23 +163,9 @@ async def _handle_non_streaming_request( return Response(content=text, media_type="text/plain") if request.response_format == "verbose_json": - if self._model_family == "whisper": - output_ids = ret.get("output_ids", []) - tokenizer = self.tokenizer_manager.tokenizer - parsed_text, segments = self._parse_segments(output_ids, tokenizer) - return TranscriptionVerboseResponse( - language=request.language or "en", - duration=round(request.audio_duration_s, 2), - text=parsed_text or text, - segments=segments, - usage=usage, - ) - return TranscriptionVerboseResponse( - language=request.language, - duration=round(request.audio_duration_s, 2), - text=text, - segments=[], - usage=usage, + tokenizer = self.tokenizer_manager.tokenizer + return self._adapter.build_verbose_response( + request, text, ret, tokenizer, usage ) # Default JSON format @@ -364,13 +241,3 @@ async def _generate_transcription_stream( yield f"data: {error}\n\n" yield "data: [DONE]\n\n" - - -# TODO (adityavaid): refactor model-specific postprocessing into a plugin/adapter mechanism. -def _postprocess_qwen3_asr(text: str) -> str: - if not text: - return "" - if _QWEN3_ASR_TEXT_TAG in text: - _, text_part = text.rsplit(_QWEN3_ASR_TEXT_TAG, 1) - return text_part.strip() - return text.strip() diff --git a/python/sglang/srt/entrypoints/openai/transcription_adapters/__init__.py b/python/sglang/srt/entrypoints/openai/transcription_adapters/__init__.py new file mode 100644 index 000000000000..353196fddfad --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/transcription_adapters/__init__.py @@ -0,0 +1,23 @@ +# Re-export the public API from base so callers can do: +# from ...transcription_adapters import TranscriptionAdapter, register_transcription_adapter +from sglang.srt.entrypoints.openai.transcription_adapters.base import ( # noqa: F401 + TranscriptionAdapter, + register_transcription_adapter, + resolve_adapter, +) + +# Import built-in adapters so they self-register via @register_transcription_adapter. +from sglang.srt.entrypoints.openai.transcription_adapters.qwen3_asr import ( # noqa: F401 + Qwen3ASRAdapter, +) +from sglang.srt.entrypoints.openai.transcription_adapters.whisper import ( # noqa: F401 + WhisperAdapter, +) + +__all__ = [ + "TranscriptionAdapter", + "register_transcription_adapter", + "resolve_adapter", + "WhisperAdapter", + "Qwen3ASRAdapter", +] diff --git a/python/sglang/srt/entrypoints/openai/transcription_adapters/base.py b/python/sglang/srt/entrypoints/openai/transcription_adapters/base.py new file mode 100644 index 000000000000..c2deb05e1b6c --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/transcription_adapters/base.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List + +from sglang.srt.entrypoints.openai.protocol import ( + TranscriptionRequest, + TranscriptionUsage, + TranscriptionVerboseResponse, +) + + +class TranscriptionAdapter(ABC): + """Abstract base for model-specific transcription logic. + + Subclass this and decorate with ``@register_transcription_adapter("Key")`` + to add support for a new ASR model. See the sibling modules for + the built-in Whisper and Qwen3-ASR implementations. + """ + + @abstractmethod + def build_sampling_params(self, request: TranscriptionRequest) -> dict: + """Return the ``sampling_params`` dict for ``GenerateReqInput``.""" + + def postprocess_text(self, text: str) -> str: + """Strip model-specific markers from raw decoded text. + + The default implementation is a no-op pass-through. + """ + return text + + @abstractmethod + def build_verbose_response( + self, + request: TranscriptionRequest, + text: str, + ret: dict, + tokenizer, + usage: TranscriptionUsage, + ) -> TranscriptionVerboseResponse: + """Build a ``verbose_json`` response with segments / timestamps.""" + + +_ADAPTER_REGISTRY: dict[str, type[TranscriptionAdapter]] = {} +_DEFAULT_ADAPTER_KEY = "Whisper" + + +def register_transcription_adapter( + key: str, +) -> callable: + """Class decorator that registers a ``TranscriptionAdapter`` subclass. + + *key* is matched as a substring against the model's HF ``architectures`` + list at init time (e.g. ``"Whisper"`` matches + ``"WhisperForConditionalGeneration"``). + """ + + def decorator(cls: type[TranscriptionAdapter]) -> type[TranscriptionAdapter]: + _ADAPTER_REGISTRY[key] = cls + return cls + + return decorator + + +def resolve_adapter(architectures: List[str]) -> TranscriptionAdapter: + """Pick the right adapter by matching architecture names against the registry.""" + for arch in architectures or []: + for key, adapter_cls in _ADAPTER_REGISTRY.items(): + if key in arch: + return adapter_cls() + default_cls = _ADAPTER_REGISTRY.get(_DEFAULT_ADAPTER_KEY) + if default_cls is None: + raise RuntimeError( + "No transcription adapters registered. " + "Make sure 'transcription_adapters' package is importable." + ) + return default_cls() diff --git a/python/sglang/srt/entrypoints/openai/transcription_adapters/qwen3_asr.py b/python/sglang/srt/entrypoints/openai/transcription_adapters/qwen3_asr.py new file mode 100644 index 000000000000..aa96d2cad2b8 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/transcription_adapters/qwen3_asr.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from sglang.srt.entrypoints.openai.protocol import ( + TranscriptionRequest, + TranscriptionUsage, + TranscriptionVerboseResponse, +) +from sglang.srt.entrypoints.openai.transcription_adapters.base import ( + TranscriptionAdapter, + register_transcription_adapter, +) + + +@register_transcription_adapter("Qwen3ASR") +class Qwen3ASRAdapter(TranscriptionAdapter): + ASR_TEXT_TAG = "" + + def build_sampling_params(self, request: TranscriptionRequest) -> dict: + temperature = request.temperature + if temperature == 0.0: + temperature = 0.01 # Qwen3-ASR recommended near-greedy temperature + return { + "temperature": temperature, + "max_new_tokens": 256, # Qwen3-ASR default + } + + def postprocess_text(self, text: str) -> str: + # Qwen3-ASR outputs "language transcription" format; + # strip the prefix to return clean transcription text. + if self.ASR_TEXT_TAG in text: + return text.split(self.ASR_TEXT_TAG, 1)[-1] + return text + + def build_verbose_response( + self, + request: TranscriptionRequest, + text: str, + ret: dict, + tokenizer, + usage: TranscriptionUsage, + ) -> TranscriptionVerboseResponse: + # Qwen3-ASR doesn't natively produce timestamp tokens + return TranscriptionVerboseResponse( + language=request.language or "auto", + duration=round(request.audio_duration_s, 2), + text=text, + segments=[], + usage=usage, + ) diff --git a/python/sglang/srt/entrypoints/openai/transcription_adapters/whisper.py b/python/sglang/srt/entrypoints/openai/transcription_adapters/whisper.py new file mode 100644 index 000000000000..1fa9a6db88aa --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/transcription_adapters/whisper.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from typing import List + +from sglang.srt.entrypoints.openai.protocol import ( + TranscriptionRequest, + TranscriptionSegment, + TranscriptionUsage, + TranscriptionVerboseResponse, +) +from sglang.srt.entrypoints.openai.transcription_adapters.base import ( + TranscriptionAdapter, + register_transcription_adapter, +) + + +@register_transcription_adapter("Whisper") +class WhisperAdapter(TranscriptionAdapter): + TIMESTAMP_BASE_TOKEN_ID = 50365 # <|0.00|> + TIMESTAMP_BASE_OFFSET = 0.02 # each token step = 0.02 s + + def build_sampling_params(self, request: TranscriptionRequest) -> dict: + params: dict = { + "temperature": request.temperature, + "max_new_tokens": 448, # Whisper default max tokens + "language": request.language, + } + if request.timestamp_granularities: + params["timestamp_granularities"] = request.timestamp_granularities + return params + + def build_verbose_response( + self, + request: TranscriptionRequest, + text: str, + ret: dict, + tokenizer, + usage: TranscriptionUsage, + ) -> TranscriptionVerboseResponse: + output_ids = ret.get("output_ids", []) + parsed_text, segments = self._parse_segments(output_ids, tokenizer) + return TranscriptionVerboseResponse( + language=request.language or "en", + duration=round(request.audio_duration_s, 2), + text=parsed_text or text, + segments=segments, + usage=usage, + ) + + @staticmethod + def _parse_segments( + output_ids: List[int], tokenizer + ) -> tuple[str, List[TranscriptionSegment]]: + """Parse Whisper timestamp tokens from *output_ids* into segments. + + The decoder prompt ends with ``<|0.00|>``, so the first segment starts + at t=0. The model then outputs:: + + text_tokens <|end_ts|> [<|start_ts|> text_tokens <|end_ts|> ...] + + Each timestamp token marks the end of the current segment; its value + also becomes the start of the next segment. + """ + eos_token_id = getattr(tokenizer, "eos_token_id", 50257) + ts_base = WhisperAdapter.TIMESTAMP_BASE_TOKEN_ID + ts_step = WhisperAdapter.TIMESTAMP_BASE_OFFSET + + segments: list[TranscriptionSegment] = [] + full_text_parts: list[str] = [] + current_text_tokens: list[int] = [] + current_start = 0.0 # First segment starts at 0.0 (from prompt <|0.00|>) + seg_id = 0 + + for token_id in output_ids: + if token_id >= ts_base: + timestamp = (token_id - ts_base) * ts_step + + if current_text_tokens: + seg_text = tokenizer.decode( + current_text_tokens, skip_special_tokens=True + ).strip() + if seg_text: + segments.append( + TranscriptionSegment( + id=seg_id, + start=round(current_start, 2), + end=round(timestamp, 2), + text=seg_text, + ) + ) + full_text_parts.append(seg_text) + seg_id += 1 + current_text_tokens = [] + + current_start = timestamp + + elif token_id == eos_token_id: + continue + else: + current_text_tokens.append(token_id) + + if current_text_tokens: + seg_text = tokenizer.decode( + current_text_tokens, skip_special_tokens=True + ).strip() + if seg_text: + segments.append( + TranscriptionSegment( + id=seg_id, + start=round(current_start, 2), + end=round(current_start, 2), + text=seg_text, + ) + ) + full_text_parts.append(seg_text) + + return " ".join(full_text_parts), segments From 5ef3708e739cbf897abce08f67d4b48ad3aaea7c Mon Sep 17 00:00:00 2001 From: Yihao Wang <42559837+AgainstEntropy@users.noreply.github.com> Date: Mon, 6 Apr 2026 05:34:01 +0000 Subject: [PATCH 2/4] misc --- .../openai/transcription_adapters/qwen3_asr.py | 2 +- .../sglang/srt/multimodal/processors/qwen3_asr.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/transcription_adapters/qwen3_asr.py b/python/sglang/srt/entrypoints/openai/transcription_adapters/qwen3_asr.py index aa96d2cad2b8..dca58ec84fb0 100644 --- a/python/sglang/srt/entrypoints/openai/transcription_adapters/qwen3_asr.py +++ b/python/sglang/srt/entrypoints/openai/transcription_adapters/qwen3_asr.py @@ -39,7 +39,7 @@ def build_verbose_response( tokenizer, usage: TranscriptionUsage, ) -> TranscriptionVerboseResponse: - # Qwen3-ASR doesn't natively produce timestamp tokens + # TODO: Qwen3-ASR needs ForcedAligner to produce timestamps return TranscriptionVerboseResponse( language=request.language or "auto", duration=round(request.audio_duration_s, 2), diff --git a/python/sglang/srt/multimodal/processors/qwen3_asr.py b/python/sglang/srt/multimodal/processors/qwen3_asr.py index 59ebb921ea99..546dbc13708f 100644 --- a/python/sglang/srt/multimodal/processors/qwen3_asr.py +++ b/python/sglang/srt/multimodal/processors/qwen3_asr.py @@ -10,11 +10,13 @@ MultimodalSpecialTokens, ) +AUDIO_PLACEHOLDER = "<|audio_start|><|audio_pad|><|audio_end|>" + _DEFAULT_ASR_PROMPT = ( - "<|im_start|>user\n" - "<|audio_start|><|audio_pad|><|audio_end|>" - "<|im_end|>\n" - "<|im_start|>assistant\n" + f"<|im_start|>user\n" + f"{AUDIO_PLACEHOLDER}" + f"<|im_end|>\n" + f"<|im_start|>assistant\n" ) @@ -23,7 +25,7 @@ class Qwen3ASRMultimodalProcessor(BaseMultimodalProcessor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) - self.AUDIO_TOKEN = "<|audio_start|><|audio_pad|><|audio_end|>" + self.AUDIO_TOKEN = AUDIO_PLACEHOLDER self.AUDIO_TOKEN_REGEX = re.compile( r"<\|audio_start\|>(?:<\|audio_pad\|>)+<\|audio_end\|>" ) @@ -41,6 +43,7 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.ATTR_NAME_TO_MODALITY.update({"feature_attention_mask": Modality.AUDIO}) def _build_transcription_prompt(self, input_text: Union[str, list]) -> str: + # TODO: support `force_language` if isinstance(input_text, list): input_text = self._tokenizer.decode(input_text) if not input_text or not input_text.strip(): From a51a32f741a1180a0aacbb46f57f017427949981 Mon Sep 17 00:00:00 2001 From: Yihao Wang <42559837+AgainstEntropy@users.noreply.github.com> Date: Tue, 7 Apr 2026 21:40:49 +0000 Subject: [PATCH 3/4] refactor Qwen3 ASR configuration and add unit tests --- python/sglang/srt/configs/qwen3_asr.py | 130 ++++++++++++------------- test/manual/models/test_qwen3_asr.py | 118 ++++++++++++++++++++++ 2 files changed, 181 insertions(+), 67 deletions(-) create mode 100644 test/manual/models/test_qwen3_asr.py diff --git a/python/sglang/srt/configs/qwen3_asr.py b/python/sglang/srt/configs/qwen3_asr.py index 048eb2d9704d..259ad058cc03 100644 --- a/python/sglang/srt/configs/qwen3_asr.py +++ b/python/sglang/srt/configs/qwen3_asr.py @@ -14,72 +14,6 @@ from sglang.utils import logger -class Qwen3ASRThinkerConfig(PretrainedConfig): - model_type = "qwen3_asr_thinker" - sub_configs = { - "audio_config": Qwen3OmniMoeAudioEncoderConfig, - } - - def __init__( - self, - audio_config=None, - text_config=None, - audio_token_id=151676, - audio_start_token_id=151669, - audio_end_token_id=151670, - **kwargs, - ): - super().__init__(**kwargs) - - if isinstance(audio_config, dict): - audio_config = Qwen3OmniMoeAudioEncoderConfig(**audio_config) - elif audio_config is None: - audio_config = Qwen3OmniMoeAudioEncoderConfig() - self.audio_config = audio_config - - if isinstance(text_config, dict): - from transformers.models.qwen3.configuration_qwen3 import ( - Qwen3Config as HFQwen3Config, - ) - - text_config = HFQwen3Config(**text_config) - elif text_config is None: - raise ValueError( - "Qwen3ASRThinkerConfig requires a text_config dict with " - "model parameters (hidden_size, num_attention_heads, etc.). " - "Got None." - ) - - self.text_config = text_config - - self.audio_token_id = audio_token_id - self.audio_start_token_id = audio_start_token_id - self.audio_end_token_id = audio_end_token_id - - -class Qwen3ASRConfig(PretrainedConfig): - model_type = "qwen3_asr" - sub_configs = { - "thinker_config": Qwen3ASRThinkerConfig, - } - - def __init__(self, thinker_config=None, **kwargs): - super().__init__(**kwargs) - if thinker_config is None: - thinker_config = {} - logger.info( - "thinker_config is None. " - "Initializing Qwen3-ASR thinker with default values" - ) - if isinstance(thinker_config, dict): - self.thinker_config = Qwen3ASRThinkerConfig(**thinker_config) - else: - self.thinker_config = thinker_config - - def get_text_config(self, decoder=False) -> PretrainedConfig: - return self.thinker_config.text_config - - class Qwen3ASRProcessor(ProcessorMixin): """Minimal composite processor: WhisperFeatureExtractor + Qwen2Tokenizer. @@ -167,6 +101,68 @@ def __call__(self, text=None, audio=None, audio_kwargs=None, **kwargs): return inputs + +class Qwen3ASRThinkerConfig(PretrainedConfig): + model_type = "qwen3_asr_thinker" + sub_configs = { + "audio_config": Qwen3OmniMoeAudioEncoderConfig, + } + + def __init__( + self, + audio_config=None, + text_config=None, + audio_token_id=151676, + audio_start_token_id=151669, + audio_end_token_id=151670, + **kwargs, + ): + super().__init__(**kwargs) + + if isinstance(audio_config, dict): + audio_config = Qwen3OmniMoeAudioEncoderConfig(**audio_config) + elif audio_config is None: + audio_config = Qwen3OmniMoeAudioEncoderConfig() + self.audio_config = audio_config + + from transformers.models.qwen3.configuration_qwen3 import ( + Qwen3Config as HFQwen3Config, + ) + if isinstance(text_config, dict): + text_config = HFQwen3Config(**text_config) + elif text_config is None: + text_config = HFQwen3Config() + + self.text_config = text_config + + self.audio_token_id = audio_token_id + self.audio_start_token_id = audio_start_token_id + self.audio_end_token_id = audio_end_token_id + + +@register_customized_processor(Qwen3ASRProcessor) +class Qwen3ASRConfig(PretrainedConfig): + model_type = "qwen3_asr" + sub_configs = { + "thinker_config": Qwen3ASRThinkerConfig, + } + + def __init__(self, thinker_config=None, **kwargs): + super().__init__(**kwargs) + if thinker_config is None: + thinker_config = {} + logger.info( + "thinker_config is None. " + "Initializing Qwen3-ASR thinker with default values" + ) + if isinstance(thinker_config, dict): + self.thinker_config = Qwen3ASRThinkerConfig(**thinker_config) + else: + self.thinker_config = thinker_config + + def get_text_config(self, decoder=False) -> PretrainedConfig: + return self.thinker_config.text_config + + AutoConfig.register("qwen3_asr", Qwen3ASRConfig) AutoConfig.register("qwen3_asr_thinker", Qwen3ASRThinkerConfig) -register_customized_processor(Qwen3ASRProcessor)(Qwen3ASRConfig) diff --git a/test/manual/models/test_qwen3_asr.py b/test/manual/models/test_qwen3_asr.py new file mode 100644 index 000000000000..ec8747ad0cb6 --- /dev/null +++ b/test/manual/models/test_qwen3_asr.py @@ -0,0 +1,118 @@ +""" +Test Qwen3-ASR model support in SGLang. + +Tests /v1/audio/transcriptions endpoint (OpenAI-compatible). + +Usage: + python test/manual/models/test_qwen3_asr.py +""" + +import io +import os +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +MODEL = "Qwen/Qwen3-ASR-0.6B" +# MODEL = "Qwen/Qwen3-ASR-1.7B" +TEST_AUDIO_EN_URL = ( + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_en.wav" +) +TEST_AUDIO_ZH_URL = ( + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_zh.wav" +) +TEST_AUDIO_EN_LOCAL = "/tmp/test_qwen3_asr_en.wav" +TEST_AUDIO_ZH_LOCAL = "/tmp/test_qwen3_asr_zh.wav" + + +def download_audio(url, local_path): + """Download audio file if not already cached.""" + if os.path.exists(local_path): + with open(local_path, "rb") as f: + return f.read() + resp = requests.get(url, timeout=60) + resp.raise_for_status() + with open(local_path, "wb") as f: + f.write(resp.content) + return resp.content + + +class TestQwen3ASRTranscription(CustomTestCase): + """Test Qwen3-ASR via /v1/audio/transcriptions endpoint.""" + + @classmethod + def setUpClass(cls): + cls.model = MODEL + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--served-model-name", + "qwen3-asr", + "--trust-remote-code", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def _transcribe(self, audio_url, local_path, language=None): + """Send a transcription request.""" + audio_bytes = download_audio(audio_url, local_path) + data = {"model": "qwen3-asr"} + if language: + data["language"] = language + response = requests.post( + self.base_url + "/v1/audio/transcriptions", + files={"file": ("audio.wav", io.BytesIO(audio_bytes), "audio/wav")}, + data=data, + timeout=120, + ) + self.assertEqual(response.status_code, 200, response.text) + return response.json() + + def test_english_transcription(self): + """Test English audio transcription.""" + result = self._transcribe(TEST_AUDIO_EN_URL, TEST_AUDIO_EN_LOCAL) + self.assertIn("text", result) + text = result["text"] + self.assertTrue(len(text) > 0, "Transcription should not be empty") + print(f"[EN Transcription] {text}") + + def test_chinese_transcription(self): + """Test Chinese audio transcription.""" + result = self._transcribe(TEST_AUDIO_ZH_URL, TEST_AUDIO_ZH_LOCAL) + self.assertIn("text", result) + text = result["text"] + self.assertTrue(len(text) > 0, "Transcription should not be empty") + print(f"[ZH Transcription] {text}") + + def test_multiple_requests_consistency(self): + """Test that repeated requests produce consistent output.""" + results = [] + for _ in range(3): + result = self._transcribe(TEST_AUDIO_EN_URL, TEST_AUDIO_EN_LOCAL) + results.append(result["text"]) + + for i in range(1, len(results)): + self.assertEqual( + results[0], + results[i], + f"Request {i+1} differs from first request", + ) + print(f"[Consistency] All 3 requests match: {results[0][:80]}...") + + +if __name__ == "__main__": + unittest.main(verbosity=3) \ No newline at end of file From 0058ae45b5c8ac250823c5c75641c58d3f952072 Mon Sep 17 00:00:00 2001 From: Yihao Wang <42559837+AgainstEntropy@users.noreply.github.com> Date: Wed, 8 Apr 2026 03:12:54 +0000 Subject: [PATCH 4/4] fix lint --- python/sglang/srt/configs/qwen3_asr.py | 2 +- test/manual/models/test_qwen3_asr.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/configs/qwen3_asr.py b/python/sglang/srt/configs/qwen3_asr.py index 259ad058cc03..37fb4ef57d7d 100644 --- a/python/sglang/srt/configs/qwen3_asr.py +++ b/python/sglang/srt/configs/qwen3_asr.py @@ -101,7 +101,6 @@ def __call__(self, text=None, audio=None, audio_kwargs=None, **kwargs): return inputs - class Qwen3ASRThinkerConfig(PretrainedConfig): model_type = "qwen3_asr_thinker" sub_configs = { @@ -128,6 +127,7 @@ def __init__( from transformers.models.qwen3.configuration_qwen3 import ( Qwen3Config as HFQwen3Config, ) + if isinstance(text_config, dict): text_config = HFQwen3Config(**text_config) elif text_config is None: diff --git a/test/manual/models/test_qwen3_asr.py b/test/manual/models/test_qwen3_asr.py index ec8747ad0cb6..c0b772bf5a6e 100644 --- a/test/manual/models/test_qwen3_asr.py +++ b/test/manual/models/test_qwen3_asr.py @@ -115,4 +115,4 @@ def test_multiple_requests_consistency(self): if __name__ == "__main__": - unittest.main(verbosity=3) \ No newline at end of file + unittest.main(verbosity=3)