diff --git a/src/rai_sim/rai_sim/o3de/o3de_bridge.py b/src/rai_sim/rai_sim/o3de/o3de_bridge.py index 924465879..8e6de7658 100644 --- a/src/rai_sim/rai_sim/o3de/o3de_bridge.py +++ b/src/rai_sim/rai_sim/o3de/o3de_bridge.py @@ -20,6 +20,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set +import psutil import yaml from geometry_msgs.msg import Point, PoseStamped, Quaternion from geometry_msgs.msg import Pose as ROS2Pose @@ -32,6 +33,7 @@ from rai_sim.simulation_bridge import ( Entity, Pose, + Process, Rotation, SceneState, SimulationBridge, @@ -68,10 +70,12 @@ def __init__( self.current_sim_process = None self.current_robotic_stack_process = None self.current_binary_path = None + self.current_binary_level = None def shutdown(self): self._shutdown_binary() self._shutdown_robotic_stack() + self._processes = [] def _shutdown_binary(self): if not self.current_sim_process: @@ -92,7 +96,6 @@ def _shutdown_binary(self): def _shutdown_robotic_stack(self): if not self.current_robotic_stack_process: return - self.current_robotic_stack_process.send_signal(signal.SIGINT) self.current_robotic_stack_process.wait() @@ -271,17 +274,20 @@ def setup_scene( self, simulation_config: O3DExROS2SimulationConfig, ): - if self.current_binary_path != simulation_config.binary_path: + if ( + self.current_binary_path != simulation_config.binary_path + or self.current_binary_level != simulation_config.level + ): if self.current_sim_process: self.shutdown() self._launch_binary(simulation_config) self._launch_robotic_stack(simulation_config) self.current_binary_path = simulation_config.binary_path + self.current_binary_level = simulation_config.level else: while self.spawned_entities: self._despawn_entity(self.spawned_entities[0]) - self.logger.info(f"Entities after despawn: {self.spawned_entities}") for entity in simulation_config.entities: self._spawn_entity(entity) @@ -306,6 +312,13 @@ def _launch_binary( ): raise RuntimeError("ROS2 stack is not ready in time.") + self._processes.append( + Process( + name=psutil.Process(self.current_sim_process.pid).name(), + process=self.current_sim_process, + ) + ) + def _launch_robotic_stack(self, simulation_config: O3DExROS2SimulationConfig): command = shlex.split(simulation_config.robotic_stack_command) self.logger.info(f"Running command: {command}") @@ -319,6 +332,23 @@ def _launch_robotic_stack(self, simulation_config: O3DExROS2SimulationConfig): ): raise RuntimeError("ROS2 stack is not ready in time.") + self._processes.append( + Process( + name=psutil.Process(self.current_robotic_stack_process.pid).name(), + process=self.current_robotic_stack_process, + ) + ) + + parent = psutil.Process(self.current_robotic_stack_process.pid) + children = parent.children(recursive=True) + for child in children: + self._processes.append( + Process( + name=child.name(), + process=child, + ) + ) + def _has_process_started(self, process: subprocess.Popen[Any], timeout: int = 15): start_time = time.time() while time.time() - start_time < timeout: diff --git a/src/rai_sim/rai_sim/simulation_bridge.py b/src/rai_sim/rai_sim/simulation_bridge.py index 9241583e4..fcba550d7 100644 --- a/src/rai_sim/rai_sim/simulation_bridge.py +++ b/src/rai_sim/rai_sim/simulation_bridge.py @@ -13,10 +13,17 @@ # limitations under the License. import logging +import os +import signal +import subprocess +import threading +import time from abc import ABC, abstractmethod +from dataclasses import dataclass from pathlib import Path -from typing import Generic, List, Optional, TypeVar +from typing import Any, Generic, List, Optional, TypeVar, Union +import psutil import yaml from pydantic import BaseModel, Field, field_validator @@ -163,6 +170,12 @@ class SceneState(BaseModel): ) +@dataclass(frozen=True) +class Process: + name: str + process: Union[subprocess.Popen[Any], psutil.Process] + + SimulationConfigT = TypeVar("SimulationConfigT", bound=SimulationConfig) @@ -172,14 +185,20 @@ class SimulationBridge(ABC, Generic[SimulationConfigT]): """ def __init__(self, logger: Optional[logging.Logger] = None): - self.spawned_entities: List[ - SpawnedEntity - ] = [] # list of spawned entities with their initial poses + self.spawned_entities: List[SpawnedEntity] = [] + self._processes: List[Process] = [] + if logger is None: self.logger = logging.getLogger(__name__) else: self.logger = logger + self._monitoring_running = True + self._process_monitor_thread = threading.Thread( + target=self._monitor_processes, daemon=True + ) + self._process_monitor_thread.start() + @abstractmethod def setup_scene(self, simulation_config: SimulationConfigT): """ @@ -279,3 +298,29 @@ def get_scene_state(self) -> SceneState: SceneState should contain the current poses of spawned_entities. """ pass + + def _monitor_processes(self): + """Checks the status of managed processes and shuts everything down if one of the processes exits unexpectedly.""" + while self._monitoring_running: + for process in self._processes[:]: + if isinstance(process.process, subprocess.Popen): + if process.process.poll() is not None: + self.logger.critical( + f"Process {process.name} with PID {process.process.pid} exited unexpectedly with code {process.process.returncode}" + ) + self.logger.info("Shutting down main process.") + os.kill(os.getpid(), signal.SIGINT) + else: + if not process.process.is_running(): + self.logger.critical( + f"Process {process.name} with PID {process.process.pid} exited unexpectedly." + ) + self.logger.info("Shutting down main process.") + os.kill(os.getpid(), signal.SIGINT) + time.sleep(1) + + def stop_monitoring(self): + self._monitoring_running = False + if self._process_monitor_thread.is_alive(): + self._process_monitor_thread.join() + self.logger.info("Processes monitor thread shut down.") diff --git a/tests/rai_sim/test_o3de_bridge.py b/tests/rai_sim/test_o3de_bridge.py index b7687304a..137bc6029 100644 --- a/tests/rai_sim/test_o3de_bridge.py +++ b/tests/rai_sim/test_o3de_bridge.py @@ -118,18 +118,24 @@ def test_init(self): self.assertEqual(self.bridge.spawned_entities, []) @patch("subprocess.Popen") - def test_launch_robotic_stack(self, mock_popen): + @patch("psutil.Process") + def test_launch_robotic_stack( + self, mock_psutil_process: MagicMock, mock_popen: MagicMock + ): mock_process = MagicMock() mock_process.poll.return_value = None mock_process.pid = 54321 mock_popen.return_value = mock_process + self.bridge._launch_robotic_stack(self.test_config) mock_popen.assert_called_once_with(["ros2", "launch", "robot.launch.py"]) + mock_psutil_process.assert_any_call(mock_process.pid) self.assertEqual(self.bridge.current_robotic_stack_process, mock_process) @patch("subprocess.Popen") - def test_launch_binary(self, mock_popen): + @patch("psutil.Process") + def test_launch_binary(self, mock_psutil_process: MagicMock, mock_popen: MagicMock): mock_process = MagicMock() mock_process.poll.return_value = None mock_process.pid = 54322 @@ -138,6 +144,7 @@ def test_launch_binary(self, mock_popen): self.bridge._launch_binary(self.test_config) mock_popen.assert_called_once_with(["/path/to/binary"]) + mock_psutil_process.assert_called_once_with(mock_process.pid) self.assertEqual(self.bridge.current_sim_process, mock_process) def test_shutdown_binary(self):