From e6bf7f411289e8d1bcb36ed4d27e98f3ac6115a3 Mon Sep 17 00:00:00 2001 From: Yihao Wang <42559837+AgainstEntropy@users.noreply.github.com> Date: Fri, 3 Apr 2026 19:11:31 +0000 Subject: [PATCH] [feat] add Qwen3-ASR model support and related configurations - Introduced Qwen3-ASR model with configuration and processor classes. - Updated entry points to handle Qwen3-ASR in the transcription endpoint. - Enhanced multimodal processing to support Qwen3-ASR. - Added tests for Qwen3-ASR transcription functionality. - Updated existing files to include Qwen3ASR in relevant imports and configurations. --- python/sglang/srt/configs/__init__.py | 2 + python/sglang/srt/configs/model_config.py | 10 +- python/sglang/srt/configs/qwen3_asr.py | 231 ++++++++++++++++ .../openai/serving_transcription.py | 90 +++++-- python/sglang/srt/models/qwen3_asr.py | 247 ++++++++++++++++++ .../multimodal/processors/base_processor.py | 1 + .../srt/multimodal/processors/qwen3_asr.py | 151 +++++++++++ test/manual/models/test_qwen3_asr.py | 118 +++++++++ 8 files changed, 833 insertions(+), 17 deletions(-) create mode 100644 python/sglang/srt/configs/qwen3_asr.py create mode 100644 python/sglang/srt/models/qwen3_asr.py create mode 100644 python/sglang/srt/multimodal/processors/qwen3_asr.py create mode 100644 test/manual/models/test_qwen3_asr.py diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index 3a3b37f54c0c..da38dde4dbaa 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -21,6 +21,7 @@ from sglang.srt.configs.nano_nemotron_vl import NemotronH_Nano_VL_V2_Config from sglang.srt.configs.nemotron_h import NemotronHConfig from sglang.srt.configs.olmo3 import Olmo3Config +from sglang.srt.configs.qwen3_asr import Qwen3ASRConfig from sglang.srt.configs.qwen3_5 import Qwen3_5Config, Qwen3_5MoeConfig from sglang.srt.configs.qwen3_next import Qwen3NextConfig from sglang.srt.configs.step3_vl import ( @@ -47,6 +48,7 @@ "Olmo3Config", "KimiLinearConfig", "KimiK25Config", + "Qwen3ASRConfig", "Qwen3NextConfig", "Qwen3_5Config", "Qwen3_5MoeConfig", diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index a7f66c8443f5..4fe15032f3de 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -196,8 +196,12 @@ def __init__( self.is_image_understandable_model = enable_multimodal and hasattr( self.hf_config, "vision_config" ) - self.is_audio_understandable_model = enable_multimodal and hasattr( - self.hf_config, "audio_config" + self.is_audio_understandable_model = enable_multimodal and ( + hasattr(self.hf_config, "audio_config") + or ( + hasattr(self.hf_config, "thinker_config") + and hasattr(self.hf_config.thinker_config, "audio_config") + ) ) self.is_multimodal_chunked_prefill_supported = ( @@ -1326,6 +1330,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal "Qwen3VLMoeForConditionalGeneration", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration", + "Qwen3ASRForConditionalGeneration", "Qwen3OmniMoeForConditionalGeneration", "KimiVLForConditionalGeneration", "InternVLChatModel", @@ -1373,6 +1378,7 @@ def is_multimodal_model(model_architectures: List[str]): def is_audio_model(model_architectures: List[str]): models = [ "WhisperForConditionalGeneration", + "Qwen3ASRForConditionalGeneration", ] return any(model in model_architectures for model in models) diff --git a/python/sglang/srt/configs/qwen3_asr.py b/python/sglang/srt/configs/qwen3_asr.py new file mode 100644 index 000000000000..693135e7c5d1 --- /dev/null +++ b/python/sglang/srt/configs/qwen3_asr.py @@ -0,0 +1,231 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Configuration and processor classes for Qwen3-ASR model.""" + +from transformers import ( + AutoConfig, + AutoFeatureExtractor, + AutoTokenizer, + PretrainedConfig, + ProcessorMixin, +) + +from sglang.srt.configs.qwen3_omni import Qwen3OmniMoeAudioEncoderConfig +from sglang.srt.multimodal.customized_mm_processor_utils import ( + register_customized_processor, +) +from sglang.utils import logger + + +class Qwen3ASRThinkerConfig(PretrainedConfig): + model_type = "qwen3_asr_thinker" + sub_configs = { + "audio_config": Qwen3OmniMoeAudioEncoderConfig, + } + + def __init__( + self, + audio_config=None, + text_config=None, + audio_token_id=151676, + audio_start_token_id=151669, + audio_end_token_id=151670, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + self.initializer_range = initializer_range + + if isinstance(audio_config, dict): + audio_config = Qwen3OmniMoeAudioEncoderConfig(**audio_config) + elif audio_config is None: + audio_config = Qwen3OmniMoeAudioEncoderConfig() + self.audio_config = audio_config + + if isinstance(text_config, dict): + # Use the proper Qwen3Config so rope_parameters property works + from transformers.models.qwen3.configuration_qwen3 import ( + Qwen3Config as HFQwen3Config, + ) + + text_config = HFQwen3Config(**text_config) + elif text_config is None: + text_config = PretrainedConfig() + self.text_config = text_config + + self.audio_token_id = audio_token_id + self.audio_start_token_id = audio_start_token_id + self.audio_end_token_id = audio_end_token_id + + +class Qwen3ASRConfig(PretrainedConfig): + model_type = "qwen3_asr" + sub_configs = { + "thinker_config": Qwen3ASRThinkerConfig, + } + + def __init__( + self, + thinker_config=None, + support_languages=None, + **kwargs, + ): + super().__init__(**kwargs) + if thinker_config is None: + thinker_config = {} + logger.info( + "thinker_config is None. Initializing Qwen3-ASR thinker with default values" + ) + + if isinstance(thinker_config, dict): + self.thinker_config = Qwen3ASRThinkerConfig(**thinker_config) + else: + self.thinker_config = thinker_config + self.support_languages = support_languages or [] + + def get_text_config(self, decoder=False) -> "PretrainedConfig": + return self.thinker_config.text_config + + +class Qwen3ASRProcessor(ProcessorMixin): + """Custom processor combining WhisperFeatureExtractor + Qwen2Tokenizer. + + AutoProcessor.from_pretrained() for Qwen3-ASR returns just a tokenizer + because the model repo doesn't ship a proper ProcessorMixin class. + This wrapper provides the composite processor that SGLang expects. + """ + + attributes = ["feature_extractor", "tokenizer"] + feature_extractor_class = "WhisperFeatureExtractor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, feature_extractor=None, tokenizer=None, **kwargs): + super().__init__(feature_extractor=feature_extractor, tokenizer=tokenizer) + self.audio_token = "<|audio_pad|>" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + trust_remote_code = kwargs.pop("trust_remote_code", True) + feature_extractor = AutoFeatureExtractor.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + **{k: v for k, v in kwargs.items() if k in ("revision",)}, + ) + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + **{k: v for k, v in kwargs.items() if k in ("revision", "use_fast")}, + ) + return cls(feature_extractor=feature_extractor, tokenizer=tokenizer) + + def _get_feat_extract_output_lengths(self, input_lengths): + """Compute the number of audio tokens from mel feature lengths.""" + import torch + + if not isinstance(input_lengths, torch.Tensor): + input_lengths = torch.tensor(input_lengths) + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ( + ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + + (input_lengths // 100) * 13 + ) + return output_lengths + + def __call__( + self, + text=None, + audio=None, + audio_kwargs=None, + **kwargs, + ): + import torch + + if audio_kwargs is None: + audio_kwargs = {} + + if audio is not None: + audio_inputs = self.feature_extractor( + audio, + sampling_rate=self.feature_extractor.sampling_rate, + return_attention_mask=True, + return_tensors=kwargs.get("return_tensors"), + **audio_kwargs, + ) + # Rename attention_mask -> feature_attention_mask + inputs = {"input_features": audio_inputs["input_features"]} + if "attention_mask" in audio_inputs: + inputs["feature_attention_mask"] = audio_inputs["attention_mask"] + else: + inputs = {} + + if text is not None: + text_inputs = self.tokenizer( + text, + return_tensors=kwargs.get("return_tensors"), + padding=kwargs.get("padding", False), + ) + input_ids = text_inputs["input_ids"] + + # Expand <|audio_pad|> tokens based on audio feature lengths + if audio is not None and "feature_attention_mask" in inputs: + audio_pad_id = self.tokenizer.convert_tokens_to_ids( + self.audio_token + ) + feat_mask = inputs["feature_attention_mask"] + feat_lengths = feat_mask.sum(dim=-1) # actual mel lengths + audio_token_counts = self._get_feat_extract_output_lengths( + feat_lengths + ) + + # Expand each sequence's audio_pad tokens + expanded_ids_list = [] + for seq_idx in range(input_ids.shape[0]): + seq_ids = input_ids[seq_idx].tolist() + audio_idx = 0 + new_ids = [] + for tid in seq_ids: + if tid == audio_pad_id and audio_idx < len( + audio_token_counts + ): + count = int(audio_token_counts[audio_idx].item()) + new_ids.extend([audio_pad_id] * count) + audio_idx += 1 + else: + new_ids.append(tid) + expanded_ids_list.append(new_ids) + + # Pad to same length and convert to tensor + max_len = max(len(ids) for ids in expanded_ids_list) + padded = [ + ids + [self.tokenizer.pad_token_id or 0] * (max_len - len(ids)) + for ids in expanded_ids_list + ] + input_ids = torch.tensor(padded, dtype=torch.long) + + inputs["input_ids"] = input_ids + + return inputs + + +AutoConfig.register("qwen3_asr", Qwen3ASRConfig) +AutoConfig.register("qwen3_asr_thinker", Qwen3ASRThinkerConfig) + + +@register_customized_processor(Qwen3ASRProcessor) +class _Qwen3ASRConfigForProcessorRegistration(Qwen3ASRConfig): + """Shim so that ``_CUSTOMIZED_MM_PROCESSOR["qwen3_asr"]`` resolves to + ``Qwen3ASRProcessor`` when ``get_processor()`` loads the model.""" + + model_type = "qwen3_asr" diff --git a/python/sglang/srt/entrypoints/openai/serving_transcription.py b/python/sglang/srt/entrypoints/openai/serving_transcription.py index bfbad1e0d321..59b41a5cb632 100644 --- a/python/sglang/srt/entrypoints/openai/serving_transcription.py +++ b/python/sglang/srt/entrypoints/openai/serving_transcription.py @@ -12,7 +12,8 @@ # limitations under the License. # ============================================================================== """ -OpenAI-compatible transcription endpoint handler for Whisper models. +OpenAI-compatible transcription endpoint handler for audio ASR models. +Supports Whisper and Qwen3-ASR models. """ from __future__ import annotations @@ -51,11 +52,20 @@ TIMESTAMP_BASE_OFFSET = 0.02 # Each token step = 0.02 seconds +def _is_qwen3_asr_model(architectures: List[str]) -> bool: + return any("Qwen3ASR" in arch for arch in (architectures or [])) + + class OpenAIServingTranscription(OpenAIServingBase): """Handler for /v1/audio/transcriptions requests""" def __init__(self, tokenizer_manager: TokenizerManager): super().__init__(tokenizer_manager) + # Detect if the loaded model is Qwen3-ASR + model_config = tokenizer_manager.model_config + self._is_qwen3_asr = _is_qwen3_asr_model( + getattr(model_config.hf_config, "architectures", []) + ) def _request_id_prefix(self) -> str: return "trsc-" @@ -71,17 +81,25 @@ def _convert_to_internal_request( raw_request: Request = None, ) -> tuple[GenerateReqInput, TranscriptionRequest]: """Convert transcription request to internal format.""" - # Build sampling params - include language for WhisperProcessor + if self._is_qwen3_asr: + return self._convert_qwen3_asr_request(request, raw_request) + return self._convert_whisper_request(request, raw_request) + + def _convert_whisper_request( + self, + request: TranscriptionRequest, + raw_request: Request = None, + ) -> tuple[GenerateReqInput, TranscriptionRequest]: + """Convert transcription request for Whisper models.""" sampling_params = { "temperature": request.temperature, "max_new_tokens": 448, # Whisper default max tokens - "language": request.language, # Pass to WhisperProcessor for language-specific decoding + "language": request.language, } if request.timestamp_granularities: sampling_params["timestamp_granularities"] = request.timestamp_granularities - # For Whisper, we pass audio_data and let the processor handle it adapted_request = GenerateReqInput( text="", # Empty text - Whisper processor will set proper decoder tokens audio_data=request.audio_data, @@ -90,7 +108,33 @@ def _convert_to_internal_request( modalities=["audio"], routing_key=self.extract_routing_key(raw_request), ) + return adapted_request, request + def _convert_qwen3_asr_request( + self, + request: TranscriptionRequest, + raw_request: Request = None, + ) -> tuple[GenerateReqInput, TranscriptionRequest]: + """Convert transcription request for Qwen3-ASR models.""" + temperature = request.temperature + if temperature == 0.0: + temperature = 0.01 # Qwen3-ASR recommended near-greedy temperature + + sampling_params = { + "temperature": temperature, + "max_new_tokens": 256, # Qwen3-ASR default + } + + # Qwen3-ASR uses chat format with audio_url in content + # The processor handles the prompt construction from audio_data + adapted_request = GenerateReqInput( + text="", # Processor will construct the proper prompt + audio_data=request.audio_data, + sampling_params=sampling_params, + stream=request.stream, + modalities=["audio"], + routing_key=self.extract_routing_key(raw_request), + ) return adapted_request, request def _get_audio_duration(self, audio_data: bytes) -> float: @@ -232,6 +276,12 @@ async def _handle_non_streaming_request( return self.create_error_response(str(e)) text = ret.get("text", "") + + # Qwen3-ASR outputs "language transcription" format + # Strip the prefix to return clean transcription text + if self._is_qwen3_asr and "" in text: + text = text.split("", 1)[-1] + usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s))) # Build response based on format @@ -239,17 +289,27 @@ async def _handle_non_streaming_request( return Response(content=text, media_type="text/plain") if request.response_format == "verbose_json": - output_ids = ret.get("output_ids", []) - tokenizer = self.tokenizer_manager.tokenizer - parsed_text, segments = self._parse_segments(output_ids, tokenizer) - - return TranscriptionVerboseResponse( - language=request.language or "en", - duration=round(request.audio_duration_s, 2), - text=parsed_text or text, - segments=segments, - usage=usage, - ) + if self._is_qwen3_asr: + # Qwen3-ASR doesn't natively produce timestamp tokens + return TranscriptionVerboseResponse( + language=request.language or "auto", + duration=round(request.audio_duration_s, 2), + text=text, + segments=[], + usage=usage, + ) + else: + output_ids = ret.get("output_ids", []) + tokenizer = self.tokenizer_manager.tokenizer + parsed_text, segments = self._parse_segments(output_ids, tokenizer) + + return TranscriptionVerboseResponse( + language=request.language or "en", + duration=round(request.audio_duration_s, 2), + text=parsed_text or text, + segments=segments, + usage=usage, + ) # Default JSON format return TranscriptionResponse(text=text, usage=usage) diff --git a/python/sglang/srt/models/qwen3_asr.py b/python/sglang/srt/models/qwen3_asr.py new file mode 100644 index 000000000000..85361509a871 --- /dev/null +++ b/python/sglang/srt/models/qwen3_asr.py @@ -0,0 +1,247 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Inference-only Qwen3-ASR model compatible with HuggingFace weights.""" + +import logging +from typing import Any, Iterable, List, Optional, Tuple + +import torch +import torch.nn as nn + +from sglang.srt.configs.qwen3_asr import Qwen3ASRConfig, Qwen3ASRThinkerConfig +from sglang.srt.configs.qwen3_omni import Qwen3OmniMoeAudioEncoderConfig +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.qwen3 import Qwen3ForCausalLM +from sglang.srt.models.qwen3_omni_moe import ( + Qwen3OmniMoeAudioEncoder, + _get_feat_extract_output_lengths, +) +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + + +class Qwen3ASRForConditionalGeneration(nn.Module): + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: Qwen3ASRConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + # Extract the thinker_config which contains audio_config and text_config + thinker_config = config.thinker_config + if not isinstance(thinker_config, Qwen3ASRThinkerConfig): + thinker_config = Qwen3ASRThinkerConfig( + **( + thinker_config + if isinstance(thinker_config, dict) + else thinker_config.__dict__ + ) + ) + + audio_config = thinker_config.audio_config + if not isinstance(audio_config, Qwen3OmniMoeAudioEncoderConfig): + audio_config = Qwen3OmniMoeAudioEncoderConfig( + **( + audio_config + if isinstance(audio_config, dict) + else audio_config.__dict__ + ) + ) + + self.audio_tower = Qwen3OmniMoeAudioEncoder(audio_config) + self.language_model = Qwen3ForCausalLM( + thinker_config.text_config, + quant_config, + prefix=add_prefix("language_model", prefix), + ) + self.pattern = MultiModalityDataPaddingPatternMultimodalTokens() + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + return self.pattern.pad_input_tokens(input_ids, mm_inputs) + + def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + device = next(self.audio_tower.parameters()).device + + input_features = ( + torch.cat([item.feature for item in items]) + .type(self.audio_tower.dtype) + .to(device) + ) + + # Check if feature_attention_mask is available (not present during warmup) + has_mask = hasattr(items[0], "feature_attention_mask") and getattr( + items[0], "feature_attention_mask", None + ) is not None + + if has_mask: + feature_attention_mask = torch.cat( + [item.feature_attention_mask for item in items], dim=0 + ).type(torch.long).to(device) + + # Compute actual audio lengths from attention mask + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + + # Extract valid features using the mask (remove padding) + input_features = input_features.permute(0, 2, 1)[ + feature_attention_mask.bool() + ].permute(1, 0) + else: + # No mask: assume all features are valid (e.g., during warmup) + # input_features shape: (batch, num_mel_bins, time_steps) + batch_size = input_features.shape[0] + time_steps = input_features.shape[-1] + audio_feature_lengths = torch.full( + (batch_size,), time_steps, dtype=torch.long, device=device + ) + # Flatten batch dim: (num_mel_bins, total_time_steps) + input_features = input_features.permute(0, 2, 1).reshape( + -1, input_features.shape[1] + ).permute(1, 0) + + # Run through audio encoder + audio_outputs = self.audio_tower( + input_features, + feature_lens=audio_feature_lengths, + ) + audio_features = audio_outputs.last_hidden_state + + return audio_features + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs: Any, + ) -> torch.Tensor: + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.language_model, + data_embedding_funcs={ + Modality.AUDIO: self.get_audio_feature, + }, + positions=positions, + ) + + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # Stacked params for the LLM decoder + llm_stacked_params = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + # Audio tower VisionAttention uses qkv_proj (fused) and proj (output) + # HF weights have separate q_proj, k_proj, v_proj, out_proj + audio_stacked_params = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + + for name, loaded_weight in weights: + # Remap weight names from HuggingFace checkpoint format + if name.startswith("thinker.audio_tower."): + name = name.replace("thinker.audio_tower.", "audio_tower.", 1) + elif name.startswith("thinker.lm_head."): + name = name.replace("thinker.lm_head.", "language_model.lm_head.", 1) + elif name.startswith("thinker.model."): + name = name.replace("thinker.model.", "language_model.model.", 1) + elif name.startswith("thinker."): + name = name.replace("thinker.", "", 1) + + # Skip talker and code2wav weights (not used for ASR) + if "talker" in name or "code2wav" in name: + continue + + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + continue + + text_config = self.config.thinker_config.text_config + if getattr(text_config, "tie_word_embeddings", False) and "lm_head.weight" in name: + continue + + # Audio tower: remap out_proj -> proj for VisionAttention + if "audio_tower" in name and "out_proj" in name: + name = name.replace("out_proj", "proj") + + # Select appropriate stacked params mapping + is_audio = "audio_tower" in name + stacked_params = audio_stacked_params if is_audio else llm_stacked_params + + for param_name, weight_name, shard_id in stacked_params: + if weight_name not in name: + continue + name_tmp = name.replace(weight_name, param_name) + + if name_tmp.endswith(".bias") and name_tmp not in params_dict: + continue + if name_tmp not in params_dict: + continue + param = params_dict[name_tmp] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = Qwen3ASRForConditionalGeneration diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 839d5b74e079..bd7aaac4bc65 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -388,6 +388,7 @@ def process_mm_data( "Gemma3nProcessor", "GlmAsrProcessor", "Qwen2AudioProcessor", + "Qwen3ASRProcessor", "Qwen3OmniMoeProcessor", }: # Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107 diff --git a/python/sglang/srt/multimodal/processors/qwen3_asr.py b/python/sglang/srt/multimodal/processors/qwen3_asr.py new file mode 100644 index 000000000000..72b90398361c --- /dev/null +++ b/python/sglang/srt/multimodal/processors/qwen3_asr.py @@ -0,0 +1,151 @@ +import re +from typing import Optional, Union + +import torch + +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalProcessorOutput, +) +from sglang.srt.models.qwen3_asr import Qwen3ASRForConditionalGeneration +from sglang.srt.multimodal.processors.base_processor import ( + BaseMultimodalProcessor, + MultimodalSpecialTokens, +) + +# Default ASR prompt template for Qwen3-ASR transcription endpoint +_DEFAULT_ASR_PROMPT = ( + "<|im_start|>user\n" + "<|audio_start|><|audio_pad|><|audio_end|>" + "<|im_end|>\n" + "<|im_start|>assistant\n" +) + + +class Qwen3ASRMultimodalProcessor(BaseMultimodalProcessor): + models = [Qwen3ASRForConditionalGeneration] + + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + # Access the thinker_config for token IDs + if hasattr(hf_config, "thinker_config"): + thinker_config = hf_config.thinker_config + else: + thinker_config = hf_config + + super().__init__(hf_config, server_args, _processor, *args, **kwargs) + + # Audio special tokens + self.AUDIO_TOKEN = "<|audio_start|><|audio_pad|><|audio_end|>" + self.AUDIO_TOKEN_REGEX = re.compile( + r"<\|audio_start\|>(?:<\|audio_pad\|>)+<\|audio_end\|>" + ) + + tokenizer = self._processor.tokenizer + self.audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_start|>") + self.audio_token_id = tokenizer.convert_tokens_to_ids("<|audio_pad|>") + self.audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_end|>") + + self.mm_tokens = MultimodalSpecialTokens( + audio_token=self.AUDIO_TOKEN, + audio_token_regex=self.AUDIO_TOKEN_REGEX, + audio_token_id=self.audio_token_id, + ).build(_processor) + + self.ATTR_NAME_TO_MODALITY.update( + {"feature_attention_mask": Modality.AUDIO} + ) + + def _build_transcription_prompt(self, input_text: Union[str, list]) -> str: + """Build a prompt for the transcription endpoint. + + When the input text is empty (from /v1/audio/transcriptions), + construct the default Qwen3-ASR chat prompt with an audio placeholder. + """ + if isinstance(input_text, list): + # Token IDs - decode to text first + input_text = self._tokenizer.decode(input_text) + + if not input_text or not input_text.strip(): + return _DEFAULT_ASR_PROMPT + return input_text + + def _compute_mrope_positions( + self, + input_ids: torch.Tensor, + audio_feature_lengths: Optional[torch.Tensor] = None, + ): + """Compute MRoPE positions for Qwen3-ASR. + + For audio-only model, all 3 MRoPE dimensions get the same sequential + positions. Audio tokens get sequential positions just like text tokens. + """ + seq_len = input_ids.shape[0] + if input_ids.dim() > 1: + seq_len = input_ids.shape[-1] + + # For Qwen3-ASR, all 3 dimensions get identical sequential positions + # since audio tokens don't have spatial structure + positions = torch.arange(seq_len, dtype=torch.long) + mrope_positions = positions.unsqueeze(0).expand(3, -1).clone() + mrope_position_delta = torch.tensor([0], dtype=torch.long) + + return mrope_positions, mrope_position_delta + + def compute_mrope_positions(self, input_ids, mm_items): + """Compute M-RoPE positions for Qwen3-ASR. + + All 3 dimensions get identical sequential positions since audio + tokens have no spatial structure. + """ + if isinstance(input_ids, list): + seq_len = len(input_ids) + else: + seq_len = input_ids.shape[-1] if input_ids.dim() > 1 else input_ids.shape[0] + + positions = torch.arange(seq_len, dtype=torch.long) + mrope_positions = positions.unsqueeze(0).expand(3, -1).clone() + return mrope_positions, torch.tensor([0], dtype=torch.long) + + async def process_mm_data_async( + self, + audio_data=None, + input_text=None, + request_obj=None, + **kwargs, + ): + if not audio_data: + return None + + # Build the prompt - handles empty text from transcription endpoint + prompt = self._build_transcription_prompt(input_text) + + base_output = self.load_mm_data( + prompt=prompt, + audio_data=audio_data, + multimodal_tokens=self.mm_tokens, + ) + if base_output is None: + return None + + mm_items, input_ids, ret = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) + + # The feature_attention_mask is automatically set on audio items + # by the base processor's collect_mm_items_from_processor_output() + # since we registered it in ATTR_NAME_TO_MODALITY + + # Compute MRoPE positions + mrope_positions, mrope_position_delta = self._compute_mrope_positions( + input_ids + ) + + return MultimodalProcessorOutput( + mm_items=mm_items, + input_ids=input_ids.tolist(), + audio_start_id=self.audio_start_id, + audio_token_id=self.audio_token_id, + audio_end_id=self.audio_end_id, + mrope_positions=mrope_positions, + mrope_position_delta=mrope_position_delta, + ) diff --git a/test/manual/models/test_qwen3_asr.py b/test/manual/models/test_qwen3_asr.py new file mode 100644 index 000000000000..dd600659a300 --- /dev/null +++ b/test/manual/models/test_qwen3_asr.py @@ -0,0 +1,118 @@ +""" +Test Qwen3-ASR model support in SGLang. + +Tests /v1/audio/transcriptions endpoint (OpenAI-compatible). + +Usage: + python test/manual/models/test_qwen3_asr.py +""" + +import io +import os +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +MODEL = "Qwen/Qwen3-ASR-1.7B" +TEST_AUDIO_EN_URL = ( + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_en.wav" +) +TEST_AUDIO_ZH_URL = ( + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_zh.wav" +) +TEST_AUDIO_EN_LOCAL = "/tmp/test_qwen3_asr_en.wav" +TEST_AUDIO_ZH_LOCAL = "/tmp/test_qwen3_asr_zh.wav" + + +def download_audio(url, local_path): + """Download audio file if not already cached.""" + if os.path.exists(local_path): + with open(local_path, "rb") as f: + return f.read() + resp = requests.get(url, timeout=60) + resp.raise_for_status() + with open(local_path, "wb") as f: + f.write(resp.content) + return resp.content + + +class TestQwen3ASRTranscription(CustomTestCase): + """Test Qwen3-ASR via /v1/audio/transcriptions endpoint.""" + + @classmethod + def setUpClass(cls): + cls.model = MODEL + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--served-model-name", + "qwen3-asr", + "--trust-remote-code", + "--disable-cuda-graph", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def _transcribe(self, audio_url, local_path, language=None): + """Send a transcription request.""" + audio_bytes = download_audio(audio_url, local_path) + data = {"model": "qwen3-asr"} + if language: + data["language"] = language + response = requests.post( + self.base_url + "/v1/audio/transcriptions", + files={"file": ("audio.wav", io.BytesIO(audio_bytes), "audio/wav")}, + data=data, + timeout=120, + ) + self.assertEqual(response.status_code, 200, response.text) + return response.json() + + def test_english_transcription(self): + """Test English audio transcription.""" + result = self._transcribe(TEST_AUDIO_EN_URL, TEST_AUDIO_EN_LOCAL) + self.assertIn("text", result) + text = result["text"] + self.assertTrue(len(text) > 0, "Transcription should not be empty") + print(f"[EN Transcription] {text}") + + def test_chinese_transcription(self): + """Test Chinese audio transcription.""" + result = self._transcribe(TEST_AUDIO_ZH_URL, TEST_AUDIO_ZH_LOCAL) + self.assertIn("text", result) + text = result["text"] + self.assertTrue(len(text) > 0, "Transcription should not be empty") + print(f"[ZH Transcription] {text}") + + def test_multiple_requests_consistency(self): + """Test that repeated requests produce consistent output.""" + results = [] + for _ in range(3): + result = self._transcribe(TEST_AUDIO_EN_URL, TEST_AUDIO_EN_LOCAL) + results.append(result["text"]) + + for i in range(1, len(results)): + self.assertEqual( + results[0], + results[i], + f"Request {i+1} differs from first request", + ) + print(f"[Consistency] All 3 requests match: {results[0][:80]}...") + + +if __name__ == "__main__": + unittest.main(verbosity=3)