diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index bf4b24fa05a2..57762532ccc9 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -17,6 +17,8 @@ This file implements python APIs for the inference engine. """ +from __future__ import annotations + import asyncio import atexit import dataclasses @@ -98,6 +100,7 @@ ) from sglang.srt.utils.network import get_zmq_socket from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.utils.watchdog import SubprocessWatchdog from sglang.version import __version__ logger = logging.getLogger(__name__) @@ -185,6 +188,7 @@ def __init__(self, **kwargs): template_manager, port_args, scheduler_init_result, + subprocess_watchdog, ) = self._launch_subprocesses( server_args=server_args, init_tokenizer_manager_func=self.init_tokenizer_manager_func, @@ -194,6 +198,8 @@ def __init__(self, **kwargs): self.tokenizer_manager = tokenizer_manager self.template_manager = template_manager self._scheduler_init_result = scheduler_init_result + if tokenizer_manager is not None: + tokenizer_manager._subprocess_watchdog = subprocess_watchdog self.port_args = port_args self.remote_instance_transfer_engine_info = ( parse_remote_instance_transfer_engine_info_from_scheduler_infos( @@ -505,9 +511,13 @@ def _launch_scheduler_processes( server_args: ServerArgs, port_args: PortArgs, run_scheduler_process_func: Callable, - ) -> SchedulerInitResult: + ) -> Tuple[SchedulerInitResult, Optional[List]]: """Launch scheduler processes using multiprocessing. Override in subclasses for different backends (e.g. Ray). + + Returns: + Tuple of (SchedulerInitResult, scheduler_procs). + scheduler_procs is None for RayEngine (uses Ray actors instead). """ scheduler_procs = [] @@ -592,10 +602,13 @@ def wait_for_completion(): f"terminated with {proc.exitcode}" ) - return SchedulerInitResult( - scheduler_infos=scheduler_infos, - wait_for_ready=wait_for_ready, - wait_for_completion=wait_for_completion, + return ( + SchedulerInitResult( + scheduler_infos=scheduler_infos, + wait_for_ready=wait_for_ready, + wait_for_completion=wait_for_completion, + ), + scheduler_procs, ) @classmethod @@ -606,11 +619,17 @@ def _launch_subprocesses( run_scheduler_process_func: Callable, run_detokenizer_process_func: Callable, port_args: Optional[PortArgs] = None, - ) -> Tuple[TokenizerManager, TemplateManager, PortArgs, SchedulerInitResult]: + ) -> Tuple[ + TokenizerManager, + TemplateManager, + PortArgs, + SchedulerInitResult, + Optional[SubprocessWatchdog], + ]: """Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. Returns: - Tuple of (tokenizer_manager, template_manager, port_args, scheduler_init_result). + Tuple of (tokenizer_manager, template_manager, port_args, scheduler_init_result, subprocess_watchdog). """ # Configure global environment configure_logger(server_args) @@ -624,7 +643,7 @@ def _launch_subprocesses( logger.info(f"{server_args=}") # Launch scheduler processes - scheduler_init_result = cls._launch_scheduler_processes( + scheduler_init_result, scheduler_procs = cls._launch_scheduler_processes( server_args, port_args, run_scheduler_process_func ) @@ -646,6 +665,7 @@ def _launch_subprocesses( None, port_args, scheduler_init_result, + None, ) launch_dummy_health_check_server( @@ -658,6 +678,7 @@ def _launch_subprocesses( None, port_args, scheduler_init_result, + None, ) # Launch detokenizer process @@ -688,15 +709,32 @@ def _launch_subprocesses( "max_req_input_len" ] + # Set up subprocess liveness watchdog to detect crashes + # Note: RayEngine returns scheduler_procs=None as it uses Ray actors instead of mp.Process + processes = list(scheduler_procs or []) + names = [f"scheduler_{i}" for i in range(len(processes))] + processes.append(detoken_proc) + names.append("detokenizer") + subprocess_watchdog = SubprocessWatchdog( + processes=processes, process_names=names + ) + subprocess_watchdog.start() + return ( tokenizer_manager, template_manager, port_args, scheduler_init_result, + subprocess_watchdog, ) def shutdown(self): """Shutdown the engine""" + if ( + self.tokenizer_manager is not None + and self.tokenizer_manager._subprocess_watchdog is not None + ): + self.tokenizer_manager._subprocess_watchdog.stop() kill_process_tree(os.getpid(), include_parent=False) def __enter__(self): diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index eeefe3f9ba53..d64a95332ba8 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -177,6 +177,7 @@ dumps_json, orjson_response, ) +from sglang.srt.utils.watchdog import SubprocessWatchdog from sglang.utils import get_exception_traceback from sglang.version import __version__ @@ -1979,6 +1980,7 @@ def _setup_and_run_http_server( template_manager, port_args: PortArgs, scheduler_infos: List[Dict], + subprocess_watchdog: Optional[SubprocessWatchdog], execute_warmup_func: Callable = _execute_server_warmup, launch_callback: Optional[Callable[[], None]] = None, ): @@ -2001,6 +2003,10 @@ def _setup_and_run_http_server( ) ) + # Store watchdog on tokenizer_manager (single source of truth for SIGQUIT handler) + if tokenizer_manager is not None: + tokenizer_manager._subprocess_watchdog = subprocess_watchdog + if server_args.enable_metrics: add_prometheus_track_response_middleware(app) @@ -2171,13 +2177,17 @@ def launch_server( 2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library. """ # Launch subprocesses - tokenizer_manager, template_manager, port_args, scheduler_init_result = ( - Engine._launch_subprocesses( - server_args=server_args, - init_tokenizer_manager_func=init_tokenizer_manager_func, - run_scheduler_process_func=run_scheduler_process_func, - run_detokenizer_process_func=run_detokenizer_process_func, - ) + ( + tokenizer_manager, + template_manager, + port_args, + scheduler_init_result, + subprocess_watchdog, + ) = Engine._launch_subprocesses( + server_args=server_args, + init_tokenizer_manager_func=init_tokenizer_manager_func, + run_scheduler_process_func=run_scheduler_process_func, + run_detokenizer_process_func=run_detokenizer_process_func, ) _setup_and_run_http_server( @@ -2186,6 +2196,7 @@ def launch_server( template_manager, port_args, scheduler_init_result.scheduler_infos, + subprocess_watchdog, execute_warmup_func=execute_warmup_func, launch_callback=launch_callback, ) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 810165e0fa3f..9742fb4098b1 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -213,6 +213,9 @@ def __init__( # Init PD disaggregation and encoder disaggregation self.init_disaggregation() + # Subprocess liveness watchdog — set by Engine or http_server after construction + self._subprocess_watchdog = None + # Init metric collector and watchdog self.init_metric_collector_watchdog() @@ -2563,6 +2566,10 @@ def running_phase_sigquit_handler(self, signum=None, frame=None): logger.error( f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed." ) + # Stop subprocess watchdog before killing processes to prevent false-positive + # crash detection during normal shutdown + if self.tokenizer_manager._subprocess_watchdog is not None: + self.tokenizer_manager._subprocess_watchdog.stop() self.tokenizer_manager.dump_requests_before_crash() kill_process_tree(os.getpid()) diff --git a/python/sglang/srt/ray/engine.py b/python/sglang/srt/ray/engine.py index 94c26436edde..f36bdd4607b1 100644 --- a/python/sglang/srt/ray/engine.py +++ b/python/sglang/srt/ray/engine.py @@ -90,8 +90,13 @@ def _launch_scheduler_processes( server_args: ServerArgs, port_args: PortArgs, run_scheduler_process_func: Callable, - ) -> SchedulerInitResult: - """Launch schedulers as Ray actors.""" + ) -> tuple[SchedulerInitResult, None]: + """Launch schedulers as Ray actors. + + Returns: + Tuple of (RaySchedulerInitResult, None). + scheduler_procs is None since Ray uses actors instead of mp.Process. + """ if server_args.dp_size > 1: raise NotImplementedError( "Ray support for dp_size > 1 is not yet implemented. " @@ -183,8 +188,11 @@ def wait_for_completion(): except Exception as e: logger.error(f"Ray scheduler actor terminated with error: {e}") - return RaySchedulerInitResult( - scheduler_infos=scheduler_infos, - wait_for_completion=wait_for_completion, - scheduler_actors=scheduler_actors, + return ( + RaySchedulerInitResult( + scheduler_infos=scheduler_infos, + wait_for_completion=wait_for_completion, + scheduler_actors=scheduler_actors, + ), + None, ) diff --git a/python/sglang/srt/ray/http_server.py b/python/sglang/srt/ray/http_server.py index c0580838e195..c2acda83e248 100644 --- a/python/sglang/srt/ray/http_server.py +++ b/python/sglang/srt/ray/http_server.py @@ -44,13 +44,17 @@ def launch_server( if execute_warmup_func is None: execute_warmup_func = _execute_server_warmup - tokenizer_manager, template_manager, port_args, scheduler_init_result = ( - RayEngine._launch_subprocesses( - server_args, - init_tokenizer_manager_func=init_tokenizer_manager_func, - run_scheduler_process_func=run_scheduler_process_func, - run_detokenizer_process_func=run_detokenizer_process_func, - ) + ( + tokenizer_manager, + template_manager, + port_args, + scheduler_init_result, + subprocess_watchdog, + ) = RayEngine._launch_subprocesses( + server_args, + init_tokenizer_manager_func=init_tokenizer_manager_func, + run_scheduler_process_func=run_scheduler_process_func, + run_detokenizer_process_func=run_detokenizer_process_func, ) _setup_and_run_http_server( @@ -59,6 +63,7 @@ def launch_server( template_manager, port_args, scheduler_init_result.scheduler_infos, + subprocess_watchdog, execute_warmup_func=execute_warmup_func, launch_callback=launch_callback, ) diff --git a/python/sglang/srt/utils/watchdog.py b/python/sglang/srt/utils/watchdog.py index 651243d5d122..d4fde7406c20 100644 --- a/python/sglang/srt/utils/watchdog.py +++ b/python/sglang/srt/utils/watchdog.py @@ -1,12 +1,14 @@ from __future__ import annotations import logging +import os import signal import sys import threading import time from contextlib import contextmanager -from typing import Callable, Optional +from multiprocessing import Process +from typing import Callable, List, Optional import psutil @@ -159,3 +161,66 @@ def _watchdog_once(self): # Wait for some time so that the parent process can print the error. time.sleep(5) self.parent_process.send_signal(signal.SIGQUIT) + + +class SubprocessWatchdog: + """Monitors subprocess liveness and triggers SIGQUIT when a crash is detected. + + When a subprocess crashes (e.g., NCCL timeout causing C++ std::terminate()), + Python exception handlers never run, leaving the main process as a zombie + service. This watchdog polls subprocess liveness in a daemon thread and + sends SIGQUIT to trigger proper cleanup. + + See: https://github.com/sgl-project/sglang/issues/18421 + """ + + def __init__( + self, + processes: List[Process], + process_names: Optional[List[str]] = None, + interval: float = 1.0, + ): + self._processes = processes + self._names = process_names or [f"process_{i}" for i in range(len(processes))] + self._interval = interval + self._stop_event = threading.Event() + self._thread: Optional[threading.Thread] = None + + def start(self) -> None: + if self._thread is not None or not self._processes: + return + self._thread = threading.Thread( + target=self._monitor_loop, daemon=True, name="subprocess-watchdog" + ) + self._thread.start() + logger.info( + f"SubprocessWatchdog started, monitoring {len(self._processes)} process(es)" + ) + + def stop(self) -> None: + self._stop_event.set() + if self._thread is not None: + self._thread.join(timeout=self._interval * 2) + self._thread = None + + def _monitor_loop(self) -> None: + try: + while not self._stop_event.wait(self._interval): + if self._check_processes(): + return + except Exception as e: + logger.error(f"SubprocessWatchdog thread crashed: {e}", exc_info=True) + + def _check_processes(self) -> bool: + for proc, name in zip(self._processes, self._names): + if proc.is_alive() or proc.exitcode == 0: + continue + + logger.error( + f"Subprocess {name} (pid={proc.pid}) crashed " + f"with exit code {proc.exitcode}. " + f"Triggering SIGQUIT for cleanup..." + ) + os.kill(os.getpid(), signal.SIGQUIT) + return True + return False diff --git a/test/registered/unit/utils/test_subprocess_watchdog.py b/test/registered/unit/utils/test_subprocess_watchdog.py new file mode 100644 index 000000000000..075bec79d26f --- /dev/null +++ b/test/registered/unit/utils/test_subprocess_watchdog.py @@ -0,0 +1,137 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SubprocessWatchdog in watchdog.py""" + +import multiprocessing as mp +import os +import signal +import threading +import time +import unittest.mock + +from sglang.srt.utils.watchdog import SubprocessWatchdog +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-a-cpu-only") + + +def healthy_worker(): + time.sleep(10) + + +def crashing_worker(): + os._exit(1) + + +def slow_crash_worker(delay: float = 0.5): + time.sleep(delay) + os._exit(42) + + +class TestSubprocessWatchdog(CustomTestCase): + def setUp(self): + self.sigquit_triggered = threading.Event() + self._procs = [] + self._monitor = None + + original_kill = os.kill + + def mock_kill(pid, sig): + if sig == signal.SIGQUIT: + self.sigquit_triggered.set() + else: + original_kill(pid, sig) + + self._patcher = unittest.mock.patch("os.kill", side_effect=mock_kill) + self._patcher.start() + + def tearDown(self): + if self._monitor is not None: + self._monitor.stop() + self._patcher.stop() + for p in self._procs: + if p.is_alive(): + p.terminate() + p.join(timeout=1) + + def _spawn(self, target, args=()): + proc = mp.Process(target=target, args=args) + proc.start() + self._procs.append(proc) + return proc + + def _watch(self, procs, names=None, interval=0.1): + if not isinstance(procs, list): + procs = [procs] + self._monitor = SubprocessWatchdog( + processes=procs, + process_names=names, + interval=interval, + ) + self._monitor.start() + return self._monitor + + def test_healthy_processes_no_sigquit(self): + proc = self._spawn(healthy_worker) + self._watch(proc) + time.sleep(0.5) + self.assertFalse(self.sigquit_triggered.is_set()) + + def test_crashed_process_triggers_sigquit(self): + proc = self._spawn(slow_crash_worker, args=(0.2,)) + self._watch(proc) + self.assertTrue( + self.sigquit_triggered.wait(timeout=2.0), + "SIGQUIT was not triggered within timeout", + ) + + def test_immediate_crash_detection(self): + proc = self._spawn(crashing_worker) + self._watch(proc, interval=0.05) + self.assertTrue( + self.sigquit_triggered.wait(timeout=1.0), + "Immediate crash was not detected", + ) + + def test_multiple_processes_one_crashes(self): + healthy = self._spawn(healthy_worker) + crashing = self._spawn(slow_crash_worker, args=(0.2,)) + self._watch([healthy, crashing], names=["healthy", "crashing"]) + self.assertTrue( + self.sigquit_triggered.wait(timeout=2.0), + "Crash was not detected when one of multiple processes crashed", + ) + + def test_empty_processes_list(self): + self._watch([], interval=0.1) + time.sleep(0.3) + self.assertFalse(self.sigquit_triggered.is_set()) + + def test_normal_exit_no_sigquit(self): + proc = self._spawn(lambda: None) + proc.join(timeout=2) + self._watch(proc) + time.sleep(0.3) + self.assertFalse( + self.sigquit_triggered.is_set(), + "SIGQUIT should not be triggered for normal exit (exitcode=0)", + ) + + +if __name__ == "__main__": + mp.set_start_method("spawn", force=True) + import unittest + + unittest.main()