diff --git a/python/sglang/srt/managers/async_mm_data_processor.py b/python/sglang/srt/managers/async_mm_data_processor.py new file mode 100644 index 000000000000..85e8580cb769 --- /dev/null +++ b/python/sglang/srt/managers/async_mm_data_processor.py @@ -0,0 +1,122 @@ +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from typing import Any, Dict, List, Optional, Union + +logger = logging.getLogger(__name__) + + +class AsyncMMDataProcessor: + """ + Async wrapper for a multimodal processor. + + Behavior: + - If the underlying processor exposes `process_mm_data_async`, call/await it directly. + - Otherwise, fall back to running a synchronous `process_mm_data` in a thread pool. + - Optionally guard per-call concurrency via an asyncio.Semaphore. + - Optionally enforce per-call timeout via asyncio.wait_for. + """ + + def __init__( + self, + mm_processor: Any, + *, + max_concurrent_calls: Optional[int] = None, + timeout_s: Optional[float] = None, + ) -> None: + """ + Args: + mm_processor: An object exposing either + - async def process_mm_data_async(...): -> Dict[str, Any] + or + - def process_mm_data(...): -> Dict[str, Any] + max_concurrent_calls: Optional concurrency cap for per-call execution. + timeout_s: Optional timeout (seconds) for each `process()` call. + """ + self.mm_processor = mm_processor + self.timeout_s = timeout_s + + # Concurrency guard (None -> unlimited) + self.semaphore = ( + asyncio.Semaphore(max_concurrent_calls) if max_concurrent_calls else None + ) + + # Detect async path; if missing, prepare a fallback executor for sync path + self._proc_async = getattr(mm_processor, "process_mm_data_async", None) + self.is_async = asyncio.iscoroutinefunction(self._proc_async) + self.fallback_exec: Optional[ThreadPoolExecutor] = ( + ThreadPoolExecutor(max_workers=max_concurrent_calls) + if not self.is_async + else None + ) + + async def process( + self, + *, + image_data: Optional[List[Union[str, bytes]]] = None, + audio_data: Optional[List[Union[str, bytes]]] = None, + input_text_or_ids: Union[str, List[int], None] = None, + request_obj: Any, + **kwargs: Any, + ) -> Dict[str, Any]: + """ + Public entrypoint: process a single multimodal request without blocking the event loop. + """ + + async def _invoke() -> Dict[str, Any]: + if self.is_async: + # Native async implementation + return await self._proc_async( + image_data=image_data, + audio_data=audio_data, + input_text=input_text_or_ids, + request_obj=request_obj, + **kwargs, + ) + + # Synchronous fallback + sync_fn = getattr(self.mm_processor, "process_mm_data", None) + if not callable(sync_fn): + raise RuntimeError( + "mm_processor has neither 'process_mm_data_async' nor 'process_mm_data'." + ) + loop = asyncio.get_running_loop() + fn = partial( + sync_fn, + image_data=image_data, + audio_data=audio_data, + input_text=input_text_or_ids, + request_obj=request_obj, + **kwargs, + ) + return await loop.run_in_executor(self.fallback_exec, fn) + + # Apply optional concurrency guard + if self.semaphore is not None: + async with self.semaphore: + if self.timeout_s is not None: + return await asyncio.wait_for(_invoke(), timeout=self.timeout_s) + return await _invoke() + + # No concurrency guard + if self.timeout_s is not None: + return await asyncio.wait_for(_invoke(), timeout=self.timeout_s) + return await _invoke() + + def shutdown(self) -> None: + """Gracefully shutdown resources owned by this wrapper.""" + try: + if self.fallback_exec: + self.fallback_exec.shutdown(wait=False) + except Exception: + logger.exception( + "Error while shutting down fallback executor in AsyncMMDataProcessor" + ) + + def __del__(self): + # Best-effort shutdown + try: + self.shutdown() + except Exception: + pass diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d5d6b2ef18b3..67476fa4cbc2 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -43,6 +43,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.lora.lora_registry import LoRARegistry from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer +from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( AbortReq, @@ -215,6 +216,11 @@ def __init__( self.mm_processor = get_mm_processor( self.model_config.hf_config, server_args, _processor, transport_mode ) + self.mm_data_processor = AsyncMMDataProcessor( + self.mm_processor, + max_concurrent_calls=self.server_args.mm_max_concurrent_calls, + timeout_s=self.server_args.mm_per_request_timeout, + ) if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None @@ -598,10 +604,10 @@ async def _tokenize_one_request( obj.image_data = [obj.image_data] if obj.audio_data is not None and not isinstance(obj.audio_data, list): obj.audio_data = [obj.audio_data] - mm_inputs: Dict = await self.mm_processor.process_mm_data_async( + mm_inputs: Dict = await self.mm_data_processor.process( image_data=obj.image_data, audio_data=obj.audio_data, - input_text=input_text or input_ids, + input_text_or_ids=(input_text or input_ids), request_obj=obj, max_req_input_len=self.max_req_input_len, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f1310cc959e2..8b7aa36856e2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -533,6 +533,10 @@ class ServerArgs: pdmux_config_path: Optional[str] = None sm_group_num: int = 8 + # For Multi-Modal + mm_max_concurrent_calls: int = 32 + mm_per_request_timeout: float = 10.0 + def __post_init__(self): """ Orchestrates the handling of various server arguments, ensuring proper configuration and validation. @@ -3501,6 +3505,20 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Read CLI options from a config file. Must be a YAML file with configuration options.", ) + # For Multi-Modal + parser.add_argument( + "--mm-max-concurrent-calls", + type=int, + default=ServerArgs.mm_max_concurrent_calls, + help="The max concurrent calls for async mm data processing.", + ) + parser.add_argument( + "--mm-per-request-timeout", + type=int, + default=ServerArgs.mm_per_request_timeout, + help="The timeout for each multi-modal request in seconds.", + ) + @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size diff --git a/test/srt/test_async_mm_data_processor.py b/test/srt/test_async_mm_data_processor.py new file mode 100644 index 000000000000..65a8dff528dd --- /dev/null +++ b/test/srt/test_async_mm_data_processor.py @@ -0,0 +1,364 @@ +""" +Unit tests for AsyncMMDataProcessor. + +Covers: + - Async and sync processing paths + - Concurrency limiting via semaphore + - Per-call timeout behavior (async and sync) + - Argument passthrough (images, audios, text/ids, request_obj, kwargs) + - Error propagation and shutdown behavior +""" + +import asyncio +import logging +import threading +import time +from unittest.mock import Mock + +import pytest + +from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor + + +class TestAsyncMMDataProcessor: + """Test suite for AsyncMMDataProcessor.""" + + @pytest.fixture + def async_processor(self): + """Create a processor exposing an async process_mm_data_async.""" + + class AsyncProc: + async def process_mm_data_async( + self, + *, + image_data=None, + audio_data=None, + input_text=None, + request_obj=None, + **kwargs, + ): + # Allow tests to simulate latency via kwargs + delay = kwargs.get("delay_s", 0.0) + if delay: + await asyncio.sleep(delay) + return { + "path": "async", + "images": image_data, + "audios": audio_data, + "text": input_text, + "request": request_obj, + "kwargs": kwargs, + } + + return AsyncProc() + + @pytest.fixture + def sync_processor(self): + """Provide a processor exposing a sync process_mm_data.""" + + class SyncProc: + def process_mm_data( + self, + *, + image_data=None, + audio_data=None, + input_text=None, + request_obj=None, + **kwargs, + ): + delay = kwargs.get("delay_s", 0.0) + if delay: + # Simulate CPU/blocking work + time.sleep(delay) + return { + "path": "sync", + "images": image_data, + "audios": audio_data, + "text": input_text, + "request": request_obj, + "kwargs": kwargs, + } + + return SyncProc() + + @pytest.mark.asyncio + async def test_async_path_basic(self, async_processor): + """Async processor should be awaited directly.""" + proc = AsyncMMDataProcessor(async_processor) + out = await proc.process( + image_data=["img1.png"], + audio_data=["a.wav"], + input_text_or_ids="hello", + request_obj={"rid": 1}, + mode="fast", + ) + assert out["path"] == "async" + assert out["images"] == ["img1.png"] + assert out["audios"] == ["a.wav"] + assert out["text"] == "hello" + assert out["request"] == {"rid": 1} + assert out["kwargs"]["mode"] == "fast" + + @pytest.mark.asyncio + async def test_sync_fallback_basic(self, sync_processor): + """Sync processor should run in fallback executor.""" + proc = AsyncMMDataProcessor(sync_processor) + out = await proc.process( + image_data=[b"\x00\x01"], + audio_data=None, + input_text_or_ids=[1, 2, 3], + request_obj="req-obj", + role="user", + ) + assert out["path"] == "sync" + assert out["images"] == [b"\x00\x01"] + assert out["audios"] is None + assert out["text"] == [1, 2, 3] + assert out["request"] == "req-obj" + assert out["kwargs"]["role"] == "user" + + @pytest.mark.asyncio + async def test_timeout_async(self, async_processor): + """Timeout should raise asyncio.TimeoutError for async path.""" + proc = AsyncMMDataProcessor(async_processor, timeout_s=0.01) + with pytest.raises(asyncio.TimeoutError): + await proc.process( + input_text_or_ids="slow", + request_obj=None, + delay_s=0.05, # longer than timeout + ) + + @pytest.mark.asyncio + async def test_timeout_sync(self, sync_processor): + """Timeout should raise asyncio.TimeoutError for sync fallback path.""" + proc = AsyncMMDataProcessor(sync_processor, timeout_s=0.01) + with pytest.raises(asyncio.TimeoutError): + await proc.process( + input_text_or_ids="slow", + request_obj=None, + delay_s=0.05, # longer than timeout + ) + + @pytest.mark.asyncio + async def test_semaphore_release_after_timeout(self, sync_processor): + """ + If a call times out, the semaphore should be released so a subsequent call can proceed. + Use >=2 fallback workers so the timed-out thread doesn't block the next call. + """ + proc = AsyncMMDataProcessor( + sync_processor, + max_concurrent_calls=2, + timeout_s=0.01, + ) + + # First call will time out + with pytest.raises(asyncio.TimeoutError): + await proc.process( + input_text_or_ids="slow1", request_obj=None, delay_s=0.05 + ) + + # Second call should be able to acquire the semaphore and complete + out = await proc.process(input_text_or_ids="ok", request_obj=None, delay_s=0.0) + assert out["text"] == "ok" + + @pytest.mark.asyncio + async def test_concurrency_limit_async(self): + """Ensure max_concurrent_calls caps concurrency for async path.""" + current = 0 + max_seen = 0 + + class AsyncProc: + async def process_mm_data_async(self, **kwargs): + nonlocal current, max_seen + current += 1 + max_seen = max(max_seen, current) + try: + await asyncio.sleep(0.02) + return {"ok": True} + finally: + current -= 1 + + proc = AsyncMMDataProcessor(AsyncProc(), max_concurrent_calls=2) + + tasks = [ + proc.process(input_text_or_ids=f"t{i}", request_obj=None) for i in range(6) + ] + await asyncio.gather(*tasks) + + assert max_seen <= 2 + + @pytest.mark.asyncio + async def test_concurrency_limit_sync(self): + """Ensure max_concurrent_calls caps concurrency for sync fallback path.""" + current = 0 + max_seen = 0 + lock = threading.Lock() + + class SyncProc: + def process_mm_data(self, **kwargs): + nonlocal current, max_seen + with lock: + current += 1 + max_seen = max(max_seen, current) + try: + time.sleep(0.02) + return {"ok": True} + finally: + with lock: + current -= 1 + + proc = AsyncMMDataProcessor(SyncProc(), max_concurrent_calls=3) + + tasks = [ + proc.process(input_text_or_ids=f"s{i}", request_obj=None) for i in range(9) + ] + await asyncio.gather(*tasks) + + assert max_seen <= 3 + + @pytest.mark.asyncio + async def test_error_from_async_processor(self): + """Exceptions raised by the async processor should propagate.""" + + class BadAsync: + async def process_mm_data_async(self, **_): + await asyncio.sleep(0) + raise ValueError("async boom") + + proc = AsyncMMDataProcessor(BadAsync()) + with pytest.raises(ValueError, match="async boom"): + await proc.process(input_text_or_ids="x", request_obj=None) + + @pytest.mark.asyncio + async def test_error_from_sync_processor(self): + """Exceptions raised by the sync processor should propagate.""" + + class BadSync: + def process_mm_data(self, **_): + raise RuntimeError("sync boom") + + proc = AsyncMMDataProcessor(BadSync()) + with pytest.raises(RuntimeError, match="sync boom"): + await proc.process(input_text_or_ids="x", request_obj=None) + + @pytest.mark.asyncio + async def test_missing_both_methods_raises(self): + """Processor missing both methods should raise at call time.""" + + class Empty: + pass + + proc = AsyncMMDataProcessor(Empty()) + with pytest.raises( + RuntimeError, match="neither 'process_mm_data_async' nor 'process_mm_data'" + ): + await proc.process(input_text_or_ids="x", request_obj=None) + + @pytest.mark.asyncio + async def test_async_attribute_not_coroutine_uses_sync_fallback(self): + """ + If `process_mm_data_async` exists but isn't a coroutine function, + wrapper should treat it as sync and use `process_mm_data`. + """ + + class WeirdProc: + # Not a coroutine function: + def process_mm_data_async(self, **_): + return {"path": "would-be-async"} + + def process_mm_data(self, **_): + return {"path": "sync"} + + proc = AsyncMMDataProcessor(WeirdProc()) + out = await proc.process(input_text_or_ids="x", request_obj=None) + assert out["path"] == "sync" + + @pytest.mark.asyncio + async def test_kwargs_and_request_passthrough_async(self, async_processor): + """Extra kwargs and request_obj should be forwarded on async path.""" + proc = AsyncMMDataProcessor(async_processor) + out = await proc.process( + image_data=["i1", "i2"], + audio_data=["a1"], + input_text_or_ids="hello world", + request_obj={"uid": 42}, + return_meta=True, + delay_s=0.0, + ) + assert out["images"] == ["i1", "i2"] + assert out["audios"] == ["a1"] + assert out["text"] == "hello world" + assert out["request"] == {"uid": 42} + assert out["kwargs"]["return_meta"] is True + + @pytest.mark.asyncio + async def test_kwargs_and_request_passthrough_sync(self, sync_processor): + """Extra kwargs and request_obj should be forwarded on sync path.""" + proc = AsyncMMDataProcessor(sync_processor) + out = await proc.process( + image_data=None, + audio_data=[], + input_text_or_ids=[101, 102], + request_obj=("r", 7), + lang="en", + ) + assert out["images"] is None + assert out["audios"] == [] + assert out["text"] == [101, 102] + assert out["request"] == ("r", 7) + assert out["kwargs"]["lang"] == "en" + + def test_shutdown_on_sync_executor(self, sync_processor): + """Explicit shutdown should close fallback executor for sync path.""" + proc = AsyncMMDataProcessor(sync_processor) + # Swap real executor for a mock to assert shutdown behavior + proc.fallback_exec = Mock() + proc.shutdown() + proc.fallback_exec.shutdown.assert_called_once_with(wait=False) + + def test_del_calls_shutdown(self, sync_processor, caplog): + """__del__ should best-effort shutdown without raising.""" + caplog.set_level(logging.DEBUG) + proc = AsyncMMDataProcessor(sync_processor) + proc.fallback_exec = Mock() + # Simulate object destruction + proc.__del__() + proc.fallback_exec.shutdown.assert_called_once_with(wait=False) + + @pytest.mark.asyncio + async def test_concurrent_mixed_requests(self, async_processor): + """Mix different payloads and ensure all complete with valid outputs.""" + proc = AsyncMMDataProcessor(async_processor, max_concurrent_calls=4) + + tasks = [ + proc.process(input_text_or_ids="t1", request_obj=1), + proc.process(image_data=["i.png"], input_text_or_ids=[9, 8], request_obj=2), + proc.process( + audio_data=["v.wav"], input_text_or_ids="speech", request_obj=3 + ), + proc.process( + image_data=[], audio_data=[], input_text_or_ids=None, request_obj=4 + ), + ] + outs = await asyncio.gather(*tasks) + assert len(outs) == 4 + for out in outs: + assert "path" in out + assert out["path"] == "async" + + @pytest.mark.asyncio + async def test_many_requests_values_match_inputs(self, sync_processor): + """For sync path, ensure each response corresponds to its specific input.""" + proc = AsyncMMDataProcessor(sync_processor, max_concurrent_calls=8) + texts = [f"msg-{i}" for i in range(10)] + tasks = [ + proc.process(input_text_or_ids=t, request_obj=i) + for i, t in enumerate(texts) + ] + outs = await asyncio.gather(*tasks) + got = [o["text"] for o in outs] + assert got == texts + + +if __name__ == "__main__": + pytest.main([__file__])