Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,11 +1486,18 @@ async def openai_v1_audio_transcriptions(
response_format: str = Form(default="json"),
temperature: float = Form(default=0.0),
stream: bool = Form(default=False),
timestamp_granularities: Optional[List[str]] = Form(
default=None, alias="timestamp_granularities[]"
),
):
"""OpenAI-compatible audio transcription endpoint."""
if response_format not in ["json", "text"]:
if response_format not in ["json", "text", "verbose_json"]:
return ORJSONResponse(
content={"error": {"message": "Only 'json' and 'text' formats supported"}},
content={
"error": {
"message": "Only 'json', 'text', and 'verbose_json' formats supported"
}
},
status_code=400,
)

Expand All @@ -1504,6 +1511,7 @@ async def openai_v1_audio_transcriptions(
response_format=response_format,
temperature=temperature,
stream=stream,
timestamp_granularities=timestamp_granularities,
raw_request=raw_request,
)
)
Expand Down
21 changes: 21 additions & 0 deletions python/sglang/srt/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,6 +1443,7 @@ class TranscriptionRequest(BaseModel):
language: Optional[str] = None
response_format: str = "json"
temperature: float = 0.0
timestamp_granularities: Optional[List[str]] = None
stream: bool = False
# Internal fields (not from API)
audio_data: Optional[bytes] = None
Expand All @@ -1463,6 +1464,26 @@ class TranscriptionResponse(BaseModel):
usage: Optional[TranscriptionUsage] = None


class TranscriptionSegment(BaseModel):
"""A segment with timestamp information."""

id: int
start: float
end: float
text: str


class TranscriptionVerboseResponse(BaseModel):
"""Verbose transcription response with timestamps (OpenAI-compatible)."""

task: str = "transcribe"
language: Optional[str] = None
duration: Optional[float] = None
text: str
segments: List[TranscriptionSegment] = []
usage: Optional[TranscriptionUsage] = None


class TranscriptionStreamChoice(BaseModel):
"""Delta content for streaming transcription."""

Expand Down
121 changes: 113 additions & 8 deletions python/sglang/srt/entrypoints/openai/serving_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import math
import time
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Optional, Union
from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Union

from fastapi import Request
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
Expand All @@ -32,9 +32,11 @@
ErrorResponse,
TranscriptionRequest,
TranscriptionResponse,
TranscriptionSegment,
TranscriptionStreamChoice,
TranscriptionStreamResponse,
TranscriptionUsage,
TranscriptionVerboseResponse,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.managers.io_struct import GenerateReqInput
Expand All @@ -44,6 +46,10 @@

logger = logging.getLogger(__name__)

# Whisper timestamp token constants
TIMESTAMP_BASE_TOKEN_ID = 50365 # <|0.00|>
TIMESTAMP_BASE_OFFSET = 0.02 # Each token step = 0.02 seconds


class OpenAIServingTranscription(OpenAIServingBase):
"""Handler for /v1/audio/transcriptions requests"""
Expand Down Expand Up @@ -72,6 +78,9 @@ def _convert_to_internal_request(
"language": request.language, # Pass to WhisperProcessor for language-specific decoding
}

if request.timestamp_granularities:
sampling_params["timestamp_granularities"] = request.timestamp_granularities

# For Whisper, we pass audio_data and let the processor handle it
adapted_request = GenerateReqInput(
text="", # Empty text - Whisper processor will set proper decoder tokens
Expand All @@ -89,13 +98,83 @@ def _get_audio_duration(self, audio_data: bytes) -> float:
try:
import soundfile as sf

audio_array, sr = sf.read(io.BytesIO(audio_data))
duration = len(audio_array) / sr
return duration
info = sf.info(io.BytesIO(audio_data))
return info.duration
except Exception as e:
logger.warning(f"Could not calculate audio duration: {e}")
return 0.0

def _parse_segments(
self, output_ids: List[int], tokenizer
) -> tuple[str, List[TranscriptionSegment]]:
"""Parse timestamp tokens from output_ids into segments.

The decoder prompt ends with <|0.00|>, so the first segment starts at
t=0. The model then outputs:
text_tokens <|end_ts|> [<|start_ts|> text_tokens <|end_ts|> ...]
Each timestamp token marks the end of the current segment; its value
also becomes the start of the next segment.
"""
# Token IDs for special tokens we want to strip from segment text
eos_token_id = getattr(tokenizer, "eos_token_id", 50257)

segments = []
full_text_parts = []
current_text_tokens = []
current_start = 0.0 # First segment starts at 0.0 (from prompt <|0.00|>)
seg_id = 0

for token_id in output_ids:
if token_id >= TIMESTAMP_BASE_TOKEN_ID:
# This is a timestamp token — marks the end of current segment
timestamp = (token_id - TIMESTAMP_BASE_TOKEN_ID) * TIMESTAMP_BASE_OFFSET

if current_text_tokens:
text = tokenizer.decode(
current_text_tokens, skip_special_tokens=True
).strip()
if text:
segments.append(
TranscriptionSegment(
id=seg_id,
start=round(current_start, 2),
end=round(timestamp, 2),
text=text,
)
)
full_text_parts.append(text)
seg_id += 1
current_text_tokens = []

# Next segment starts at this timestamp
current_start = timestamp

elif token_id == eos_token_id:
# Skip end-of-text token
continue
else:
# Regular text token
current_text_tokens.append(token_id)

# Handle any trailing text tokens without a closing timestamp
if current_text_tokens:
text = tokenizer.decode(
current_text_tokens, skip_special_tokens=True
).strip()
if text:
segments.append(
TranscriptionSegment(
id=seg_id,
start=round(current_start, 2),
end=round(current_start, 2),
text=text,
)
)
full_text_parts.append(text)

full_text = " ".join(full_text_parts)
return full_text, segments

async def create_transcription(
self,
audio_data: bytes,
Expand All @@ -105,7 +184,14 @@ async def create_transcription(
temperature: float,
stream: bool,
raw_request: Request,
) -> Union[TranscriptionResponse, StreamingResponse, Response, ORJSONResponse]:
timestamp_granularities: Optional[List[str]] = None,
) -> Union[
TranscriptionResponse,
TranscriptionVerboseResponse,
StreamingResponse,
Response,
ORJSONResponse,
]:
"""Main entry point for transcription requests."""
# Calculate audio duration for usage reporting
audio_duration_s = self._get_audio_duration(audio_data)
Expand All @@ -117,6 +203,7 @@ async def create_transcription(
language=language,
response_format=response_format,
temperature=temperature,
timestamp_granularities=timestamp_granularities,
stream=stream,
audio_duration_s=audio_duration_s,
)
Expand All @@ -129,7 +216,13 @@ async def _handle_non_streaming_request(
adapted_request: GenerateReqInput,
request: TranscriptionRequest,
raw_request: Request,
) -> Union[TranscriptionResponse, ErrorResponse, ORJSONResponse, Response]:
) -> Union[
TranscriptionResponse,
TranscriptionVerboseResponse,
ErrorResponse,
ORJSONResponse,
Response,
]:
"""Handle non-streaming transcription request."""
try:
ret = await self.tokenizer_manager.generate_request(
Expand All @@ -139,14 +232,26 @@ async def _handle_non_streaming_request(
return self.create_error_response(str(e))

text = ret.get("text", "")
usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s)))

# Build response based on format
if request.response_format == "text":
return Response(content=text, media_type="text/plain")

# JSON format
usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s)))
if request.response_format == "verbose_json":
output_ids = ret.get("output_ids", [])
tokenizer = self.tokenizer_manager.tokenizer
parsed_text, segments = self._parse_segments(output_ids, tokenizer)

return TranscriptionVerboseResponse(
language=request.language or "en",
duration=round(request.audio_duration_s, 2),
text=parsed_text or text,
segments=segments,
usage=usage,
)

# Default JSON format
return TranscriptionResponse(text=text, usage=usage)

async def _handle_streaming_request(
Expand Down
9 changes: 6 additions & 3 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,16 +1048,19 @@ def update_cross_attention(
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
# Cache encoder_lens on CPU to avoid GPU→CPU transfer per call
encoder_lens_cpu = encoder_lens.cpu() if encoder_lens is not None else None
for wrapper_id in range(2):
if wrapper_id == 0:
# Normal attention
paged_kernel_lens = seq_lens
kv_start_idx = encoder_lens
kv_lens_cpu = seq_lens_cpu
else:
# Cross attention
# Cross-attention: attend to encoder tokens only
paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens)
seq_lens_sum = encoder_lens.sum().item()
kv_lens_cpu = encoder_lens_cpu

self.call_begin_forward(
decode_wrappers[wrapper_id],
Expand All @@ -1067,7 +1070,7 @@ def update_cross_attention(
self.kv_indptr[wrapper_id],
kv_start_idx,
spec_info,
seq_lens_cpu=seq_lens_cpu,
seq_lens_cpu=kv_lens_cpu,
)

def call_begin_forward(
Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,12 @@ def __init__(self, model_runner: ModelRunner):
else self.dllm_config.block_size
)

self.encoder_len_fill_value = 0
# Non-zero encoder length ensures cross-attention kernels are captured in the graph.
self.encoder_len_fill_value = (
getattr(model_runner.model_config.hf_config, "max_source_positions", 0)
if self.is_encoder_decoder
else 0
)

if self.enable_torch_compile:
set_torch_compile_config()
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2068,7 +2068,11 @@ def _dummy_run(self, batch_size: int, run_ctx=None):
is_encoder_decoder=self.model_config.is_encoder_decoder,
require_mlp_tp_gather=require_mlp_tp_gather_,
seq_len_fill_value=seq_len_fill_value,
encoder_len_fill_value=0,
encoder_len_fill_value=(
getattr(self.model_config.hf_config, "max_source_positions", 0)
if self.model_config.is_encoder_decoder
else 0
),
num_tokens_per_bs=num_tokens_per_bs,
cache_loc_dtype=torch.int64,
enable_mamba_track=False,
Expand Down
Loading
Loading