diff --git a/docs/reference/cli-commands.md b/docs/reference/cli-commands.md index c43068202..fd239fe6f 100644 --- a/docs/reference/cli-commands.md +++ b/docs/reference/cli-commands.md @@ -709,6 +709,49 @@ NeMo Gym Server Status: --- +### `ng_stop` / `nemo_gym_stop` + +Provides a clean way to stop servers without having to manually kill processes, or use Ctrl+C on multiple terminals. + +**Examples** +```bash +# Stop all servers +ng_stop +all=true + +# Stop specific server by name +ng_stop +name=example_single_tool_call + +# Stop server on port 8001 +ng_stop +port=8001 + +# Force stop all servers +ng_stop +all=true +force=true +``` + +**Parameters** + +```{list-table} +:header-rows: 1 +:widths: 25 10 65 + +* - Parameter + - Type + - Description +* - `all` + - bool + - Stop all servers. +* - `name` + - str + - Stop specific server by name. +* - `port` + - int + - Stop specific server by port. +* - `force` + - bool + - Force stop the specified server(s). +``` + + ## Getting Help For detailed help on any command, run it with `+help=true` or `+h=true`: diff --git a/nemo_gym/cli.py b/nemo_gym/cli.py index ca5b6b6f0..1e743de0e 100644 --- a/nemo_gym/cli.py +++ b/nemo_gym/cli.py @@ -51,7 +51,7 @@ get_global_config_dict, ) from nemo_gym.rollout_collection import E2ERolloutCollectionConfig, RolloutCollectionConfig, RolloutCollectionHelper -from nemo_gym.server_status import StatusCommand +from nemo_gym.server_commands import StatusCommand, StopCommand from nemo_gym.server_utils import ( HEAD_SERVER_KEY_NAME, HeadServer, @@ -81,6 +81,13 @@ class RunConfig(BaseNeMoGymCLIConfig): ) +class StopConfig(BaseNeMoGymCLIConfig): + all: bool = Field(default=False, description="Stop all running servers") + name: Optional[str] = Field(default=None, description="Stop server by name") + port: Optional[int] = Field(default=None, description="Stop server on specific port") + force: bool = Field(default=False, description="Force stop unresponsive servers") + + class TestConfig(RunConfig): """ Test a specific server module by running its pytest suite and optionally validating example data. @@ -249,24 +256,44 @@ def poll(self) -> None: if not self._head_server_thread.is_alive(): raise RuntimeError("Head server finished unexpectedly!") + # Clean up processes that have stopped + processes_to_delete = [] for process_name, process in self._processes.items(): - if process.poll() is not None: - proc_out, proc_err = process.communicate() - print_str = f"Process `{process_name}` finished unexpectedly!" - - if isinstance(proc_out, bytes): - proc_out = proc_out.decode("utf-8") - print_str = f"""{print_str} -Process `{process_name}` stdout: -{proc_out} -""" - if isinstance(proc_err, bytes): - proc_err = proc_err.decode("utf-8") - print_str = f"""{print_str} -Process `{process_name}` stderr: -{proc_err}""" + poll_result = process.poll() + + if poll_result is not None: + # Assume the process exited + exit_code = poll_result + + try: + proc_out, proc_err = process.communicate() + except: + proc_out, proc_err = None, None + + if exit_code <= 0: + processes_to_delete.append(process_name) + else: + print_str = f"Process `{process_name}` finished unexpectedly!" + + if isinstance(proc_out, bytes): + proc_out = proc_out.decode("utf-8") + print_str = f"""{print_str} + Process `{process_name}` stdout: + {proc_out} + """ + if isinstance(proc_err, bytes): + proc_err = proc_err.decode("utf-8") + print_str = f"""{print_str} + Process `{process_name}` stderr: + {proc_err}""" + + raise RuntimeError(print_str) + + for process_name in processes_to_delete: + del self._processes[process_name] - raise RuntimeError(print_str) + if not self._processes: + raise KeyboardInterrupt() def wait_for_dry_run_spinup(self) -> None: sleep_interval = 3 @@ -349,10 +376,29 @@ def run_forever(self) -> None: return async def sleep(): + poll_interval = 60 + sleep_interval = 1 + secs_since_last_poll = 0 + # Indefinitely while True: - self.poll() - await asyncio.sleep(60) + if secs_since_last_poll >= poll_interval: + self.poll() + secs_since_last_poll = 0 + + alive_count = 0 + for proc in self._processes.values(): + if proc.poll() is None: # still running + alive_count += 1 + + if self._processes and alive_count == 0: + print(f"\n{'#' * 100}") + print("All servers stopped. Shutting down head server...") + print(f"{'#' * 100}\n") + return + + await asyncio.sleep(sleep_interval) + secs_since_last_poll += sleep_interval try: asyncio.run(sleep()) @@ -988,6 +1034,39 @@ def version(): # pragma: no cover print(output) +def stop(): # pragma: no cover + global_config_dict = get_global_config_dict() + config = StopConfig.model_validate(global_config_dict) + + stop_cmd = StopCommand() + + # Validation to prevent multiple options from being set + options_set = sum([config.all, config.name is not None, config.port is not None]) + + if options_set == 0: + print("Error: Must specify one of: '+all=', '+name=', or '+port='") + print("\nUsage:") + print(" ng_stop +all=true # Stop all servers") + print(" ng_stop +name=example_single_tool_call # Stop specific server") + print(" ng_stop +port=8001 # Stop server on port 8001") + print(" ng_stop +all=true +force=true # Force stop all servers") + exit(1) + + if options_set > 1: + print("Error: Can only specify one of: '+all=', '+name=', or '+port='") + exit(1) + + if config.all: + results = stop_cmd.stop_all(config.force) + elif config.name: + results = stop_cmd.stop_by_name(config.name, config.force) + elif config.port: + results = stop_cmd.stop_by_port(config.port, config.force) + + stop_cmd.display_results(results) + exit() + + def reinstall(): # pragma: no cover global_config_dict = get_global_config_dict() # Just here for help diff --git a/nemo_gym/server_commands.py b/nemo_gym/server_commands.py new file mode 100644 index 000000000..ce538758b --- /dev/null +++ b/nemo_gym/server_commands.py @@ -0,0 +1,262 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from signal import SIGINT +from time import time +from typing import List + +import psutil +import requests +from devtools import pprint + +from nemo_gym.server_utils import ServerClient, ServerInstanceDisplayConfig, ServerStatus + + +class StatusCommand: + """Main class to check server status""" + + def check_health(self, server_info: ServerInstanceDisplayConfig) -> ServerStatus: + """Check if server is responding""" + if not server_info.url: + return "unknown_error" + + try: + requests.get(server_info.url, timeout=2) + return "success" + except requests.exceptions.ConnectionError: + return "connection_error" + except requests.exceptions.Timeout: + return "timeout" + except Exception: + return "unknown_error" + + def discover_servers(self) -> List[ServerInstanceDisplayConfig]: + """Find all running NeMo Gym server processes""" + + try: + head_server_config = ServerClient.load_head_server_config() + head_url = f"http://{head_server_config.host}:{head_server_config.port}" + + response = requests.get(f"{head_url}/server_instances", timeout=5) + response.raise_for_status() + instances = response.json() + + servers = [] + current_time = time() + + for inst in instances: + uptime = current_time - inst.get("start_time", current_time) + server_info = ServerInstanceDisplayConfig( + process_name=inst["process_name"], + server_type=inst["server_type"], + name=inst["name"], + host=inst.get("host"), + port=inst.get("port"), + url=inst.get("url"), + entrypoint=inst.get("entrypoint"), + pid=inst.get("pid"), + uptime_seconds=uptime, + status="unknown_error", + ) + server_info.status = self.check_health(server_info) + servers.append(server_info) + + return servers + + except (requests.RequestException, ConnectionError) as e: + print(f""" +Could not connect to head server: {e} +Is the head server running? Start it with: `ng_run` + """) + return [] + + def display_status(self, servers: List[ServerInstanceDisplayConfig]) -> None: + """Show server info in a table""" + + def format_uptime(uptime_seconds: float) -> str: + """Format uptime in a human readable format""" + minutes, seconds = divmod(uptime_seconds, 60) + hours, minutes = divmod(minutes, 60) + days, hours = divmod(hours, 24) + return f"{int(days)}d {int(hours)}h {int(minutes)}m {seconds:.1f}s" + + if not servers: + print("No NeMo Gym servers found running.") + return + + print("\nNeMo Gym Server Status:\n") + + for i, server in enumerate(servers, 1): + status_icon = "✓" if server.status == "success" else "✗" + print(f"[{i}] {status_icon} {server.process_name} ({server.server_type}/{server.name})") + display_dict = { + "server_type": server.server_type, + "name": server.name, + "port": server.port, + "pid": server.pid, + "uptime_seconds": format_uptime(server.uptime_seconds), + } + pprint(display_dict) + + healthy_count = sum(1 for s in servers if s.status == "success") + print(f""" +{len(servers)} servers found ({healthy_count} healthy, {len(servers) - healthy_count} unhealthy) +""") + + +@dataclass +class StopCommand: + """Class to stop gym servers""" + + status_cmd: StatusCommand = field(default_factory=StatusCommand) + + def stop_server(self, server_info: ServerInstanceDisplayConfig, force: bool = False) -> dict: + """Stop a single server process.""" + + try: + proc = psutil.Process(server_info.pid) + children = proc.children(recursive=True) + + if force: + for child in children: + try: + child.kill() + except psutil.NoSuchProcess: + pass + + proc.kill() + proc.wait(timeout=2) + + return { + "server": server_info, + "success": True, + "method": "force", + "message": f"Force stopped {server_info.name}", + } + else: + # Graceful shutdown, then wait + for child in children: + try: + child.send_signal(SIGINT) + except psutil.NoSuchProcess: + pass + + proc.send_signal(SIGINT) + + try: + proc.wait(timeout=10) + + return { + "server": server_info, + "success": True, + "method": "graceful", + "message": f"Gracefully stopped {server_info.name}", + } + except psutil.TimeoutExpired: + # Graceful didn't work, so terminate and wait + for child in children: + try: + child.terminate() + except psutil.NoSuchProcess: + pass + + proc.terminate() + + try: + proc.wait(timeout=5) + + return { + "server": server_info, + "success": True, + "method": "terminate", + "message": f"Terminated {server_info.name} (graceful shutdown timed out)", + } + except psutil.TimeoutExpired: + return { + "server": server_info, + "success": False, + "method": "failed", + "message": f"Failed to stop {server_info.name} - use --force", + } + except psutil.NoSuchProcess: + return { + "server": server_info, + "success": False, + "method": "no_process", + "message": f"{server_info.name} not found", + } + except psutil.AccessDenied: + return { + "server": server_info, + "success": False, + "method": "access_denied", + "message": f"Access denied to stop {server_info.name} (PID: {server_info.pid})", + } + except Exception as e: + return { + "server": server_info, + "success": False, + "method": "error", + "message": f"Error stopping {server_info.name}: {e}", + } + + def stop_all(self, force: bool = False) -> List[dict]: + """Stop all running servers""" + servers = self.status_cmd.discover_servers() + + if not servers: + return [{"success": False, "message": "No servers found"}] + + return [self.stop_server(server, force) for server in servers] + + def stop_by_name(self, name: str, force: bool = False) -> List[dict]: + """Stop a server by name""" + servers = self.status_cmd.discover_servers() + name = name.lower() + matching = next((s for s in servers if s.name.lower() == name), None) + + if not matching: + return [{"success": False, "message": f"No server found with name: {name}"}] + + return [self.stop_server(matching, force)] + + def stop_by_port(self, port: int, force: bool = False) -> List[dict]: + """Stop a server on a specific port""" + servers = self.status_cmd.discover_servers() + matching = next((s for s in servers if s.port == port), None) + + if not matching: + return [{"success": False, "message": f"No server found with port: {port}"}] + + return [self.stop_server(matching, force)] + + def display_results(self, results: List[dict]) -> None: + """Display stop results""" + print("\nStopping NeMo Gym servers...\n") + + success_count = 0 + failure_count = 0 + for result in results: + if result["success"]: + success_count += 1 + icon = "✓" + else: + failure_count += 1 + icon = "✗" + + print(f"{icon} {result['message']}") + + total_count = len(results) + print(f"\n{success_count} of {total_count} servers stopped successfully, {failure_count} failed") diff --git a/nemo_gym/server_status.py b/nemo_gym/server_status.py deleted file mode 100644 index cfa65eb06..000000000 --- a/nemo_gym/server_status.py +++ /dev/null @@ -1,113 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from time import time -from typing import List - -import requests -from devtools import pprint - -from nemo_gym.server_utils import ServerClient, ServerInstanceDisplayConfig, ServerStatus - - -class StatusCommand: - """Main class to check server status""" - - def check_health(self, server_info: ServerInstanceDisplayConfig) -> ServerStatus: - """Check if server is responding""" - if not server_info.url: - return "unknown_error" - - try: - requests.get(server_info.url, timeout=2) - return "success" - except requests.exceptions.ConnectionError: - return "connection_error" - except requests.exceptions.Timeout: - return "timeout" - except Exception: - return "unknown_error" - - def discover_servers(self) -> List[ServerInstanceDisplayConfig]: - """Find all running NeMo Gym server processes""" - - try: - head_server_config = ServerClient.load_head_server_config() - head_url = f"http://{head_server_config.host}:{head_server_config.port}" - - response = requests.get(f"{head_url}/server_instances", timeout=5) - response.raise_for_status() - instances = response.json() - - servers = [] - current_time = time() - - for inst in instances: - uptime = current_time - inst.get("start_time", current_time) - server_info = ServerInstanceDisplayConfig( - process_name=inst["process_name"], - server_type=inst["server_type"], - name=inst["name"], - host=inst.get("host"), - port=inst.get("port"), - url=inst.get("url"), - entrypoint=inst.get("entrypoint"), - pid=inst.get("pid"), - uptime_seconds=uptime, - status="unknown_error", - ) - server_info.status = self.check_health(server_info) - servers.append(server_info) - - return servers - - except (requests.RequestException, ConnectionError) as e: - print(f""" -Could not connect to head server: {e} -Is the head server running? Start it with: `ng_run` - """) - return [] - - def display_status(self, servers: List[ServerInstanceDisplayConfig]) -> None: - """Show server info in a table""" - - def format_uptime(uptime_seconds: float) -> str: - """Format uptime in a human readable format""" - minutes, seconds = divmod(uptime_seconds, 60) - hours, minutes = divmod(minutes, 60) - days, hours = divmod(hours, 24) - return f"{int(days)}d {int(hours)}h {int(minutes)}m {seconds:.1f}s" - - if not servers: - print("No NeMo Gym servers found running.") - return - - print("\nNeMo Gym Server Status:\n") - - for i, server in enumerate(servers, 1): - status_icon = "✓" if server.status == "success" else "✗" - print(f"[{i}] {status_icon} {server.process_name} ({server.server_type}/{server.name})") - display_dict = { - "server_type": server.server_type, - "name": server.name, - "port": server.port, - "pid": server.pid, - "uptime_seconds": format_uptime(server.uptime_seconds), - } - pprint(display_dict) - - healthy_count = sum(1 for s in servers if s.status == "success") - print(f""" -{len(servers)} servers found ({healthy_count} healthy, {len(servers) - healthy_count} unhealthy) -""") diff --git a/nemo_gym/server_utils.py b/nemo_gym/server_utils.py index 07e988cc4..a7612dadf 100644 --- a/nemo_gym/server_utils.py +++ b/nemo_gym/server_utils.py @@ -511,6 +511,25 @@ def set_ulimit(self, target_soft_limit: int = 65535): # pragma: no cover e, ) + def setup_shutdown_message(self, app: FastAPI) -> None: # pragma: no cover + """Wrap lifespan to print shutdown message before uvicorn logs""" + main_app_lifespan = app.router.lifespan_context + + @asynccontextmanager + async def lifespan_wrapper(app): + # Startup + async with main_app_lifespan(app) as maybe_state: + yield maybe_state + + # Shutdown + print(f"\n{'#' * 100}") + print(f"Shutting down {self.config.name}...") + print(f"{'#' * 100}\n") + + sys.stdout.flush() + + app.router.lifespan_context = lifespan_wrapper + def prefix_server_logs(self) -> None: # pragma: no cover # Adapted from https://github.com/vllm-project/vllm/blob/ab74b2a27a4eb88b90356bfb4b452d29edf05574/vllm/utils/system_utils.py#L205 @@ -569,6 +588,7 @@ def run_webserver(cls) -> Optional[FastAPI]: # pragma: no cover server.set_ulimit() server.prefix_server_logs() server.setup_exception_middleware(app) + server.setup_shutdown_message(app) @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc): diff --git a/pyproject.toml b/pyproject.toml index 98a54fccd..7b7e80c6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -379,6 +379,10 @@ ng_pip_list = "nemo_gym.cli:pip_list" nemo_gym_version = "nemo_gym.cli:version" ng_version = "nemo_gym.cli:version" +# Stop server(s) +nemo_gym_stop = "nemo_gym.cli:stop" +ng_stop = "nemo_gym.cli:stop" + # Re-install Gym and dependencies nemo_gym_reinstall = "nemo_gym.cli:reinstall" ng_reinstall = "nemo_gym.cli:reinstall" diff --git a/tests/unit_tests/test_server_commands.py b/tests/unit_tests/test_server_commands.py new file mode 100644 index 000000000..c30dcbb3e --- /dev/null +++ b/tests/unit_tests/test_server_commands.py @@ -0,0 +1,576 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from io import StringIO +from unittest.mock import MagicMock + +import psutil +import requests +from pytest import MonkeyPatch + +from nemo_gym.cli import ServerInstanceDisplayConfig +from nemo_gym.server_commands import StatusCommand, StopCommand + + +class TestServerCommands: + def test_server_process_info_creation_sanity(self) -> None: + ServerInstanceDisplayConfig( + pid=12345, + server_type="resources_server", + name="test_server", + process_name="test_server", + host="127.0.0.1", + port=8000, + url="http://127.0.0.1:8000", + uptime_seconds=100, + status="success", + entrypoint="test_server", + ) + + def test_status_command_check_health(self, monkeypatch: MonkeyPatch) -> None: + cmd = StatusCommand() + + # success + server_info = ServerInstanceDisplayConfig( + pid=123, + server_type="resources_servers", + name="test_server", + process_name="test_server", + host="127.0.0.1", + port=8000, + url="http://127.0.0.1:8000", + uptime_seconds=100, + status="unknown_error", + entrypoint="app.py", + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + monkeypatch.setattr("requests.get", lambda *args, **kwargs: mock_response) + + status = cmd.check_health(server_info) + assert status == "success" + + monkeypatch.setattr( + "requests.get", lambda *args, **kwargs: (_ for _ in ()).throw(requests.exceptions.ConnectionError()) + ) + status = cmd.check_health(server_info) + assert status == "connection_error" + + # timeout + monkeypatch.setattr( + "requests.get", lambda *args, **kwargs: (_ for _ in ()).throw(requests.exceptions.Timeout()) + ) + status = cmd.check_health(server_info) + assert status == "timeout" + + # generic or other exception + monkeypatch.setattr( + "requests.get", lambda *args, **kwargs: (_ for _ in ()).throw(ValueError("Something unexpected happened")) + ) + status = cmd.check_health(server_info) + assert status == "unknown_error" + + # no url + server_info.url = None + status = cmd.check_health(server_info) + assert status == "unknown_error" + + def test_status_command_display_status_no_servers(self, monkeypatch: MonkeyPatch) -> None: + text_trap = StringIO() + monkeypatch.setattr("sys.stdout", text_trap) + + cmd = StatusCommand() + cmd.display_status([]) + + output = text_trap.getvalue() + assert "No NeMo Gym servers found running." in output + + def test_status_command_display_status_with_servers(self, monkeypatch: MonkeyPatch) -> None: + text_trap = StringIO() + monkeypatch.setattr("sys.stdout", text_trap) + + servers = [ + ServerInstanceDisplayConfig( + pid=123, + server_type="resources_servers", + name="test_resource", + process_name="test_resources_server", + host="127.0.0.1", + port=8000, + url="http://127.0.0.1:8000", + uptime_seconds=100, + status="success", + entrypoint="test_server", + ), + ServerInstanceDisplayConfig( + pid=456, + server_type="responses_api_models", + name="test_model", + process_name="test_model", + host="127.0.0.1", + port=8001, + url="http://127.0.0.1:8001", + uptime_seconds=200, + status="connection_error", + entrypoint="test_model", + ), + ] + + cmd = StatusCommand() + cmd.display_status(servers) + + output = text_trap.getvalue() + assert "2 servers found (1 healthy, 1 unhealthy)" in output + assert "123" in output + assert "456" in output + assert "test_resource" in output + assert "test_model" in output + + def test_status_command_discover_servers_connection_error(self, monkeypatch: MonkeyPatch) -> None: + text_trap = StringIO() + + monkeypatch.setattr("sys.stdout", text_trap) + + cmd = StatusCommand() + + mock_config = MagicMock() + mock_config.host = "localhost" + mock_config.port = 11000 + monkeypatch.setattr("nemo_gym.server_commands.ServerClient.load_head_server_config", lambda: mock_config) + + monkeypatch.setattr( + "requests.get", + lambda *args, **kwargs: (_ for _ in ()).throw(requests.exceptions.ConnectionError("Connection refused")), + ) + + servers = cmd.discover_servers() + + assert servers == [] + output = text_trap.getvalue() + assert "Could not connect to head server" in output + assert "Is the head server running? Start it with: `ng_run`" in output + + def test_status_command_discover_servers(self, monkeypatch: MonkeyPatch) -> None: + cmd = StatusCommand() + + mock_config = MagicMock() + mock_config.host = "localhost" + mock_config.port = 11000 + monkeypatch.setattr("nemo_gym.server_commands.ServerClient.load_head_server_config", lambda: mock_config) + + mock_response = MagicMock() + mock_response.json.return_value = [ + { + "process_name": "test_server_config", + "server_type": "resources_servers", + "name": "test_server", + "host": "127.0.0.1", + "port": 8000, + "url": "http://127.0.0.1:8000", + "entrypoint": "app.py", + "pid": 12345, + "start_time": 1000.0, + } + ] + + monkeypatch.setattr("requests.get", lambda *args, **kwargs: mock_response) + monkeypatch.setattr("nemo_gym.server_commands.time", lambda: 1100.0) # 100 seconds later + monkeypatch.setattr(cmd, "check_health", lambda s: "success") + + servers = cmd.discover_servers() + + assert len(servers) == 1 + assert servers[0].name == "test_server" + assert servers[0].pid == 12345 + assert servers[0].port == 8000 + assert servers[0].uptime_seconds == 100.0 + assert servers[0].status == "success" + + def test_stop_command_display_results(self, monkeypatch: MonkeyPatch) -> None: + text_trap = StringIO() + monkeypatch.setattr("sys.stdout", text_trap) + + cmd = StopCommand() + results = [ + {"success": True, "message": "Stopped server1"}, + {"success": True, "message": "Stopped server2"}, + {"success": False, "message": "Failed to stop server3"}, + ] + + cmd.display_results(results) + + output = text_trap.getvalue() + assert "Stopping NeMo Gym servers" in output + assert "✓" in output + assert "✗" in output + assert "2 of 3 servers stopped successfully, 1 failed" in output + + def test_stop_command_stop_server_graceful(self, monkeypatch: MonkeyPatch) -> None: + cmd = StopCommand() + + server_info = ServerInstanceDisplayConfig( + pid=99999, + server_type="resources_servers", + name="test_server", + process_name="test_server", + host="127.0.0.1", + port=8000, + url="http://127.0.0.1:8000", + uptime_seconds=100, + status="success", + entrypoint="app.py", + ) + + mock_child = MagicMock() + mock_proc = MagicMock() + mock_proc.children.return_value = [mock_child] + mock_proc.wait.return_value = None + + monkeypatch.setattr("psutil.Process", lambda pid: mock_proc) + + result = cmd.stop_server(server_info, force=False) + + assert result["success"] is True + assert result["method"] == "graceful" + mock_child.send_signal.assert_called_once() + mock_proc.send_signal.assert_called_once() + + # Mock when child process is already dead + mock_child = MagicMock() + mock_child.send_signal.side_effect = psutil.NoSuchProcess(12345) + + mock_proc = MagicMock() + mock_proc.children.return_value = [mock_child] + mock_proc.wait.return_value = None + + monkeypatch.setattr("psutil.Process", lambda pid: mock_proc) + + result = cmd.stop_server(server_info, force=False) + + assert result["success"] is True + assert result["method"] == "graceful" + mock_child.send_signal.assert_called_once() + + def test_stop_command_stop_server_timeout_then_terminate(self, monkeypatch: MonkeyPatch) -> None: + cmd = StopCommand() + + server_info = ServerInstanceDisplayConfig( + pid=99999, + server_type="resources_servers", + name="test_server", + process_name="test_server", + host="127.0.0.1", + port=8000, + url="http://127.0.0.1:8000", + uptime_seconds=100, + status="success", + entrypoint="app.py", + ) + + mock_child = MagicMock() + mock_proc = MagicMock() + mock_proc.children.return_value = [mock_child] + + mock_proc.wait.side_effect = [psutil.TimeoutExpired(1), None] + + monkeypatch.setattr("psutil.Process", lambda pid: mock_proc) + + result = cmd.stop_server(server_info, force=False) + + assert result["success"] is True + assert result["method"] == "terminate" + mock_child.terminate.assert_called_once() + mock_proc.terminate.assert_called_once() + + # Mock when child process is already dead + mock_child = MagicMock() + mock_child.terminate.side_effect = psutil.NoSuchProcess(99999) + + mock_proc = MagicMock() + mock_proc.children.return_value = [mock_child] + mock_proc.wait.side_effect = [psutil.TimeoutExpired(1), None] + + monkeypatch.setattr("psutil.Process", lambda pid: mock_proc) + + result = cmd.stop_server(server_info, force=False) + + assert result["success"] is True + assert result["method"] == "terminate" + mock_child.terminate.assert_called_once() + + def test_stop_command_stop_server_double_timeout_failure(self, monkeypatch: MonkeyPatch) -> None: + cmd = StopCommand() + + server_info = ServerInstanceDisplayConfig( + pid=99999, + server_type="resources_servers", + name="test_server", + process_name="test_server", + host="127.0.0.1", + port=8000, + url="http://127.0.0.1:8000", + uptime_seconds=100, + status="success", + entrypoint="app.py", + ) + + mock_child = MagicMock() + mock_proc = MagicMock() + mock_proc.children.return_value = [mock_child] + + mock_proc.wait.side_effect = psutil.TimeoutExpired(1) + + monkeypatch.setattr("psutil.Process", lambda pid: mock_proc) + + result = cmd.stop_server(server_info, force=False) + + assert result["success"] is False + assert result["method"] == "failed" + assert "use --force" in result["message"] + + def test_stop_command_stop_server_no_such_process(self, monkeypatch: MonkeyPatch) -> None: + cmd = StopCommand() + + server_info = ServerInstanceDisplayConfig( + pid=99999, + server_type="resources_servers", + name="test_server", + process_name="test_server", + host="127.0.0.1", + port=8000, + url="http://127.0.0.1:8000", + uptime_seconds=100, + status="success", + entrypoint="app.py", + ) + + monkeypatch.setattr("psutil.Process", lambda pid: (_ for _ in ()).throw(psutil.NoSuchProcess(99999))) + + result = cmd.stop_server(server_info, force=False) + + assert result["success"] is False + assert result["method"] == "no_process" + + def test_stop_command_stop_server_access_denied(self, monkeypatch: MonkeyPatch) -> None: + cmd = StopCommand() + + server_info = ServerInstanceDisplayConfig( + pid=99999, + server_type="resources_servers", + name="test_server", + process_name="test_server", + host="127.0.0.1", + port=8000, + url="http://127.0.0.1:8000", + uptime_seconds=100, + status="success", + entrypoint="app.py", + ) + + monkeypatch.setattr("psutil.Process", lambda pid: (_ for _ in ()).throw(psutil.AccessDenied())) + + result = cmd.stop_server(server_info, force=False) + + assert result["success"] is False + assert result["method"] == "access_denied" + + def test_stop_command_stop_server_unexpected_exception(self, monkeypatch: MonkeyPatch) -> None: + cmd = StopCommand() + + server_info = ServerInstanceDisplayConfig( + pid=99999, + server_type="resources_servers", + name="test_server", + process_name="test_server", + host="127.0.0.1", + port=8000, + url="http://127.0.0.1:8000", + uptime_seconds=100, + status="success", + entrypoint="app.py", + ) + + monkeypatch.setattr( + "psutil.Process", lambda pid: (_ for _ in ()).throw(Exception("Something unexpected happened")) + ) + + result = cmd.stop_server(server_info, force=False) + assert result["success"] is False + assert result["method"] == "error" + assert "Error stopping test_server" in result["message"] + assert "Something unexpected happened" in result["message"] + + def test_stop_command_stop_server_force(self, monkeypatch: MonkeyPatch) -> None: + server_info = ServerInstanceDisplayConfig( + pid=99999, + server_type="resources_servers", + name="test_server", + process_name="test_server", + host="127.0.0.1", + port=8000, + url="http://127.0.0.1:8000", + uptime_seconds=100, + status="success", + entrypoint="app.py", + ) + + mock_proc = MagicMock() + mock_process_cls = MagicMock(return_value=mock_proc) + monkeypatch.setattr("psutil.Process", mock_process_cls) + + cmd = StopCommand() + result = cmd.stop_server(server_info, force=True) + + assert result["success"] is True + assert result["method"] == "force" + mock_proc.kill.assert_called_once() + + # Mock when child process is already dead + mock_child = MagicMock() + mock_child.kill.side_effect = psutil.NoSuchProcess(99999) + mock_proc = MagicMock() + mock_proc.children.return_value = [mock_child] + monkeypatch.setattr("psutil.Process", lambda pid: mock_proc) + result = cmd.stop_server(server_info, force=True) + assert result["success"] is True + assert result["method"] == "force" + mock_child.kill.assert_called_once() + mock_proc.kill.assert_called_once() + + def test_stop_command_stop_all(self, monkeypatch: MonkeyPatch) -> None: + cmd = StopCommand() + + # No servers running + monkeypatch.setattr(cmd.status_cmd, "discover_servers", lambda: []) + results = cmd.stop_all(force=False) + + assert len(results) == 1 + assert results[0]["success"] is False + assert "No servers found" in results[0]["message"] + + # With running servers + servers = [ + ServerInstanceDisplayConfig( + pid=123, + server_type="resources_servers", + name="server1", + process_name="server1", + host="127.0.0.1", + port=8000, + url="http://127.0.0.1:8000", + uptime_seconds=100, + status="success", + entrypoint="app.py", + ) + ] + + mock_stop_result = { + "server": servers[0], + "success": True, + "method": "graceful", + "message": "Stopped server1", + } + + monkeypatch.setattr(cmd.status_cmd, "discover_servers", lambda: servers) + monkeypatch.setattr(cmd, "stop_server", lambda s, f: mock_stop_result) + + results = cmd.stop_all(force=False) + + assert len(results) == 1 + assert results[0]["success"] is True + + def test_stop_command_stop_by_name(self, monkeypatch: MonkeyPatch) -> None: + cmd = StopCommand() + + servers = [ + ServerInstanceDisplayConfig( + pid=123, + server_type="resources_servers", + name="test_server", + process_name="test_server", + host="127.0.0.1", + port=8000, + url="http://127.0.0.1:8000", + uptime_seconds=100, + status="success", + entrypoint="app.py", + ) + ] + + # Found by name + mock_stop_result = { + "server": servers[0], + "success": True, + "method": "graceful", + "message": "Stopped test_server", + } + + monkeypatch.setattr(cmd.status_cmd, "discover_servers", lambda: servers) + monkeypatch.setattr(cmd, "stop_server", lambda s, f: mock_stop_result) + + results = cmd.stop_by_name("test_server", force=False) + + assert len(results) == 1 + assert results[0]["success"] is True + + # Not found by name + monkeypatch.setattr(cmd.status_cmd, "discover_servers", lambda: []) + results = cmd.stop_by_name("nonexistent", force=False) + + assert len(results) == 1 + assert results[0]["success"] is False + assert "No server found" in results[0]["message"] + + def test_stop_command_stop_by_port(self, monkeypatch: MonkeyPatch) -> None: + cmd = StopCommand() + + servers = [ + ServerInstanceDisplayConfig( + pid=123, + server_type="resources_servers", + name="test_server", + process_name="test_server", + host="127.0.0.1", + port=8000, + url="http://127.0.0.1:8000", + uptime_seconds=100, + status="success", + entrypoint="app.py", + ) + ] + + # Found on port + mock_stop_result = { + "server": servers[0], + "success": True, + "method": "graceful", + "message": "Stopped server on port 8000", + } + + monkeypatch.setattr(cmd.status_cmd, "discover_servers", lambda: servers) + monkeypatch.setattr(cmd, "stop_server", lambda s, f: mock_stop_result) + + results = cmd.stop_by_port(8000, force=False) + + assert len(results) == 1 + assert results[0]["success"] is True + + # Not found on port + monkeypatch.setattr(cmd.status_cmd, "discover_servers", lambda: []) + results = cmd.stop_by_port(9999, force=False) + + assert len(results) == 1 + assert results[0]["success"] is False + assert "No server found" in results[0]["message"] diff --git a/tests/unit_tests/test_server_status.py b/tests/unit_tests/test_server_status.py index 58afd793c..95aad937c 100644 --- a/tests/unit_tests/test_server_status.py +++ b/tests/unit_tests/test_server_status.py @@ -19,7 +19,7 @@ from pytest import MonkeyPatch from nemo_gym.cli import ServerInstanceDisplayConfig -from nemo_gym.server_status import StatusCommand +from nemo_gym.server_commands import StatusCommand from nemo_gym.server_utils import ServerClient @@ -182,7 +182,7 @@ def test_discover_servers(self, monkeypatch: MonkeyPatch) -> None: monkeypatch.setattr(requests, "get", mock_get) mock_time = MagicMock(return_value=10000.0) - monkeypatch.setattr("nemo_gym.server_status.time", mock_time) + monkeypatch.setattr("nemo_gym.server_commands.time", mock_time) cmd = StatusCommand() servers = cmd.discover_servers()