diff --git a/setup_shell.sh b/setup_shell.sh index 164372ed9..506df2cd2 100755 --- a/setup_shell.sh +++ b/setup_shell.sh @@ -33,3 +33,5 @@ PYTHONPATH="$(dirname "$(dirname "$(poetry run which python)")")/lib/python$(poe PYTHONPATH="src/rai_core:$PYTHONPATH" PYTHONPATH="src/rai_asr:$PYTHONPATH" PYTHONPATH="src/rai_tts:$PYTHONPATH" +PYTHONPATH="src/rai_sim:$PYTHONPATH" +PYTHONPATH="src/rai_bench:$PYTHONPATH" diff --git a/src/rai_bench/rai_bench/benchmark_model.py b/src/rai_bench/rai_bench/benchmark_model.py index 34cb7b776..de9fc43eb 100644 --- a/src/rai_bench/rai_bench/benchmark_model.py +++ b/src/rai_bench/rai_bench/benchmark_model.py @@ -23,7 +23,7 @@ from rclpy.impl.rcutils_logger import RcutilsLogger from rai_sim.simulation_bridge import ( - PoseModel, + Pose, SimulationBridge, SimulationConfig, SimulationConfigT, @@ -88,7 +88,7 @@ def filter_entities_by_prefab_type( """Filter and return only these entities that match provided prefab types""" return [ent for ent in entities if ent.prefab_name in prefab_types] - def euclidean_distance(self, pos1: PoseModel, pos2: PoseModel) -> float: + def euclidean_distance(self, pos1: Pose, pos2: Pose) -> float: """Calculate euclidean distance between 2 positions""" return ( (pos1.translation.x - pos2.translation.x) ** 2 @@ -96,7 +96,7 @@ def euclidean_distance(self, pos1: PoseModel, pos2: PoseModel) -> float: + (pos1.translation.z - pos2.translation.z) ** 2 ) ** 0.5 - def is_adjacent(self, pos1: PoseModel, pos2: PoseModel, threshold_distance: float): + def is_adjacent(self, pos1: Pose, pos2: Pose, threshold_distance: float): """ Check if positions are adjacent to each other, the threshold_distance is a distance in simulation, refering to how close they have to be to classify them as adjacent @@ -107,7 +107,7 @@ def is_adjacent(self, pos1: PoseModel, pos2: PoseModel, threshold_distance: floa return self.euclidean_distance(pos1, pos2) < threshold_distance def is_adjacent_to_any( - self, pos1: PoseModel, positions: List[PoseModel], threshold_distance: float + self, pos1: Pose, positions: List[Pose], threshold_distance: float ) -> bool: """ Check if given position is adjacent to any position in the given list. @@ -117,9 +117,7 @@ def is_adjacent_to_any( self.is_adjacent(pos1, pos2, threshold_distance) for pos2 in positions ) - def count_adjacent( - self, positions: List[PoseModel], threshold_distance: float - ) -> int: + def count_adjacent(self, positions: List[Pose], threshold_distance: float) -> int: """ Count how many adjacent positions are in the given list. Note that position has to be adjacent to only 1 other position diff --git a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py index bfa81d0bc..b644567d8 100644 --- a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py +++ b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py @@ -32,7 +32,7 @@ from rai_sim.o3de.o3de_bridge import ( O3DEngineArmManipulationBridge, O3DExROS2SimulationConfig, - PoseModel, + Pose, ) from rai_sim.simulation_bridge import Rotation, Translation @@ -152,7 +152,7 @@ ) # custom request to arm - base_arm_pose = PoseModel( + base_arm_pose = Pose( translation=Translation(x=0.5, y=0.1, z=0.3), rotation=Rotation(x=1.0, y=0.0, z=0.0, w=0.0), ) diff --git a/src/rai_sim/README.md b/src/rai_sim/README.md index f6e82bcb6..8008b4d2d 100644 --- a/src/rai_sim/README.md +++ b/src/rai_sim/README.md @@ -6,12 +6,13 @@ The RAI Sim is a package providing interface to implement connection with a spec ### Components -- `SimulationConnector` - An interface for connecting with a specific simulation. It manages scene setup, spawning, despawning objects, getting current state of the scene. +- `SimulationBridge` - An interface for connecting with a specific simulation. It manages scene setup, spawning, despawning objects, getting current state of the scene. -- `SimulationConfig` - base config class to specify the entities to be spawned. For each simulation connector there should be specified custom simulation config specifying additional parameters needed to run and connect with the simulation. +- `SimulationConfig` - base config class to specify the entities to be spawned. For each simulation bridge there should be specified custom simulation config specifying additional parameters needed to run and connect with the simulation. - `SceneState` - stores the current info about spawned entities ### Example implementation -- `O3DExROS2Connector` - An implementation of SimulationConnector for working with simulation based on O3DE and ROS2. +- `O3DExROS2Bridge` - An implementation of SimulationBridge for working with simulation based on O3DE and ROS2. +- `O3DExROS2SimulationConfig` - config class for `O3DExROS2Bridge` diff --git a/src/rai_sim/rai_sim/o3de/o3de_bridge.py b/src/rai_sim/rai_sim/o3de/o3de_bridge.py index d514beb78..8fbe022ee 100644 --- a/src/rai_sim/rai_sim/o3de/o3de_bridge.py +++ b/src/rai_sim/rai_sim/o3de/o3de_bridge.py @@ -18,10 +18,11 @@ import subprocess import time from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set import yaml -from geometry_msgs.msg import Point, Pose, PoseStamped, Quaternion +from geometry_msgs.msg import Point, PoseStamped, Quaternion +from geometry_msgs.msg import Pose as ROS2Pose from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage from rai.utils.ros_async import get_future_result from std_msgs.msg import Header @@ -30,7 +31,7 @@ from rai_interfaces.srv import ManipulatorMoveTo from rai_sim.simulation_bridge import ( Entity, - PoseModel, + Pose, Rotation, SceneState, SimulationBridge, @@ -43,9 +44,8 @@ class O3DExROS2SimulationConfig(SimulationConfig): binary_path: Path robotic_stack_command: str - required_services: List[str] - required_topics: List[str] - required_actions: List[str] + required_simulation_ros2_interfaces: dict[str, List[str]] + required_robotic_ros2_interfaces: dict[str, List[str]] @classmethod def load_config( @@ -110,13 +110,7 @@ def get_available_spawnable_names(self) -> list[str]: target="get_available_spawnable_names", msg_type="gazebo_msgs/srv/GetWorldProperties", ) - # NOTE (mkotynia) There is a bug in the gazebo_msgs/srv/GetWorldProperties service - payload.success is not set to True even if the service call is successful. It was reported to Kacper DÄ…browski and he is going to fix it. - # PR fixing the bug: https://github.com/o3de/o3de-extras/pull/828 - # TODO (mkotynia) uncomment check if response.payload.success when the bug is fixed and remove workaround check if response.payload.model_names. - - # if response.payload.success: - # return response.payload.model_names - if response.payload.model_names: + if response.payload.success: return response.payload.model_names else: raise RuntimeError( @@ -125,7 +119,7 @@ def get_available_spawnable_names(self) -> list[str]: def _spawn_entity(self, entity: Entity): pose = do_transform_pose( - self.to_ros2_pose(entity.pose), + self._to_ros2_pose(entity.pose), self.connector.get_transform("odom", "world"), ) @@ -176,15 +170,13 @@ def _despawn_entity(self, entity: SpawnedEntity): f"Failed to delete entity {entity.name}. Response: {response.payload.status_message}" ) - def get_object_pose(self, entity: SpawnedEntity) -> PoseModel: + def get_object_pose(self, entity: SpawnedEntity) -> Pose: object_frame = entity.name + "/" ros2_pose = do_transform_pose( - Pose(), self.connector.get_transform(object_frame + "odom", object_frame) - ) - ros2_pose = do_transform_pose( - ros2_pose, self.connector.get_transform("world", "odom") + ROS2Pose(), + self.connector.get_transform(object_frame + "odom", object_frame), ) - return self.from_ros2_pose(ros2_pose) + return self._from_ros2_pose(ros2_pose) def get_scene_state(self) -> SceneState: """ @@ -205,72 +197,112 @@ def get_scene_state(self) -> SceneState: ) return SceneState(entities=entities) - def _is_robotic_stack_ready( - self, simulation_config: O3DExROS2SimulationConfig, retries: int = 30 + def _is_ros2_stack_ready( + self, required_ros2_stack: dict[str, List[str]], retries: int = 30 ) -> bool: - i = 0 - while i < retries: - topics = self.connector.get_topics_names_and_types() - services = self.connector.node.get_service_names_and_types() - topics_names = [tp[0] for tp in topics] - service_names = [srv[0] for srv in services] - self.logger.info( - f"required services: {simulation_config.required_services}" - ) - self.logger.info(f"required topics: {simulation_config.required_topics}") - self.logger.info(f"required actions: {simulation_config.required_actions}") - # NOTE actions will be listed in services and topics - if ( - all(srv in service_names for srv in simulation_config.required_services) - and all(tp in topics_names for tp in simulation_config.required_topics) - and all( - ac in service_names for ac in simulation_config.required_actions + for i in range(retries): + available_topics = self.connector.get_topics_names_and_types() + available_services = self.connector.node.get_service_names_and_types() + available_topics_names = [tp[0] for tp in available_topics] + available_services_names = [srv[0] for srv in available_services] + + # Extract action names + available_actions_names: Set[str] = set() + for service in available_services_names: + if "/_action" in service: + action_name = service.split("/_action")[0] + available_actions_names.add(action_name) + + required_services = required_ros2_stack["services"] + required_topics = required_ros2_stack["topics"] + required_actions = required_ros2_stack["actions"] + self.logger.info(f"required services: {required_services}") + self.logger.info(f"required topics: {required_topics}") + self.logger.info(f"required actions: {required_actions}") + self.logger.info(f"available actions: {available_actions_names}") + + missing_services = [ + service + for service in required_services + if service not in available_services_names + ] + missing_topics = [ + topic + for topic in required_topics + if topic not in available_topics_names + ] + missing_actions = [ + action + for action in required_actions + if action not in available_actions_names + ] + + if missing_services: + self.logger.warning( + f"Waiting for missing services {missing_services} out of required services: {required_services}" + ) + + if missing_topics: + self.logger.warning( + f"Waiting for missing topics: {missing_topics} out of required topics: {required_topics}" ) - ): - self.logger.info("All required services are available.") + + if missing_actions: + self.logger.warning( + f"Waiting for missing actions: {missing_actions} out of required actions: {required_actions}" + ) + + if not (missing_services or missing_topics or missing_actions): + self.logger.info("All required ROS2 stack components are available.") return True - time.sleep(5) - retries += 1 + time.sleep(3) + + self.logger.error( + "Maximum number of retries reached. Required ROS2 stack components are not fully available." + ) return False def setup_scene(self, simulation_config: O3DExROS2SimulationConfig): if self.current_binary_path != simulation_config.binary_path: if self.current_sim_process: self.shutdown() - self._launch_binary(simulation_config.binary_path) - self._launch_robotic_stack(simulation_config.robotic_stack_command) + self._launch_binary(simulation_config) + self._launch_robotic_stack(simulation_config) self.current_binary_path = simulation_config.binary_path else: while self.spawned_entities: self._despawn_entity(self.spawned_entities[0]) - if not self._is_robotic_stack_ready(simulation_config=simulation_config): - raise RuntimeError( - "Not all required services, topics and actions are available" - ) - for entity in simulation_config.entities: self._spawn_entity(entity) - def _launch_binary(self, binary_path: Path): - command = [binary_path.as_posix()] + def _launch_binary(self, simulation_config: O3DExROS2SimulationConfig): + command = [simulation_config.binary_path.as_posix()] self.logger.info(f"Running command: {command}") self.current_sim_process = subprocess.Popen( command, ) if not self._has_process_started(process=self.current_sim_process): raise RuntimeError("Process did not start in time.") + if not self._is_ros2_stack_ready( + required_ros2_stack=simulation_config.required_simulation_ros2_interfaces + ): + raise RuntimeError("ROS2 stack is not ready in time.") - def _launch_robotic_stack(self, robotic_stack_command: str): - command = shlex.split(robotic_stack_command) + def _launch_robotic_stack(self, simulation_config: O3DExROS2SimulationConfig): + command = shlex.split(simulation_config.robotic_stack_command) self.logger.info(f"Running command: {command}") self.current_robotic_stack_process = subprocess.Popen( command, ) if not self._has_process_started(self.current_robotic_stack_process): raise RuntimeError("Process did not start in time.") + if not self._is_ros2_stack_ready( + required_ros2_stack=simulation_config.required_robotic_ros2_interfaces + ): + raise RuntimeError("ROS2 stack is not ready in time.") def _has_process_started(self, process: subprocess.Popen[Any], timeout: int = 15): start_time = time.time() @@ -303,9 +335,9 @@ def _try_service_call( return response # type: ignore # NOTE (mkotynia) probably to be refactored, other bridges may also want to use pose conversion to/from ROS2 format - def to_ros2_pose(self, pose: PoseModel) -> Pose: + def _to_ros2_pose(self, pose: Pose) -> ROS2Pose: """ - Converts pose in PoseModel format to pose in ROS2 Pose format. + Converts pose to pose in ROS2 Pose format. """ position = Point( x=pose.translation.x, y=pose.translation.y, z=pose.translation.z @@ -321,13 +353,13 @@ def to_ros2_pose(self, pose: PoseModel) -> Pose: else: orientation = Quaternion() - ros2_pose = Pose(position=position, orientation=orientation) + ros2_pose = ROS2Pose(position=position, orientation=orientation) return ros2_pose - def from_ros2_pose(self, pose: Pose) -> PoseModel: + def _from_ros2_pose(self, pose: ROS2Pose) -> Pose: """ - Converts ROS2 pose to PoseModel format + Converts ROS2Pose to Pose """ translation = Translation( @@ -343,13 +375,13 @@ def from_ros2_pose(self, pose: Pose) -> PoseModel: w=pose.orientation.w, # type: ignore ) - return PoseModel(translation=translation, rotation=rotation) + return Pose(translation=translation, rotation=rotation) class O3DEngineArmManipulationBridge(O3DExROS2Bridge): def move_arm( self, - pose: PoseModel, + pose: Pose, initial_gripper_state: bool, final_gripper_state: bool, frame_id: str, @@ -357,7 +389,7 @@ def move_arm( """Moves arm to a given position Args: - pose (PoseModel): where to move arm + pose (Pose): where to move arm initial_gripper_state (bool): False means closed grip, True means open grip final_gripper_state (bool): False means closed grip, True means open grip frame_id (str): reference frame diff --git a/src/rai_sim/rai_sim/simulation_bridge.py b/src/rai_sim/rai_sim/simulation_bridge.py index a1e1fcbd2..336406ad0 100644 --- a/src/rai_sim/rai_sim/simulation_bridge.py +++ b/src/rai_sim/rai_sim/simulation_bridge.py @@ -18,48 +18,102 @@ from typing import Generic, List, Optional, TypeVar import yaml -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator class Translation(BaseModel): - x: float - y: float - z: float + """ + Represents the position of an object in 3D space using + x, y, and z coordinates. + """ + + x: float = Field(description="X coordinate in meters") + y: float = Field(description="Y coordinate in meters") + z: float = Field(description="Z coordinate in meters") class Rotation(BaseModel): - x: float - y: float - z: float - w: float + """ + Represents a 3D rotation using quaternion representation. + """ + x: float = Field(description="X component of the quaternion") + y: float = Field(description="Y component of the quaternion") + z: float = Field(description="Z component of the quaternion") + w: float = Field(description="W component of the quaternion") + + +class Pose(BaseModel): + """ + Represents the complete pose (position and orientation) of an object. + """ -class PoseModel(BaseModel): - translation: Translation - rotation: Optional[Rotation] = None + translation: Translation = Field( + description="The position of the object in 3D space" + ) + rotation: Optional[Rotation] = Field( + default=None, + description="The orientation of the object as a quaternion. Optional if orientation is not needed and default orientation is handled by the bridge", + ) class Entity(BaseModel): - name: str - prefab_name: str - pose: PoseModel + """ + Entity that can be spawned in the simulation environment. + """ + + name: str = Field(description="Unique name for the entity") + prefab_name: str = Field( + description="Name of the prefab resource to use for spawning this entity" + ) + pose: Pose = Field(description="Initial pose of the entity") class SpawnedEntity(Entity): - id: str + """ + Entity that has been spawned in the simulation environment. + """ + + id: str = Field( + description="Unique identifier assigned to the spawned entity instance" + ) class SimulationConfig(BaseModel): """ - Setup of simulation - arrangemenet of objects in the environment. + Setup of simulation - arrangement of objects in the environment. + + Attributes + ---------- + entities : List[Entity] + List of entities to be spawned in the simulation. """ - # NOTE (mkotynia) can be extended by other attributes - entities: List[Entity] + entities: List[Entity] = Field( + description="List of entities to be spawned in the simulation environment" + ) @field_validator("entities") @classmethod def check_unique_names(cls, entities: List[Entity]) -> List[Entity]: + """ + Validates that all entity names in the configuration are unique. + + Parameters + ---------- + entities : List[Entity] + List of entities to validate. + + Returns + ------- + List[Entity] + The validated list of entities. + + Raises + ------ + ValueError + If any entity names are duplicated. + """ names = [entity.name for entity in entities] if len(names) != len(set(names)): raise ValueError("Each entity must have a unique name.") @@ -67,6 +121,19 @@ def check_unique_names(cls, entities: List[Entity]) -> List[Entity]: @classmethod def load_base_config(cls, base_config_path: Path) -> "SimulationConfig": + """ + Loads a simulation configuration from a YAML file. + + Parameters + ---------- + base_config_path : Path + Path to the YAML configuration file. + + Returns + ------- + SimulationConfig + The loaded simulation configuration. + """ with open(base_config_path) as f: content = yaml.safe_load(f) return cls(**content) @@ -74,11 +141,17 @@ def load_base_config(cls, base_config_path: Path) -> "SimulationConfig": class SceneState(BaseModel): """ - Info about current entities' state in the scene. + Info about current state of the scene. + + Attributes + ---------- + entities : List[SpawnedEntity] + List of all entities currently present in the scene. """ - # NOTE (mkotynia) can be extended by other attributes - entities: List[SpawnedEntity] + entities: List[SpawnedEntity] = Field( + description="List of all entities currently spawned in the scene with their current poses" + ) SimulationConfigT = TypeVar("SimulationConfigT", bound=SimulationConfig) @@ -100,20 +173,100 @@ def __init__(self, logger: Optional[logging.Logger] = None): @abstractmethod def setup_scene(self, simulation_config: SimulationConfigT): + """ + Runs and sets up the simulation scene according to the provided configuration. + + Parameters + ---------- + simulation_config : SimulationConfigT + Configuration containing the simulation initialization and setup details including + entities to be spawned and their initial poses. + + Returns + ------- + None + """ pass @abstractmethod def _spawn_entity(self, entity: Entity): + """ + Spawns a single entity in the simulation environment. + + Parameters + ---------- + entity : Entity + Entity object containing the entity's properties + + Returns + ------- + None + + Notes + ----- + The spawned entity should be added to the spawned_entities list maintained + by the simulation bridge. + """ pass @abstractmethod def _despawn_entity(self, entity: SpawnedEntity): + """ + Removes a previously spawned entity from the simulation environment. + + Parameters + ---------- + entity : SpawnedEntity + Entity object representing the spawned entity to be removed. + + Returns + ------- + None + + Notes + ----- + Despawned entity should be removed from the spawned_entities list maintained + by the simulation bridge. + """ pass @abstractmethod - def get_object_pose(self, entity: SpawnedEntity) -> PoseModel: + def get_object_pose(self, entity: SpawnedEntity) -> Pose: + """ + Gets the current pose of a spawned entity. + + This method queries the simulation to get the current position and + orientation of a specific entity. + + Parameters + ---------- + entity : SpawnedEntity + Entity object representing the spawned entity whose pose is + to be retrieved. + + Returns + ------- + Pose + Object containing the entity's current pose. + """ pass @abstractmethod def get_scene_state(self) -> SceneState: + """ + Gets the current state of the simulation scene. + + Parameters + ---------- + None + + Returns + ------- + SceneState + Object containing the current state of the scene. + + Notes + ----- + SceneState should contain the current poses of spawned_entities. + """ pass diff --git a/tests/rai_sim/conftest.py b/tests/rai_sim/conftest.py new file mode 100644 index 000000000..4ebb264a9 --- /dev/null +++ b/tests/rai_sim/conftest.py @@ -0,0 +1,75 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import pytest + + +@pytest.fixture +def sample_base_yaml_config(tmp_path: Path) -> Path: + yaml_content = """ + entities: + - name: entity1 + prefab_name: cube + pose: + translation: + x: 1.0 + y: 2.0 + z: 3.0 + + - name: entity2 + prefab_name: carrot + pose: + translation: + x: 1.0 + y: 2.0 + z: 3.0 + rotation: + x: 0.1 + y: 0.2 + z: 0.3 + w: 0.4 + """ + file_path = tmp_path / "test_config.yaml" + file_path.write_text(yaml_content) + return file_path + + +@pytest.fixture +def sample_o3dexros2_config(tmp_path: Path) -> Path: + yaml_content = """ + binary_path: /path/to/binary + robotic_stack_command: "ros2 launch robotic_stack.launch.py" + required_simulation_ros2_interfaces: + services: + - /spawn_entity + - /delete_entity + topics: + - /color_image5 + - /depth_image5 + - /color_camera_info5 + actions: [] + required_robotic_ros2_interfaces: + services: + - /grounding_dino_classify + - /grounded_sam_segment + - /manipulator_move_to + topics: [] + actions: + - /execute_trajectory + """ + file_path = tmp_path / "test_o3dexros2_config.yaml" + file_path.write_text(yaml_content) + return file_path diff --git a/tests/rai_sim/test_o3de_bridge.py b/tests/rai_sim/test_o3de_bridge.py new file mode 100644 index 000000000..b7687304a --- /dev/null +++ b/tests/rai_sim/test_o3de_bridge.py @@ -0,0 +1,379 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import signal +import typing +import unittest +from pathlib import Path +from typing import List, Optional, Tuple, get_args, get_origin +from unittest.mock import MagicMock, patch + +import rclpy +import rclpy.qos +from geometry_msgs.msg import Point, Quaternion, TransformStamped +from geometry_msgs.msg import Pose as ROS2Pose +from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage +from rclpy.node import Node +from rclpy.qos import QoSProfile + +from rai_sim.o3de.o3de_bridge import O3DExROS2Bridge, O3DExROS2SimulationConfig +from rai_sim.simulation_bridge import ( + Entity, + Pose, + Rotation, + SpawnedEntity, + Translation, +) + + +def test_load_config(sample_base_yaml_config: Path, sample_o3dexros2_config: Path): + config = O3DExROS2SimulationConfig.load_config( + sample_base_yaml_config, sample_o3dexros2_config + ) + assert isinstance(config, O3DExROS2SimulationConfig) + assert config.binary_path == Path("/path/to/binary") + assert config.robotic_stack_command == "ros2 launch robotic_stack.launch.py" + assert config.required_simulation_ros2_interfaces == { + "services": ["/spawn_entity", "/delete_entity"], + "topics": ["/color_image5", "/depth_image5", "/color_camera_info5"], + "actions": [], + } + assert config.required_robotic_ros2_interfaces == { + "services": [ + "/grounding_dino_classify", + "/grounded_sam_segment", + "/manipulator_move_to", + ], + "topics": [], + "actions": ["/execute_trajectory"], + } + assert isinstance(config.entities, list) + assert all(isinstance(e, Entity) for e in config.entities) + + assert len(config.entities) == 2 + + +class TestO3DExROS2Bridge(unittest.TestCase): + def setUp(self): + self.mock_connector = MagicMock(spec=ROS2ARIConnector) + self.mock_logger = MagicMock() + self.bridge = O3DExROS2Bridge( + connector=self.mock_connector, logger=self.mock_logger + ) + + # Create test data + self.test_entity = Entity( + name="test_entity1", + prefab_name="cube", + pose=Pose( + translation=Translation(x=1.0, y=2.0, z=3.0), + rotation=Rotation(x=0.0, y=0.0, z=0.0, w=1.0), + ), + ) + + self.test_spawned_entity = SpawnedEntity( + id="entity_id_123", + name="test_entity1", + prefab_name="cube", + pose=Pose( + translation=Translation(x=1.0, y=2.0, z=3.0), + rotation=Rotation(x=0.0, y=0.0, z=0.0, w=1.0), + ), + ) + + self.test_config = O3DExROS2SimulationConfig( + binary_path=Path("/path/to/binary"), + robotic_stack_command="ros2 launch robot.launch.py", + entities=[self.test_entity], + required_simulation_ros2_interfaces={ + "services": [], + "topics": [], + "actions": [], + }, + required_robotic_ros2_interfaces={ + "services": [], + "topics": [], + "actions": [], + }, + ) + + def test_init(self): + self.assertEqual(self.bridge.connector, self.mock_connector) + self.assertEqual(self.bridge.logger, self.mock_logger) + self.assertIsNone(self.bridge.current_sim_process) + self.assertIsNone(self.bridge.current_robotic_stack_process) + self.assertIsNone(self.bridge.current_binary_path) + self.assertEqual(self.bridge.spawned_entities, []) + + @patch("subprocess.Popen") + def test_launch_robotic_stack(self, mock_popen): + mock_process = MagicMock() + mock_process.poll.return_value = None + mock_process.pid = 54321 + mock_popen.return_value = mock_process + self.bridge._launch_robotic_stack(self.test_config) + + mock_popen.assert_called_once_with(["ros2", "launch", "robot.launch.py"]) + self.assertEqual(self.bridge.current_robotic_stack_process, mock_process) + + @patch("subprocess.Popen") + def test_launch_binary(self, mock_popen): + mock_process = MagicMock() + mock_process.poll.return_value = None + mock_process.pid = 54322 + mock_popen.return_value = mock_process + + self.bridge._launch_binary(self.test_config) + + mock_popen.assert_called_once_with(["/path/to/binary"]) + self.assertEqual(self.bridge.current_sim_process, mock_process) + + def test_shutdown_binary(self): + mock_process = MagicMock() + mock_process.poll.return_value = 0 + + self.bridge.current_sim_process = mock_process + + self.bridge._shutdown_binary() + + mock_process.send_signal.assert_called_once_with(signal.SIGINT) + mock_process.wait.assert_called_once() + + self.assertIsNone(self.bridge.current_sim_process) + + def test_shutdown_robotic_stack(self): + self.bridge.current_robotic_stack_process = MagicMock() + self.bridge.current_robotic_stack_process.poll.return_value = 0 + + self.bridge._shutdown_robotic_stack() + + self.bridge.current_robotic_stack_process.send_signal.assert_called_once_with( + signal.SIGINT + ) + self.bridge.current_robotic_stack_process.wait.assert_called_once() + + def test_get_available_spawnable_names(self): + # Mock the response + response = MagicMock() + response.payload.model_names = ["cube", "carrot"] + self.bridge._try_service_call = MagicMock(return_value=response) + + names = self.bridge.get_available_spawnable_names() + + self.bridge._try_service_call.assert_called_once() + self.assertEqual(names, ["cube", "carrot"]) + + def test_to_ros2_pose(self): + # Create a pose + pose = Pose( + translation=Translation(x=1.0, y=2.0, z=3.0), + rotation=Rotation(x=0.1, y=0.2, z=0.3, w=0.4), + ) + + # Convert to ROS2 pose + ros2_pose = self.bridge._to_ros2_pose(pose) + + # Check the conversion + self.assertEqual(ros2_pose.position.x, 1.0) + self.assertEqual(ros2_pose.position.y, 2.0) + self.assertEqual(ros2_pose.position.z, 3.0) + self.assertEqual(ros2_pose.orientation.x, 0.1) + self.assertEqual(ros2_pose.orientation.y, 0.2) + self.assertEqual(ros2_pose.orientation.z, 0.3) + self.assertEqual(ros2_pose.orientation.w, 0.4) + + def test_from_ros2_pose(self): + # Create a ROS2 pose + position = Point(x=1.0, y=2.0, z=3.0) + orientation = Quaternion(x=0.1, y=0.2, z=0.3, w=0.4) + ros2_pose = ROS2Pose(position=position, orientation=orientation) + + # Convert from ROS2Pose to Pose + pose = self.bridge._from_ros2_pose(ros2_pose) + + # Check the conversion + self.assertEqual(pose.translation.x, 1.0) + self.assertEqual(pose.translation.y, 2.0) + self.assertEqual(pose.translation.z, 3.0) + self.assertEqual(pose.rotation.x, 0.1) + self.assertEqual(pose.rotation.y, 0.2) + self.assertEqual(pose.rotation.z, 0.3) + self.assertEqual(pose.rotation.w, 0.4) + + +class TestROS2ARIConnectorInterface(unittest.TestCase): + """Tests to ensure the ROS2ARIConnector interface meets the expectations of O3DExROS2Bridge.""" + + def setUp(self): + rclpy.init() + self.connector = ROS2ARIConnector() + + def tearDown(self): + rclpy.shutdown() + + def test_connector_required_methods_exist(self): + """Test that all required methods exist on the ROS2ARIConnector.""" + connector = ROS2ARIConnector() + + # Check that all required methods exist + self.assertTrue( + hasattr(connector, "service_call"), "service_call method is missing" + ) + self.assertTrue( + hasattr(connector, "get_transform"), "get_transform method is missing" + ) + self.assertTrue( + hasattr(connector, "send_message"), "send_message method is missing" + ) + self.assertTrue( + hasattr(connector, "receive_message"), "receive_message method is missing" + ) + self.assertTrue(hasattr(connector, "shutdown"), "shutdown method is missing") + self.assertTrue( + hasattr(connector, "get_topics_names_and_types"), + "get_topics_names_and_types method is missing", + ) + self.assertTrue( + hasattr(connector, "node"), + "node property is missing", + ) + + def resolve_annotation(self, annotation: type) -> type: + """Helper function to unwrap Optional types. Workaround for problem with asserting Optional types.""" + if get_origin(annotation) is typing.Optional: + return get_args(annotation)[0] + return annotation + + def test_get_transform_signature(self): + signature = inspect.signature(self.connector.get_transform) + parameters = signature.parameters + + expected_params: dict[str, type] = { + "target_frame": str, + "source_frame": str, + "timeout_sec": float, + } + + assert list(parameters.keys()) == list(expected_params.keys()), ( + f"Parameter names do not match, expected: {list(expected_params.keys())}, got: {list(parameters.keys())}" + ) + + for param_name, expected_type in expected_params.items(): + param = parameters[param_name] + self.assertEqual( + self.resolve_annotation(param.annotation), + self.resolve_annotation(expected_type), + f"Parameter '{param_name}' has incorrect type, expected: {expected_type}, got: {param.annotation}", + ) + + # Check return type explicitly + assert signature.return_annotation is TransformStamped, ( + f"Return type is incorrect, expected: TransformStamped, got: {signature.return_annotation}" + ) + + def test_send_message_signature(self): + signature = inspect.signature(self.connector.send_message) + parameters = signature.parameters + + expected_params: dict[str, type] = { + "message": ROS2ARIMessage, + "target": str, + "msg_type": str, + "auto_qos_matching": bool, + "qos_profile": Optional[QoSProfile], + } + + self.assertListEqual( + list(parameters.keys())[: len(expected_params)], + list(expected_params.keys()), + f"Parameter names do not match, expected: {list(expected_params.keys())}, got: {list(parameters.keys())}", + ) + + for param_name, expected_type in expected_params.items(): + param = parameters[param_name] + self.assertEqual( + self.resolve_annotation(param.annotation), + self.resolve_annotation(expected_type), + f"Parameter '{param_name}' has incorrect type, expected: {expected_type}, got: {param.annotation}", + ) + + self.assertIs( + signature.return_annotation, + inspect.Signature.empty, + "send_message should have no return value", + ) + + def test_receive_message_signature(self): + signature = inspect.signature(self.connector.receive_message) + parameters = signature.parameters + + expected_params: dict[str, type] = { + "source": str, + "timeout_sec": float, + "msg_type": Optional[str], + "auto_topic_type": bool, + } + + self.assertListEqual( + list(parameters.keys())[: len(expected_params)], + list(expected_params.keys()), + f"Parameter names do not match, expected: {list(expected_params.keys())}, got: {list(parameters.keys())}", + ) + + for param_name, expected_type in expected_params.items(): + param = parameters[param_name] + self.assertEqual( + self.resolve_annotation(param.annotation), + self.resolve_annotation(expected_type), + f"Parameter '{param_name}' has incorrect type, expected: {expected_type}, got: {param.annotation}", + ) + + self.assertIs( + signature.return_annotation, + ROS2ARIMessage, + f"Return type is incorrect, expected: ROS2ARIMessage, got: {signature.return_annotation}", + ) + + def test_get_topics_names_and_types_signature(self): + signature = inspect.signature(self.connector.get_topics_names_and_types) + parameters = signature.parameters + + expected_params: dict[str, type] = {} + + assert list(parameters.keys()) == list(expected_params.keys()), ( + f"Parameter names do not match, expected: {list(expected_params.keys())}, got: {list(parameters.keys())}" + ) + + for param_name, expected_type in expected_params.items(): + param = parameters[param_name] + self.assertEqual( + self.resolve_annotation(param.annotation), + self.resolve_annotation(expected_type), + f"Parameter '{param_name}' has incorrect type, expected: {expected_type}, got: {param.annotation}", + ) + + self.assertEqual( + signature.return_annotation, + List[Tuple[str, List[str]]], + f"Return type is incorrect, expected: List[Tuple[str, List[str]]], got: {signature.return_annotation}", + ) + + def test_node_property(self): + """Test that the node property returns the expected Node instance.""" + mock_node = MagicMock(spec=Node) + self.connector._node = mock_node + + self.assertEqual(self.connector.node, mock_node) + self.assertIsInstance(self.connector.node, Node) diff --git a/tests/rai_sim/test_simulation_bridge.py b/tests/rai_sim/test_simulation_bridge.py new file mode 100644 index 000000000..39f91ab33 --- /dev/null +++ b/tests/rai_sim/test_simulation_bridge.py @@ -0,0 +1,325 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import unittest +from pathlib import Path +from typing import Optional + +import pytest +from pydantic import ValidationError + +from rai_sim.simulation_bridge import ( + Entity, + Pose, + Rotation, + SceneState, + SimulationBridge, + SimulationConfig, + SpawnedEntity, + Translation, +) + + +# Helper Functions +def create_translation(x: float, y: float, z: float) -> Translation: + return Translation(x=x, y=y, z=z) + + +def create_rotation(x: float, y: float, z: float, w: float) -> Rotation: + return Rotation(x=x, y=y, z=z, w=w) + + +def create_pose(translation: Translation, rotation: Optional[Rotation] = None) -> Pose: + return Pose(translation=translation, rotation=rotation) + + +# Test Cases +def test_translation(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + + assert isinstance(translation.x, float) + assert isinstance(translation.y, float) + assert isinstance(translation.z, float) + + assert translation.x == 1.1 + assert translation.y == 2.2 + assert translation.z == 3.3 + + +def test_rotation(): + rotation = Rotation(x=0.1, y=0.2, z=0.3, w=0.4) + + assert isinstance(rotation.x, float) + assert isinstance(rotation.y, float) + assert isinstance(rotation.z, float) + assert isinstance(rotation.w, float) + + assert rotation.x == 0.1 + assert rotation.y == 0.2 + assert rotation.z == 0.3 + assert rotation.w == 0.4 + + +def test_pose(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = Rotation(x=0.1, y=0.2, z=0.3, w=0.4) + + pose = Pose(translation=translation, rotation=rotation) + + assert isinstance(pose.translation, Translation) + assert isinstance(pose.rotation, Rotation) + + assert pose.translation.x == 1.1 + assert pose.rotation.w == 0.4 + + +def test_optional_rotation(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + pose = create_pose(translation=translation) + + assert isinstance(pose.translation, Translation) + assert pose.rotation is None + + +def test_entity(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = create_rotation(x=0.1, y=0.2, z=0.3, w=0.4) + pose = create_pose(translation=translation, rotation=rotation) + + entity = Entity(name="test_cube", prefab_name="cube", pose=pose) + + assert isinstance(entity.name, str) + assert isinstance(entity.prefab_name, str) + assert isinstance(entity.pose, Pose) + + assert entity.name == "test_cube" + assert entity.prefab_name == "cube" + assert entity.pose == pose + + +def test_spawned_entity(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = create_rotation(x=0.1, y=0.2, z=0.3, w=0.4) + pose = create_pose(translation=translation, rotation=rotation) + + spawned_entity = SpawnedEntity( + name="test_cube", + prefab_name="cube", + pose=pose, + id="id_123", + ) + + assert isinstance(spawned_entity.name, str) + assert isinstance(spawned_entity.prefab_name, str) + assert isinstance(spawned_entity.pose, Pose) + assert isinstance(spawned_entity.id, str) + + assert spawned_entity.id == "id_123" + + +def test_simulation_config_unique_names(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = create_rotation(x=0.1, y=0.2, z=0.3, w=0.4) + pose = create_pose(translation=translation, rotation=rotation) + + entities = [ + Entity(name="entity1", prefab_name="cube", pose=pose), + Entity(name="entity2", prefab_name="carrot", pose=pose), + ] + + config = SimulationConfig(entities=entities) + + assert isinstance(config.entities, list) + assert all(isinstance(e, Entity) for e in config.entities) + + assert len(config.entities) == 2 + + +def test_simulation_config_duplicate_names(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = create_rotation(x=0.1, y=0.2, z=0.3, w=0.4) + pose = create_pose(translation=translation, rotation=rotation) + + entities = [ + Entity(name="duplicate", prefab_name="cube", pose=pose), + Entity(name="duplicate", prefab_name="carrot", pose=pose), + ] + + with pytest.raises(ValidationError): + SimulationConfig(entities=entities) + + +def test_load_base_config(sample_base_yaml_config: Path): + config = SimulationConfig.load_base_config(sample_base_yaml_config) + + assert isinstance(config.entities, list) + assert all(isinstance(e, Entity) for e in config.entities) + + assert len(config.entities) == 2 + + +class MockSimulationBridge(SimulationBridge[SimulationConfig]): + """Mock implementation of SimulationBridge for testing.""" + + def setup_scene(self, simulation_config: SimulationConfig): + """Mock implementation of setup_scene.""" + for entity in simulation_config.entities: + self._spawn_entity(entity) + + def _spawn_entity(self, entity: Entity): + """Mock implementation of _spawn_entity.""" + spawned_entity = SpawnedEntity( + id=f"id_{entity.name}", + name=entity.name, + prefab_name=entity.prefab_name, + pose=entity.pose, + ) + self.spawned_entities.append(spawned_entity) + + def _despawn_entity(self, entity: SpawnedEntity): + """Mock implementation of _despawn_entity.""" + self.spawned_entities = [e for e in self.spawned_entities if e.id != entity.id] + + def get_object_pose(self, entity: SpawnedEntity) -> Pose: + """Mock implementation of get_object_pose.""" + for spawned_entity in self.spawned_entities: + if spawned_entity.id == entity.id: + return spawned_entity.pose + raise ValueError(f"Entity with id {entity.id} not found") + + def get_scene_state(self) -> SceneState: + """Mock implementation of get_scene_state.""" + return SceneState(entities=self.spawned_entities) + + +class TestSimulationBridge(unittest.TestCase): + def setUp(self): + """Set up test fixtures.""" + self.logger = logging.getLogger("test_logger") + self.bridge = MockSimulationBridge(logger=self.logger) + + # Create test entities + self.test_entity1: Entity = Entity( + name="test_entity1", + prefab_name="test_prefab1", + pose=Pose( + translation=Translation(x=1.0, y=2.0, z=3.0), + rotation=Rotation(x=0.0, y=0.0, z=0.0, w=1.0), + ), + ) + + self.test_entity2: Entity = Entity( + name="test_entity2", + prefab_name="test_prefab2", + pose=Pose(translation=Translation(x=4.0, y=5.0, z=6.0), rotation=None), + ) + + # Create a test configuration + self.test_config = SimulationConfig( + entities=[self.test_entity1, self.test_entity2] + ) + + def test_init(self): + # Test with provided logger + bridge = MockSimulationBridge(logger=self.logger) + self.assertEqual(bridge.logger, self.logger) + self.assertEqual(len(bridge.spawned_entities), 0) + + # Test with default logger + bridge = MockSimulationBridge() + self.assertIsNotNone(bridge.logger) + self.assertEqual(len(bridge.spawned_entities), 0) + + def test_setup_scene(self): + self.bridge.setup_scene(self.test_config) + + # Check if entities were spawned + self.assertEqual(len(self.bridge.spawned_entities), 2) + + # Check if the spawned entities have correct properties + self.assertEqual(self.bridge.spawned_entities[0].name, "test_entity1") + self.assertEqual(self.bridge.spawned_entities[0].prefab_name, "test_prefab1") + self.assertEqual(self.bridge.spawned_entities[0].pose.translation.x, 1.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.translation.y, 2.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.translation.z, 3.0) + assert self.bridge.spawned_entities[0].pose.rotation + self.assertEqual(self.bridge.spawned_entities[0].pose.rotation.x, 0.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.rotation.y, 0.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.rotation.z, 0.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.rotation.w, 1.0) + + self.assertEqual(self.bridge.spawned_entities[1].name, "test_entity2") + self.assertEqual(self.bridge.spawned_entities[1].prefab_name, "test_prefab2") + self.assertEqual(self.bridge.spawned_entities[1].pose.translation.x, 4.0) + self.assertEqual(self.bridge.spawned_entities[1].pose.translation.y, 5.0) + self.assertEqual(self.bridge.spawned_entities[1].pose.translation.z, 6.0) + self.assertIsNone(self.bridge.spawned_entities[1].pose.rotation) + + def test_spawn_entity(self): + self.bridge._spawn_entity(self.test_entity1) # type: ignore + spawned_entity = self.bridge.spawned_entities[0] + # Check if entity was added to spawned_entities + self.assertEqual(len(self.bridge.spawned_entities), 1) + self.assertIsInstance(spawned_entity, SpawnedEntity) + + def test_despawn_entity(self): + # First spawn an entity + self.bridge._spawn_entity(self.test_entity1) # type: ignore + self.assertEqual(len(self.bridge.spawned_entities), 1) + + # Then despawn it + self.bridge._despawn_entity(self.bridge.spawned_entities[0]) # type: ignore + self.assertEqual(len(self.bridge.spawned_entities), 0) + + def test_get_object_pose(self): + # First spawn an entity + self.bridge._spawn_entity(self.test_entity1) # type: ignore + + # Get the pose + pose = self.bridge.get_object_pose(self.bridge.spawned_entities[0]) + + # Check if the pose matches + self.assertEqual(pose.translation.x, 1.0) + self.assertEqual(pose.translation.y, 2.0) + self.assertEqual(pose.translation.z, 3.0) + assert pose.rotation + self.assertEqual(pose.rotation.x, 0.0) + self.assertEqual(pose.rotation.y, 0.0) + self.assertEqual(pose.rotation.z, 0.0) + self.assertEqual(pose.rotation.w, 1.0) + + # Test for non-existent entity + non_existent_entity = SpawnedEntity( + id="non_existent", + name="non_existent", + prefab_name="non_existent", + pose=Pose(translation=Translation(x=0.0, y=0.0, z=0.0)), + ) + + with self.assertRaises(ValueError): + self.bridge.get_object_pose(non_existent_entity) + + def test_get_scene_state(self): + # First spawn some entities + self.bridge._spawn_entity(self.test_entity1) # type: ignore + self.bridge._spawn_entity(self.test_entity2) # type: ignore + + # Get the scene state + scene_state = self.bridge.get_scene_state() + + # Check if the scene state contains the correct entities + self.assertEqual(len(scene_state.entities), 2) + self.assertEqual(scene_state.entities[0].name, "test_entity1") + self.assertEqual(scene_state.entities[1].name, "test_entity2")