Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions python/sglang/srt/managers/async_mm_data_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any, Dict, List, Optional, Union

logger = logging.getLogger(__name__)


class AsyncMMDataProcessor:
"""
Async wrapper for a multimodal processor.

Behavior:
- If the underlying processor exposes `process_mm_data_async`, call/await it directly.
- Otherwise, fall back to running a synchronous `process_mm_data` in a thread pool.
- Optionally guard per-call concurrency via an asyncio.Semaphore.
- Optionally enforce per-call timeout via asyncio.wait_for.
"""

def __init__(
self,
mm_processor: Any,
*,
max_concurrent_calls: Optional[int] = None,
timeout_s: Optional[float] = None,
) -> None:
"""
Args:
mm_processor: An object exposing either
- async def process_mm_data_async(...): -> Dict[str, Any]
or
- def process_mm_data(...): -> Dict[str, Any]
max_concurrent_calls: Optional concurrency cap for per-call execution.
timeout_s: Optional timeout (seconds) for each `process()` call.
"""
self.mm_processor = mm_processor
self.timeout_s = timeout_s

# Concurrency guard (None -> unlimited)
self.semaphore = (
asyncio.Semaphore(max_concurrent_calls) if max_concurrent_calls else None
)

# Detect async path; if missing, prepare a fallback executor for sync path
self._proc_async = getattr(mm_processor, "process_mm_data_async", None)
self.is_async = asyncio.iscoroutinefunction(self._proc_async)
self.fallback_exec: Optional[ThreadPoolExecutor] = (
ThreadPoolExecutor(max_workers=max_concurrent_calls)
if not self.is_async
else None
)

async def process(
self,
*,
image_data: Optional[List[Union[str, bytes]]] = None,
audio_data: Optional[List[Union[str, bytes]]] = None,
input_text_or_ids: Union[str, List[int], None] = None,
request_obj: Any,
**kwargs: Any,
) -> Dict[str, Any]:
"""
Public entrypoint: process a single multimodal request without blocking the event loop.
"""

async def _invoke() -> Dict[str, Any]:
if self.is_async:
# Native async implementation
return await self._proc_async(
image_data=image_data,
audio_data=audio_data,
input_text=input_text_or_ids,
request_obj=request_obj,
**kwargs,
)

# Synchronous fallback
sync_fn = getattr(self.mm_processor, "process_mm_data", None)
if not callable(sync_fn):
raise RuntimeError(
"mm_processor has neither 'process_mm_data_async' nor 'process_mm_data'."
)
loop = asyncio.get_running_loop()
fn = partial(
sync_fn,
image_data=image_data,
audio_data=audio_data,
input_text=input_text_or_ids,
request_obj=request_obj,
**kwargs,
)
return await loop.run_in_executor(self.fallback_exec, fn)

# Apply optional concurrency guard
if self.semaphore is not None:
async with self.semaphore:
if self.timeout_s is not None:
return await asyncio.wait_for(_invoke(), timeout=self.timeout_s)
return await _invoke()

# No concurrency guard
if self.timeout_s is not None:
return await asyncio.wait_for(_invoke(), timeout=self.timeout_s)
return await _invoke()

def shutdown(self) -> None:
"""Gracefully shutdown resources owned by this wrapper."""
try:
if self.fallback_exec:
self.fallback_exec.shutdown(wait=False)
except Exception:
logger.exception(
"Error while shutting down fallback executor in AsyncMMDataProcessor"
)

def __del__(self):
# Best-effort shutdown
try:
self.shutdown()
except Exception:
pass
10 changes: 8 additions & 2 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.lora.lora_registry import LoRARegistry
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
AbortReq,
Expand Down Expand Up @@ -215,6 +216,11 @@ def __init__(
self.mm_processor = get_mm_processor(
self.model_config.hf_config, server_args, _processor, transport_mode
)
self.mm_data_processor = AsyncMMDataProcessor(
self.mm_processor,
max_concurrent_calls=self.server_args.mm_max_concurrent_calls,
timeout_s=self.server_args.mm_per_request_timeout,
)

if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
Expand Down Expand Up @@ -598,10 +604,10 @@ async def _tokenize_one_request(
obj.image_data = [obj.image_data]
if obj.audio_data is not None and not isinstance(obj.audio_data, list):
obj.audio_data = [obj.audio_data]
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
mm_inputs: Dict = await self.mm_data_processor.process(
image_data=obj.image_data,
audio_data=obj.audio_data,
input_text=input_text or input_ids,
input_text_or_ids=(input_text or input_ids),
request_obj=obj,
max_req_input_len=self.max_req_input_len,
)
Expand Down
18 changes: 18 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,10 @@ class ServerArgs:
pdmux_config_path: Optional[str] = None
sm_group_num: int = 8

# For Multi-Modal
mm_max_concurrent_calls: int = 32
mm_per_request_timeout: float = 10.0

def __post_init__(self):
"""
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
Expand Down Expand Up @@ -3501,6 +3505,20 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Read CLI options from a config file. Must be a YAML file with configuration options.",
)

# For Multi-Modal
parser.add_argument(
"--mm-max-concurrent-calls",
type=int,
default=ServerArgs.mm_max_concurrent_calls,
help="The max concurrent calls for async mm data processing.",
)
parser.add_argument(
"--mm-per-request-timeout",
type=int,
default=ServerArgs.mm_per_request_timeout,
help="The timeout for each multi-modal request in seconds.",
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
Expand Down
Loading
Loading