diff --git a/python/sglang/srt/managers/async_mm_data_processor.py b/python/sglang/srt/managers/async_mm_data_processor.py deleted file mode 100644 index 85e8580cb769..000000000000 --- a/python/sglang/srt/managers/async_mm_data_processor.py +++ /dev/null @@ -1,122 +0,0 @@ -import asyncio -import logging -from concurrent.futures import ThreadPoolExecutor -from functools import partial -from typing import Any, Dict, List, Optional, Union - -logger = logging.getLogger(__name__) - - -class AsyncMMDataProcessor: - """ - Async wrapper for a multimodal processor. - - Behavior: - - If the underlying processor exposes `process_mm_data_async`, call/await it directly. - - Otherwise, fall back to running a synchronous `process_mm_data` in a thread pool. - - Optionally guard per-call concurrency via an asyncio.Semaphore. - - Optionally enforce per-call timeout via asyncio.wait_for. - """ - - def __init__( - self, - mm_processor: Any, - *, - max_concurrent_calls: Optional[int] = None, - timeout_s: Optional[float] = None, - ) -> None: - """ - Args: - mm_processor: An object exposing either - - async def process_mm_data_async(...): -> Dict[str, Any] - or - - def process_mm_data(...): -> Dict[str, Any] - max_concurrent_calls: Optional concurrency cap for per-call execution. - timeout_s: Optional timeout (seconds) for each `process()` call. - """ - self.mm_processor = mm_processor - self.timeout_s = timeout_s - - # Concurrency guard (None -> unlimited) - self.semaphore = ( - asyncio.Semaphore(max_concurrent_calls) if max_concurrent_calls else None - ) - - # Detect async path; if missing, prepare a fallback executor for sync path - self._proc_async = getattr(mm_processor, "process_mm_data_async", None) - self.is_async = asyncio.iscoroutinefunction(self._proc_async) - self.fallback_exec: Optional[ThreadPoolExecutor] = ( - ThreadPoolExecutor(max_workers=max_concurrent_calls) - if not self.is_async - else None - ) - - async def process( - self, - *, - image_data: Optional[List[Union[str, bytes]]] = None, - audio_data: Optional[List[Union[str, bytes]]] = None, - input_text_or_ids: Union[str, List[int], None] = None, - request_obj: Any, - **kwargs: Any, - ) -> Dict[str, Any]: - """ - Public entrypoint: process a single multimodal request without blocking the event loop. - """ - - async def _invoke() -> Dict[str, Any]: - if self.is_async: - # Native async implementation - return await self._proc_async( - image_data=image_data, - audio_data=audio_data, - input_text=input_text_or_ids, - request_obj=request_obj, - **kwargs, - ) - - # Synchronous fallback - sync_fn = getattr(self.mm_processor, "process_mm_data", None) - if not callable(sync_fn): - raise RuntimeError( - "mm_processor has neither 'process_mm_data_async' nor 'process_mm_data'." - ) - loop = asyncio.get_running_loop() - fn = partial( - sync_fn, - image_data=image_data, - audio_data=audio_data, - input_text=input_text_or_ids, - request_obj=request_obj, - **kwargs, - ) - return await loop.run_in_executor(self.fallback_exec, fn) - - # Apply optional concurrency guard - if self.semaphore is not None: - async with self.semaphore: - if self.timeout_s is not None: - return await asyncio.wait_for(_invoke(), timeout=self.timeout_s) - return await _invoke() - - # No concurrency guard - if self.timeout_s is not None: - return await asyncio.wait_for(_invoke(), timeout=self.timeout_s) - return await _invoke() - - def shutdown(self) -> None: - """Gracefully shutdown resources owned by this wrapper.""" - try: - if self.fallback_exec: - self.fallback_exec.shutdown(wait=False) - except Exception: - logger.exception( - "Error while shutting down fallback executor in AsyncMMDataProcessor" - ) - - def __del__(self): - # Best-effort shutdown - try: - self.shutdown() - except Exception: - pass diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3c1bec753100..34497a6e1dc0 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -43,7 +43,6 @@ from sglang.srt.environ import envs from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer -from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( AbortReq, @@ -271,11 +270,6 @@ def init_tokenizer_and_processor(self): self.mm_processor = get_mm_processor( self.model_config.hf_config, server_args, _processor, transport_mode ) - self.mm_data_processor = AsyncMMDataProcessor( - self.mm_processor, - max_concurrent_calls=self.server_args.mm_max_concurrent_calls, - timeout_s=self.server_args.mm_per_request_timeout, - ) if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None @@ -734,10 +728,10 @@ async def _tokenize_one_request( need_wait_for_mm_inputs=obj.need_wait_for_mm_inputs, ) if mm_inputs is None: - mm_inputs: Dict = await self.mm_data_processor.process( + mm_inputs: Dict = await self.mm_processor.process_mm_data_async( image_data=obj.image_data, audio_data=obj.audio_data, - input_text_or_ids=(input_text or input_ids), + input_text=(input_text or input_ids), request_obj=obj, max_req_input_len=self.max_req_input_len, ) @@ -748,10 +742,10 @@ async def _tokenize_one_request( ): # In language_only mode with zmq_to_scheduler, if we didn't dispatch # to encoder (e.g., only one image), process locally like non-language_only mode - mm_inputs: Dict = await self.mm_data_processor.process( + mm_inputs: Dict = await self.mm_processor.process_mm_data_async( image_data=obj.image_data, audio_data=obj.audio_data, - input_text_or_ids=(input_text or input_ids), + input_text=(input_text or input_ids), request_obj=obj, max_req_input_len=self.max_req_input_len, ) diff --git a/python/sglang/srt/multimodal/processors/llava.py b/python/sglang/srt/multimodal/processors/llava.py index 55a9fa686a18..bdc28f1d9a83 100644 --- a/python/sglang/srt/multimodal/processors/llava.py +++ b/python/sglang/srt/multimodal/processors/llava.py @@ -1,4 +1,5 @@ import asyncio +import os from typing import Dict, List, Optional, Union import numpy as np @@ -96,7 +97,7 @@ async def _process_single_image( ): if self.cpu_executor is not None: loop = asyncio.get_running_loop() - return await loop.run_in_executor( + fut = loop.run_in_executor( self.cpu_executor, LlavaImageProcessor._process_single_image_task, image_data, @@ -104,6 +105,8 @@ async def _process_single_image( grid_pinpoints, self._processor, ) + timeout = int(os.environ.get("REQUEST_TIMEOUT", "10")) + return await asyncio.wait_for(fut, timeout=timeout) else: return self._process_single_image_task( image_data, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 92e524a6be45..1f59f7263af2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -720,8 +720,6 @@ class ServerArgs: sm_group_num: int = 8 # For Multi-Modal - mm_max_concurrent_calls: int = 32 - mm_per_request_timeout: float = 10.0 enable_broadcast_mm_inputs_process: bool = False enable_prefix_mm_cache: bool = False mm_enable_dp_encoder: bool = False @@ -5809,18 +5807,6 @@ def add_cli_args(parser: argparse.ArgumentParser): ) # For Multi-Modal - parser.add_argument( - "--mm-max-concurrent-calls", - type=int, - default=ServerArgs.mm_max_concurrent_calls, - help="The max concurrent calls for async mm data processing.", - ) - parser.add_argument( - "--mm-per-request-timeout", - type=int, - default=ServerArgs.mm_per_request_timeout, - help="The timeout for each multi-modal request in seconds.", - ) parser.add_argument( "--enable-broadcast-mm-inputs-process", action="store_true", diff --git a/test/manual/test_async_mm_data_processor.py b/test/manual/test_async_mm_data_processor.py deleted file mode 100644 index 0edc2f5ccc8e..000000000000 --- a/test/manual/test_async_mm_data_processor.py +++ /dev/null @@ -1,365 +0,0 @@ -""" -Unit tests for AsyncMMDataProcessor. - -Covers: - - Async and sync processing paths - - Concurrency limiting via semaphore - - Per-call timeout behavior (async and sync) - - Argument passthrough (images, audios, text/ids, request_obj, kwargs) - - Error propagation and shutdown behavior -""" - -import asyncio -import logging -import sys -import threading -import time -from unittest.mock import Mock - -import pytest - -from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor - - -class TestAsyncMMDataProcessor: - """Test suite for AsyncMMDataProcessor.""" - - @pytest.fixture - def async_processor(self): - """Create a processor exposing an async process_mm_data_async.""" - - class AsyncProc: - async def process_mm_data_async( - self, - *, - image_data=None, - audio_data=None, - input_text=None, - request_obj=None, - **kwargs, - ): - # Allow tests to simulate latency via kwargs - delay = kwargs.get("delay_s", 0.0) - if delay: - await asyncio.sleep(delay) - return { - "path": "async", - "images": image_data, - "audios": audio_data, - "text": input_text, - "request": request_obj, - "kwargs": kwargs, - } - - return AsyncProc() - - @pytest.fixture - def sync_processor(self): - """Provide a processor exposing a sync process_mm_data.""" - - class SyncProc: - def process_mm_data( - self, - *, - image_data=None, - audio_data=None, - input_text=None, - request_obj=None, - **kwargs, - ): - delay = kwargs.get("delay_s", 0.0) - if delay: - # Simulate CPU/blocking work - time.sleep(delay) - return { - "path": "sync", - "images": image_data, - "audios": audio_data, - "text": input_text, - "request": request_obj, - "kwargs": kwargs, - } - - return SyncProc() - - @pytest.mark.asyncio - async def test_async_path_basic(self, async_processor): - """Async processor should be awaited directly.""" - proc = AsyncMMDataProcessor(async_processor) - out = await proc.process( - image_data=["img1.png"], - audio_data=["a.wav"], - input_text_or_ids="hello", - request_obj={"rid": 1}, - mode="fast", - ) - assert out["path"] == "async" - assert out["images"] == ["img1.png"] - assert out["audios"] == ["a.wav"] - assert out["text"] == "hello" - assert out["request"] == {"rid": 1} - assert out["kwargs"]["mode"] == "fast" - - @pytest.mark.asyncio - async def test_sync_fallback_basic(self, sync_processor): - """Sync processor should run in fallback executor.""" - proc = AsyncMMDataProcessor(sync_processor) - out = await proc.process( - image_data=[b"\x00\x01"], - audio_data=None, - input_text_or_ids=[1, 2, 3], - request_obj="req-obj", - role="user", - ) - assert out["path"] == "sync" - assert out["images"] == [b"\x00\x01"] - assert out["audios"] is None - assert out["text"] == [1, 2, 3] - assert out["request"] == "req-obj" - assert out["kwargs"]["role"] == "user" - - @pytest.mark.asyncio - async def test_timeout_async(self, async_processor): - """Timeout should raise asyncio.TimeoutError for async path.""" - proc = AsyncMMDataProcessor(async_processor, timeout_s=0.01) - with pytest.raises(asyncio.TimeoutError): - await proc.process( - input_text_or_ids="slow", - request_obj=None, - delay_s=0.05, # longer than timeout - ) - - @pytest.mark.asyncio - async def test_timeout_sync(self, sync_processor): - """Timeout should raise asyncio.TimeoutError for sync fallback path.""" - proc = AsyncMMDataProcessor(sync_processor, timeout_s=0.01) - with pytest.raises(asyncio.TimeoutError): - await proc.process( - input_text_or_ids="slow", - request_obj=None, - delay_s=0.05, # longer than timeout - ) - - @pytest.mark.asyncio - async def test_semaphore_release_after_timeout(self, sync_processor): - """ - If a call times out, the semaphore should be released so a subsequent call can proceed. - Use >=2 fallback workers so the timed-out thread doesn't block the next call. - """ - proc = AsyncMMDataProcessor( - sync_processor, - max_concurrent_calls=2, - timeout_s=0.01, - ) - - # First call will time out - with pytest.raises(asyncio.TimeoutError): - await proc.process( - input_text_or_ids="slow1", request_obj=None, delay_s=0.05 - ) - - # Second call should be able to acquire the semaphore and complete - out = await proc.process(input_text_or_ids="ok", request_obj=None, delay_s=0.0) - assert out["text"] == "ok" - - @pytest.mark.asyncio - async def test_concurrency_limit_async(self): - """Ensure max_concurrent_calls caps concurrency for async path.""" - current = 0 - max_seen = 0 - - class AsyncProc: - async def process_mm_data_async(self, **kwargs): - nonlocal current, max_seen - current += 1 - max_seen = max(max_seen, current) - try: - await asyncio.sleep(0.02) - return {"ok": True} - finally: - current -= 1 - - proc = AsyncMMDataProcessor(AsyncProc(), max_concurrent_calls=2) - - tasks = [ - proc.process(input_text_or_ids=f"t{i}", request_obj=None) for i in range(6) - ] - await asyncio.gather(*tasks) - - assert max_seen <= 2 - - @pytest.mark.asyncio - async def test_concurrency_limit_sync(self): - """Ensure max_concurrent_calls caps concurrency for sync fallback path.""" - current = 0 - max_seen = 0 - lock = threading.Lock() - - class SyncProc: - def process_mm_data(self, **kwargs): - nonlocal current, max_seen - with lock: - current += 1 - max_seen = max(max_seen, current) - try: - time.sleep(0.02) - return {"ok": True} - finally: - with lock: - current -= 1 - - proc = AsyncMMDataProcessor(SyncProc(), max_concurrent_calls=3) - - tasks = [ - proc.process(input_text_or_ids=f"s{i}", request_obj=None) for i in range(9) - ] - await asyncio.gather(*tasks) - - assert max_seen <= 3 - - @pytest.mark.asyncio - async def test_error_from_async_processor(self): - """Exceptions raised by the async processor should propagate.""" - - class BadAsync: - async def process_mm_data_async(self, **_): - await asyncio.sleep(0) - raise ValueError("async boom") - - proc = AsyncMMDataProcessor(BadAsync()) - with pytest.raises(ValueError, match="async boom"): - await proc.process(input_text_or_ids="x", request_obj=None) - - @pytest.mark.asyncio - async def test_error_from_sync_processor(self): - """Exceptions raised by the sync processor should propagate.""" - - class BadSync: - def process_mm_data(self, **_): - raise RuntimeError("sync boom") - - proc = AsyncMMDataProcessor(BadSync()) - with pytest.raises(RuntimeError, match="sync boom"): - await proc.process(input_text_or_ids="x", request_obj=None) - - @pytest.mark.asyncio - async def test_missing_both_methods_raises(self): - """Processor missing both methods should raise at call time.""" - - class Empty: - pass - - proc = AsyncMMDataProcessor(Empty()) - with pytest.raises( - RuntimeError, match="neither 'process_mm_data_async' nor 'process_mm_data'" - ): - await proc.process(input_text_or_ids="x", request_obj=None) - - @pytest.mark.asyncio - async def test_async_attribute_not_coroutine_uses_sync_fallback(self): - """ - If `process_mm_data_async` exists but isn't a coroutine function, - wrapper should treat it as sync and use `process_mm_data`. - """ - - class WeirdProc: - # Not a coroutine function: - def process_mm_data_async(self, **_): - return {"path": "would-be-async"} - - def process_mm_data(self, **_): - return {"path": "sync"} - - proc = AsyncMMDataProcessor(WeirdProc()) - out = await proc.process(input_text_or_ids="x", request_obj=None) - assert out["path"] == "sync" - - @pytest.mark.asyncio - async def test_kwargs_and_request_passthrough_async(self, async_processor): - """Extra kwargs and request_obj should be forwarded on async path.""" - proc = AsyncMMDataProcessor(async_processor) - out = await proc.process( - image_data=["i1", "i2"], - audio_data=["a1"], - input_text_or_ids="hello world", - request_obj={"uid": 42}, - return_meta=True, - delay_s=0.0, - ) - assert out["images"] == ["i1", "i2"] - assert out["audios"] == ["a1"] - assert out["text"] == "hello world" - assert out["request"] == {"uid": 42} - assert out["kwargs"]["return_meta"] is True - - @pytest.mark.asyncio - async def test_kwargs_and_request_passthrough_sync(self, sync_processor): - """Extra kwargs and request_obj should be forwarded on sync path.""" - proc = AsyncMMDataProcessor(sync_processor) - out = await proc.process( - image_data=None, - audio_data=[], - input_text_or_ids=[101, 102], - request_obj=("r", 7), - lang="en", - ) - assert out["images"] is None - assert out["audios"] == [] - assert out["text"] == [101, 102] - assert out["request"] == ("r", 7) - assert out["kwargs"]["lang"] == "en" - - def test_shutdown_on_sync_executor(self, sync_processor): - """Explicit shutdown should close fallback executor for sync path.""" - proc = AsyncMMDataProcessor(sync_processor) - # Swap real executor for a mock to assert shutdown behavior - proc.fallback_exec = Mock() - proc.shutdown() - proc.fallback_exec.shutdown.assert_called_once_with(wait=False) - - def test_del_calls_shutdown(self, sync_processor, caplog): - """__del__ should best-effort shutdown without raising.""" - caplog.set_level(logging.DEBUG) - proc = AsyncMMDataProcessor(sync_processor) - proc.fallback_exec = Mock() - # Simulate object destruction - proc.__del__() - proc.fallback_exec.shutdown.assert_called_once_with(wait=False) - - @pytest.mark.asyncio - async def test_concurrent_mixed_requests(self, async_processor): - """Mix different payloads and ensure all complete with valid outputs.""" - proc = AsyncMMDataProcessor(async_processor, max_concurrent_calls=4) - - tasks = [ - proc.process(input_text_or_ids="t1", request_obj=1), - proc.process(image_data=["i.png"], input_text_or_ids=[9, 8], request_obj=2), - proc.process( - audio_data=["v.wav"], input_text_or_ids="speech", request_obj=3 - ), - proc.process( - image_data=[], audio_data=[], input_text_or_ids=None, request_obj=4 - ), - ] - outs = await asyncio.gather(*tasks) - assert len(outs) == 4 - for out in outs: - assert "path" in out - assert out["path"] == "async" - - @pytest.mark.asyncio - async def test_many_requests_values_match_inputs(self, sync_processor): - """For sync path, ensure each response corresponds to its specific input.""" - proc = AsyncMMDataProcessor(sync_processor, max_concurrent_calls=8) - texts = [f"msg-{i}" for i in range(10)] - tasks = [ - proc.process(input_text_or_ids=t, request_obj=i) - for i, t in enumerate(texts) - ] - outs = await asyncio.gather(*tasks) - got = [o["text"] for o in outs] - assert got == texts - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__]))