Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 63 additions & 67 deletions python/sglang/srt/configs/qwen3_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,72 +14,6 @@
from sglang.utils import logger


class Qwen3ASRThinkerConfig(PretrainedConfig):
model_type = "qwen3_asr_thinker"
sub_configs = {
"audio_config": Qwen3OmniMoeAudioEncoderConfig,
}

def __init__(
self,
audio_config=None,
text_config=None,
audio_token_id=151676,
audio_start_token_id=151669,
audio_end_token_id=151670,
**kwargs,
):
super().__init__(**kwargs)

if isinstance(audio_config, dict):
audio_config = Qwen3OmniMoeAudioEncoderConfig(**audio_config)
elif audio_config is None:
audio_config = Qwen3OmniMoeAudioEncoderConfig()
self.audio_config = audio_config

if isinstance(text_config, dict):
from transformers.models.qwen3.configuration_qwen3 import (
Qwen3Config as HFQwen3Config,
)

text_config = HFQwen3Config(**text_config)
elif text_config is None:
raise ValueError(
"Qwen3ASRThinkerConfig requires a text_config dict with "
"model parameters (hidden_size, num_attention_heads, etc.). "
"Got None."
)

self.text_config = text_config

self.audio_token_id = audio_token_id
self.audio_start_token_id = audio_start_token_id
self.audio_end_token_id = audio_end_token_id


class Qwen3ASRConfig(PretrainedConfig):
model_type = "qwen3_asr"
sub_configs = {
"thinker_config": Qwen3ASRThinkerConfig,
}

def __init__(self, thinker_config=None, **kwargs):
super().__init__(**kwargs)
if thinker_config is None:
thinker_config = {}
logger.info(
"thinker_config is None. "
"Initializing Qwen3-ASR thinker with default values"
)
if isinstance(thinker_config, dict):
self.thinker_config = Qwen3ASRThinkerConfig(**thinker_config)
else:
self.thinker_config = thinker_config

def get_text_config(self, decoder=False) -> PretrainedConfig:
return self.thinker_config.text_config


class Qwen3ASRProcessor(ProcessorMixin):
"""Minimal composite processor: WhisperFeatureExtractor + Qwen2Tokenizer.

Expand Down Expand Up @@ -167,6 +101,68 @@ def __call__(self, text=None, audio=None, audio_kwargs=None, **kwargs):
return inputs


class Qwen3ASRThinkerConfig(PretrainedConfig):
model_type = "qwen3_asr_thinker"
sub_configs = {
"audio_config": Qwen3OmniMoeAudioEncoderConfig,
}

def __init__(
self,
audio_config=None,
text_config=None,
audio_token_id=151676,
audio_start_token_id=151669,
audio_end_token_id=151670,
**kwargs,
):
super().__init__(**kwargs)

if isinstance(audio_config, dict):
audio_config = Qwen3OmniMoeAudioEncoderConfig(**audio_config)
elif audio_config is None:
audio_config = Qwen3OmniMoeAudioEncoderConfig()
self.audio_config = audio_config

from transformers.models.qwen3.configuration_qwen3 import (
Qwen3Config as HFQwen3Config,
)

if isinstance(text_config, dict):
text_config = HFQwen3Config(**text_config)
elif text_config is None:
text_config = HFQwen3Config()

self.text_config = text_config

self.audio_token_id = audio_token_id
self.audio_start_token_id = audio_start_token_id
self.audio_end_token_id = audio_end_token_id


@register_customized_processor(Qwen3ASRProcessor)
class Qwen3ASRConfig(PretrainedConfig):
model_type = "qwen3_asr"
sub_configs = {
"thinker_config": Qwen3ASRThinkerConfig,
}

def __init__(self, thinker_config=None, **kwargs):
super().__init__(**kwargs)
if thinker_config is None:
thinker_config = {}
logger.info(
"thinker_config is None. "
"Initializing Qwen3-ASR thinker with default values"
)
if isinstance(thinker_config, dict):
self.thinker_config = Qwen3ASRThinkerConfig(**thinker_config)
else:
self.thinker_config = thinker_config

def get_text_config(self, decoder=False) -> PretrainedConfig:
return self.thinker_config.text_config


AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
AutoConfig.register("qwen3_asr_thinker", Qwen3ASRThinkerConfig)
register_customized_processor(Qwen3ASRProcessor)(Qwen3ASRConfig)
169 changes: 18 additions & 151 deletions python/sglang/srt/entrypoints/openai/serving_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
# limitations under the License.
# ==============================================================================
"""
OpenAI-compatible transcription endpoint handler for Whisper models.
OpenAI-compatible transcription endpoint handler for audio ASR models.

New ASR models are supported by subclassing ``TranscriptionAdapter`` and
registering via the ``@register_transcription_adapter`` decorator.
See ``transcription_adapters/`` for built-in implementations.
"""

from __future__ import annotations
Expand All @@ -32,40 +36,30 @@
ErrorResponse,
TranscriptionRequest,
TranscriptionResponse,
TranscriptionSegment,
TranscriptionStreamChoice,
TranscriptionStreamResponse,
TranscriptionUsage,
TranscriptionVerboseResponse,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.transcription_adapters import resolve_adapter
from sglang.srt.managers.io_struct import GenerateReqInput

if TYPE_CHECKING:
from sglang.srt.managers.tokenizer_manager import TokenizerManager

logger = logging.getLogger(__name__)

# Whisper timestamp token constants
TIMESTAMP_BASE_TOKEN_ID = 50365 # <|0.00|>
TIMESTAMP_BASE_OFFSET = 0.02 # Each token step = 0.02 seconds

_QWEN3_ASR_TEXT_TAG = "<asr_text>"


def _detect_model_family(model_config) -> str:
archs = getattr(getattr(model_config, "hf_config", None), "architectures", []) or []
if "Qwen3ASRForConditionalGeneration" in archs:
return "qwen3_asr"
return "whisper"


class OpenAIServingTranscription(OpenAIServingBase):
"""Handler for /v1/audio/transcriptions requests"""

def __init__(self, tokenizer_manager: TokenizerManager):
super().__init__(tokenizer_manager)
self._model_family = _detect_model_family(tokenizer_manager.model_config)
model_config = tokenizer_manager.model_config
self._adapter = resolve_adapter(
getattr(model_config.hf_config, "architectures", [])
)

def _request_id_prefix(self) -> str:
return "trsc-"
Expand All @@ -81,40 +75,9 @@ def _convert_to_internal_request(
raw_request: Request = None,
) -> tuple[GenerateReqInput, TranscriptionRequest]:
"""Convert transcription request to internal format."""
if self._model_family == "qwen3_asr":
prompt = (
"<|im_start|>user\n"
"<|audio_start|><|audio_pad|><|audio_end|>"
"<|im_end|>\n"
"<|im_start|>assistant\n"
)
sampling_params = {
"temperature": request.temperature,
"max_new_tokens": 1024,
}
adapted_request = GenerateReqInput(
text=prompt,
audio_data=request.audio_data,
sampling_params=sampling_params,
stream=request.stream,
modalities=["audio"],
routing_key=self.extract_routing_key(raw_request),
)
return adapted_request, request

# Build sampling params - include language for WhisperProcessor
sampling_params = {
"temperature": request.temperature,
"max_new_tokens": 448, # Whisper default max tokens
"language": request.language, # Pass to WhisperProcessor for language-specific decoding
}

if request.timestamp_granularities:
sampling_params["timestamp_granularities"] = request.timestamp_granularities

# For Whisper, we pass audio_data and let the processor handle it
sampling_params = self._adapter.build_sampling_params(request)
adapted_request = GenerateReqInput(
text="", # Empty text - Whisper processor will set proper decoder tokens
text="", # Empty text — the multimodal processor sets proper decoder/prompt tokens
audio_data=request.audio_data,
sampling_params=sampling_params,
stream=request.stream,
Expand All @@ -124,7 +87,8 @@ def _convert_to_internal_request(

return adapted_request, request

def _get_audio_duration(self, audio_data: bytes) -> float:
@staticmethod
def _get_audio_duration(audio_data: bytes) -> float:
"""Calculate audio duration in seconds."""
try:
import soundfile as sf
Expand All @@ -135,77 +99,6 @@ def _get_audio_duration(self, audio_data: bytes) -> float:
logger.warning(f"Could not calculate audio duration: {e}")
return 0.0

def _parse_segments(
self, output_ids: List[int], tokenizer
) -> tuple[str, List[TranscriptionSegment]]:
"""Parse timestamp tokens from output_ids into segments.

The decoder prompt ends with <|0.00|>, so the first segment starts at
t=0. The model then outputs:
text_tokens <|end_ts|> [<|start_ts|> text_tokens <|end_ts|> ...]
Each timestamp token marks the end of the current segment; its value
also becomes the start of the next segment.
"""
# Token IDs for special tokens we want to strip from segment text
eos_token_id = getattr(tokenizer, "eos_token_id", 50257)

segments = []
full_text_parts = []
current_text_tokens = []
current_start = 0.0 # First segment starts at 0.0 (from prompt <|0.00|>)
seg_id = 0

for token_id in output_ids:
if token_id >= TIMESTAMP_BASE_TOKEN_ID:
# This is a timestamp token — marks the end of current segment
timestamp = (token_id - TIMESTAMP_BASE_TOKEN_ID) * TIMESTAMP_BASE_OFFSET

if current_text_tokens:
text = tokenizer.decode(
current_text_tokens, skip_special_tokens=True
).strip()
if text:
segments.append(
TranscriptionSegment(
id=seg_id,
start=round(current_start, 2),
end=round(timestamp, 2),
text=text,
)
)
full_text_parts.append(text)
seg_id += 1
current_text_tokens = []

# Next segment starts at this timestamp
current_start = timestamp

elif token_id == eos_token_id:
# Skip end-of-text token
continue
else:
# Regular text token
current_text_tokens.append(token_id)

# Handle any trailing text tokens without a closing timestamp
if current_text_tokens:
text = tokenizer.decode(
current_text_tokens, skip_special_tokens=True
).strip()
if text:
segments.append(
TranscriptionSegment(
id=seg_id,
start=round(current_start, 2),
end=round(current_start, 2),
text=text,
)
)
full_text_parts.append(text)

full_text = " ".join(full_text_parts)
return full_text, segments

async def create_transcription(
self,
audio_data: bytes,
Expand Down Expand Up @@ -262,33 +155,17 @@ async def _handle_non_streaming_request(
except ValueError as e:
return self.create_error_response(str(e))

text = ret.get("text", "")
if self._model_family == "qwen3_asr":
text = _postprocess_qwen3_asr(text)
text = self._adapter.postprocess_text(ret.get("text", ""))
usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s)))

# Build response based on format
if request.response_format == "text":
return Response(content=text, media_type="text/plain")

if request.response_format == "verbose_json":
if self._model_family == "whisper":
output_ids = ret.get("output_ids", [])
tokenizer = self.tokenizer_manager.tokenizer
parsed_text, segments = self._parse_segments(output_ids, tokenizer)
return TranscriptionVerboseResponse(
language=request.language or "en",
duration=round(request.audio_duration_s, 2),
text=parsed_text or text,
segments=segments,
usage=usage,
)
return TranscriptionVerboseResponse(
language=request.language,
duration=round(request.audio_duration_s, 2),
text=text,
segments=[],
usage=usage,
tokenizer = self.tokenizer_manager.tokenizer
return self._adapter.build_verbose_response(
request, text, ret, tokenizer, usage
)

# Default JSON format
Expand Down Expand Up @@ -364,13 +241,3 @@ async def _generate_transcription_stream(
yield f"data: {error}\n\n"

yield "data: [DONE]\n\n"


# TODO (adityavaid): refactor model-specific postprocessing into a plugin/adapter mechanism.
def _postprocess_qwen3_asr(text: str) -> str:
if not text:
return ""
if _QWEN3_ASR_TEXT_TAG in text:
_, text_part = text.rsplit(_QWEN3_ASR_TEXT_TAG, 1)
return text_part.strip()
return text.strip()
Loading
Loading