From ff0585ee4420a7e20da51a383a076802423ef117 Mon Sep 17 00:00:00 2001 From: Magdalena Kotynia Date: Wed, 26 Feb 2025 09:41:30 +0100 Subject: [PATCH 1/8] feat: rai_sim (#415) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kacper Dąbrowski Signed-off-by: Piotr Jaroszek Co-authored-by: Jakub Matejczyk Co-authored-by: Kacper Dąbrowski Co-authored-by: Piotr Jaroszek <10896036+pijaro@users.noreply.github.com> --- .../manipulation-demo-no-binary.launch.py | 59 ++++ poetry.lock | 18 +- pyproject.toml | 1 + .../rai/communication/ros2/connectors.py | 8 +- src/rai_sim/README.md | 17 + src/rai_sim/pyproject.toml | 23 ++ src/rai_sim/rai_sim/__init__.py | 15 + src/rai_sim/rai_sim/o3de/o3de_bridge.py | 306 ++++++++++++++++++ src/rai_sim/rai_sim/simulation_bridge.py | 119 +++++++ 9 files changed, 561 insertions(+), 5 deletions(-) create mode 100755 examples/manipulation-demo-no-binary.launch.py create mode 100644 src/rai_sim/README.md create mode 100644 src/rai_sim/pyproject.toml create mode 100644 src/rai_sim/rai_sim/__init__.py create mode 100644 src/rai_sim/rai_sim/o3de/o3de_bridge.py create mode 100644 src/rai_sim/rai_sim/simulation_bridge.py diff --git a/examples/manipulation-demo-no-binary.launch.py b/examples/manipulation-demo-no-binary.launch.py new file mode 100755 index 000000000..494deea7b --- /dev/null +++ b/examples/manipulation-demo-no-binary.launch.py @@ -0,0 +1,59 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from launch import LaunchDescription +from launch.actions import ( + IncludeLaunchDescription, +) +from launch.launch_description_sources import PythonLaunchDescriptionSource +from launch_ros.actions import Node +from launch_ros.substitutions import FindPackageShare + + +# TODO (mkotynia) think about separation of launches +def generate_launch_description(): + launch_moveit = IncludeLaunchDescription( + PythonLaunchDescriptionSource( + [ + "src/examples/rai-manipulation-demo/Project/Examples/panda_moveit_config_demo.launch.py", + ] + ) + ) + + launch_robotic_manipulation = Node( + package="robotic_manipulation", + executable="robotic_manipulation", + # name="robotic_manipulation_node", + output="screen", + parameters=[ + {"use_sim_time": True}, + ], + ) + + launch_openset = IncludeLaunchDescription( + PythonLaunchDescriptionSource( + [ + FindPackageShare("rai_bringup"), + "/launch/openset.launch.py", + ] + ), + ) + + return LaunchDescription( + [ + launch_openset, + launch_moveit, + launch_robotic_manipulation, + ] + ) diff --git a/poetry.lock b/poetry.lock index 93059b345..2d22b986d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5700,6 +5700,22 @@ torchaudio = "^2.3.1" type = "directory" url = "src/rai_asr" +[[package]] +name = "rai-sim" +version = "0.0.1" +description = "Package to run simulations" +optional = false +python-versions = "^3.10, <3.13" +files = [] +develop = true + +[package.dependencies] +PyYAML = "*" + +[package.source] +type = "directory" +url = "src/rai_sim" + [[package]] name = "rai-tts" version = "1.0.0" @@ -8302,4 +8318,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10, <3.13" -content-hash = "242e440c4ce4b31fa629d198a3d79b0854e84284d013d819a4b7a24e633a1706" +content-hash = "a53906ce2c798e5e0a02c7db25cf00cf36e021186a79429ba1bd8f0836b12db2" diff --git a/pyproject.toml b/pyproject.toml index b313dd299..863309947 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ python = "^3.10, <3.13" rai = {path = "src/rai_core", develop = true} rai_asr = {path = "src/rai_asr", develop = true} rai_tts = {path = "src/rai_tts", develop = true} +rai_sim = {path = "src/rai_sim", develop = true} langchain-core = "^0.3" langchain = "*" diff --git a/src/rai_core/rai/communication/ros2/connectors.py b/src/rai_core/rai/communication/ros2/connectors.py index 452dae22e..af7953d78 100644 --- a/src/rai_core/rai/communication/ros2/connectors.py +++ b/src/rai_core/rai/communication/ros2/connectors.py @@ -69,6 +69,7 @@ def __init__( self._topic_api = ROS2TopicAPI(self._node) self._service_api = ROS2ServiceAPI(self._node) self._actions_api = ROS2ActionAPI(self._node) + self._tf_buffer = Buffer(node=self._node) self._executor = MultiThreadedExecutor() self._executor.add_node(self._node) @@ -179,16 +180,15 @@ def get_transform( source_frame: str, timeout_sec: float = 5.0, ) -> TransformStamped: - tf_buffer = Buffer(node=self._node) - tf_listener = TransformListener(tf_buffer, self._node) + tf_listener = TransformListener(self._tf_buffer, self._node) transform_available = self.wait_for_transform( - tf_buffer, target_frame, source_frame, timeout_sec + self._tf_buffer, target_frame, source_frame, timeout_sec ) if not transform_available: raise LookupException( f"Could not find transform from {source_frame} to {target_frame} in {timeout_sec} seconds" ) - transform: TransformStamped = tf_buffer.lookup_transform( + transform: TransformStamped = self._tf_buffer.lookup_transform( target_frame, source_frame, rclpy.time.Time(), diff --git a/src/rai_sim/README.md b/src/rai_sim/README.md new file mode 100644 index 000000000..f6e82bcb6 --- /dev/null +++ b/src/rai_sim/README.md @@ -0,0 +1,17 @@ +## RAI Sim + +## Description + +The RAI Sim is a package providing interface to implement connection with a specific simulation. + +### Components + +- `SimulationConnector` - An interface for connecting with a specific simulation. It manages scene setup, spawning, despawning objects, getting current state of the scene. + +- `SimulationConfig` - base config class to specify the entities to be spawned. For each simulation connector there should be specified custom simulation config specifying additional parameters needed to run and connect with the simulation. + +- `SceneState` - stores the current info about spawned entities + +### Example implementation + +- `O3DExROS2Connector` - An implementation of SimulationConnector for working with simulation based on O3DE and ROS2. diff --git a/src/rai_sim/pyproject.toml b/src/rai_sim/pyproject.toml new file mode 100644 index 000000000..9042be051 --- /dev/null +++ b/src/rai_sim/pyproject.toml @@ -0,0 +1,23 @@ +[tool.poetry] +name = "rai-sim" +version = "0.0.1" +description = "Package to run simulations" +authors = ["Magdalena Kotynia ", "Kacper Dąbrowski ", "Jakub Matejczyk "] +readme = "README.md" +classifiers = [ + "Programming Language :: Python :: 3", + "Development Status :: 4 - Beta", + "License :: OSI Approved :: Apache Software License", +] +packages = [ + { include = "rai_sim", from = "." }, +] + +[tool.poetry.dependencies] +python = "^3.10, <3.13" +PyYAML = "*" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/src/rai_sim/rai_sim/__init__.py b/src/rai_sim/rai_sim/__init__.py new file mode 100644 index 000000000..f792966c5 --- /dev/null +++ b/src/rai_sim/rai_sim/__init__.py @@ -0,0 +1,15 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RAI Simulations package.""" diff --git a/src/rai_sim/rai_sim/o3de/o3de_bridge.py b/src/rai_sim/rai_sim/o3de/o3de_bridge.py new file mode 100644 index 000000000..45e377d39 --- /dev/null +++ b/src/rai_sim/rai_sim/o3de/o3de_bridge.py @@ -0,0 +1,306 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import shlex +import signal +import subprocess +import time +from pathlib import Path +from typing import Any, Dict, Optional + +import yaml +from geometry_msgs.msg import Point, Pose, Quaternion +from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage +from tf2_geometry_msgs import do_transform_pose + +from rai_sim.simulation_bridge import ( + Entity, + PoseModel, + Rotation, + SceneState, + SimulationBridge, + SimulationConfig, + SpawnedEntity, + Translation, +) + + +class O3DExROS2SimulationConfig(SimulationConfig): + binary_path: Path + robotic_stack_command: str + + @classmethod + def load_config( + cls, base_config_path: Path, connector_config_path: Path + ) -> "O3DExROS2SimulationConfig": + base_config = SimulationConfig.load_base_config(base_config_path) + + with open(connector_config_path) as f: + connector_content: dict[str, Any] = yaml.safe_load(f) + return cls(**base_config.model_dump(), **connector_content) + + +class O3DExROS2Bridge(SimulationBridge[O3DExROS2SimulationConfig]): + def __init__( + self, connector: ROS2ARIConnector, logger: Optional[logging.Logger] = None + ): + super().__init__(logger=logger) + self.connector = connector + self.current_sim_process = None + self.current_robotic_stack_process = None + self.current_binary_path = None + + def shutdown(self): + self._shutdown_binary() + self._shutdown_robotic_stack() + + def _shutdown_binary(self): + if not self.current_sim_process: + return + self.current_sim_process.send_signal(signal.SIGINT) + self.current_sim_process.wait() + + if self.current_sim_process.poll() is None: + self.logger.error( + f"Parent process PID: {self.current_sim_process.pid} is still running." + ) + raise RuntimeError( + f"Failed to terminate main process PID {self.current_sim_process.pid}" + ) + + self.current_sim_process = None + + def _shutdown_robotic_stack(self): + if not self.current_robotic_stack_process: + return + + self.current_robotic_stack_process.send_signal(signal.SIGINT) + self.current_robotic_stack_process.wait() + + if self.current_robotic_stack_process.poll() is None: + self.logger.error( + f"Parent process PID: {self.current_robotic_stack_process.pid} is still running." + ) + raise RuntimeError( + f"Failed to terminate robotic stack process PID {self.current_robotic_stack_process.pid}" + ) + + def get_available_spawnable_names(self) -> list[str]: + msg = ROS2ARIMessage({}) + response = self._try_service_call( + msg, + target="get_available_spawnable_names", + msg_type="gazebo_msgs/srv/GetWorldProperties", + ) + # NOTE (mkotynia) There is a bug in the gazebo_msgs/srv/GetWorldProperties service - payload.success is not set to True even if the service call is successful. It was reported to Kacper Dąbrowski and he is going to fix it. + # PR fixing the bug: https://github.com/o3de/o3de-extras/pull/828 + # TODO (mkotynia) uncomment check if response.payload.success when the bug is fixed and remove workaround check if response.payload.model_names. + + # if response.payload.success: + # return response.payload.model_names + if response.payload.model_names: + return response.payload.model_names + else: + raise RuntimeError( + f"Failed to get available spawnable names. Response: {response.payload.status_message}" + ) + + def _spawn_entity(self, entity: Entity): + pose = do_transform_pose( + self.to_ros2_pose(entity.pose), + self.connector.get_transform("odom", "world"), + ) + + msg_content: Dict[str, Any] = { + "name": entity.prefab_name, + "xml": "", + "robot_namespace": entity.name, + "initial_pose": { + "position": { + "x": pose.position.x, # type: ignore + "y": pose.position.y, # type: ignore + "z": pose.position.z, # type: ignore + }, + "orientation": { + "x": pose.orientation.x, # type: ignore + "y": pose.orientation.y, # type: ignore + "z": pose.orientation.z, # type: ignore + "w": pose.orientation.w, # type: ignore + }, + }, + } + + msg = ROS2ARIMessage(payload=msg_content) + response = self._try_service_call( + msg, target="spawn_entity", msg_type="gazebo_msgs/srv/SpawnEntity" + ) + if response and response.payload.success: + self.spawned_entities.append( + SpawnedEntity(id=response.payload.status_message, **entity.model_dump()) + ) + else: + raise RuntimeError( + f"Failed to spawn entity {entity.name}. Response: {response.payload.status_message}" + ) + + def _despawn_entity(self, entity: SpawnedEntity): + msg_content = {"name": entity.id} + + msg = ROS2ARIMessage(payload=msg_content) + + response = self._try_service_call( + msg, target="delete_entity", msg_type="gazebo_msgs/srv/DeleteEntity" + ) + if response.payload.success: + self.spawned_entities.remove(entity) + else: + raise RuntimeError( + f"Failed to delete entity {entity.name}. Response: {response.payload.status_message}" + ) + + def get_object_pose(self, entity: SpawnedEntity) -> PoseModel: + object_frame = entity.name + "/" + ros2_pose = do_transform_pose( + Pose(), self.connector.get_transform(object_frame + "odom", object_frame) + ) + ros2_pose = do_transform_pose( + ros2_pose, self.connector.get_transform("world", "odom") + ) + return self.from_ros2_pose(ros2_pose) + + def get_scene_state(self) -> SceneState: + """ + Get the current scene state. + """ + if not self.current_sim_process: + raise RuntimeError("Simulation is not running.") + entities: list[SpawnedEntity] = [] + for entity in self.spawned_entities: + current_pose = self.get_object_pose(entity) + entities.append( + SpawnedEntity( + id=entity.id, + name=entity.name, + prefab_name=entity.prefab_name, + pose=current_pose, + ) + ) + return SceneState(entities=entities) + + def setup_scene(self, simulation_config: O3DExROS2SimulationConfig): + if self.current_binary_path != simulation_config.binary_path: + if self.current_sim_process: + self.shutdown() + self._launch_binary(simulation_config.binary_path) + self._launch_robotic_stack(simulation_config.robotic_stack_command) + self.current_binary_path = simulation_config.binary_path + + else: + while self.spawned_entities: + self._despawn_entity(self.spawned_entities[0]) + + for entity in simulation_config.entities: + self._spawn_entity(entity) + + def _launch_binary(self, binary_path: Path): + command = [binary_path.as_posix()] + self.logger.info(f"Running command: {command}") + self.current_sim_process = subprocess.Popen( + command, + ) + if not self._has_process_started(process=self.current_sim_process): + raise RuntimeError("Process did not start in time.") + + def _launch_robotic_stack(self, robotic_stack_command: str): + command = shlex.split(robotic_stack_command) + self.logger.info(f"Running command: {command}") + self.current_robotic_stack_process = subprocess.Popen( + command, + ) + if not self._has_process_started(self.current_robotic_stack_process): + raise RuntimeError("Process did not start in time.") + + def _has_process_started(self, process: subprocess.Popen[Any], timeout: int = 15): + start_time = time.time() + while time.time() - start_time < timeout: + if process.poll() is None: + self.logger.info(f"Process started with PID {process.pid}") + return True + time.sleep(1) + return False + + def _try_service_call( + self, msg: ROS2ARIMessage, target: str, msg_type: str, n_retries: int = 3 + ) -> ROS2ARIMessage: + if n_retries < 1: + raise ValueError("Number of retries must be at least 1") + for _ in range(n_retries): + try: + response = self.connector.service_call( + msg, target=target, msg_type=msg_type + ) + except Exception as e: + error_message = f"Error while calling service {target} with msg_type {msg_type}: {e}" + self.logger.error(error_message) + raise RuntimeError(error_message) + if response.payload.success: + return response + self.logger.warning( + f"Retrying {target}, response success: {response.payload.success}" + ) + return response # type: ignore + + # NOTE (mkotynia) probably to be refactored, other bridges may also want to use pose conversion to/from ROS2 format + def to_ros2_pose(self, pose: PoseModel) -> Pose: + """ + Converts pose in PoseModel format to pose in ROS2 Pose format. + """ + position = Point( + x=pose.translation.x, y=pose.translation.y, z=pose.translation.z + ) + + if pose.rotation is not None: + orientation = Quaternion( + x=pose.rotation.x, + y=pose.rotation.y, + z=pose.rotation.z, + w=pose.rotation.w, + ) + else: + orientation = Quaternion() + + ros2_pose = Pose(position=position, orientation=orientation) + + return ros2_pose + + def from_ros2_pose(self, pose: Pose) -> PoseModel: + """ + Converts ROS2 pose to PoseModel format + """ + + translation = Translation( + x=pose.position.x, # type: ignore + y=pose.position.y, # type: ignore + z=pose.position.z, # type: ignore + ) + + rotation = Rotation( + x=pose.orientation.x, # type: ignore + y=pose.orientation.y, # type: ignore + z=pose.orientation.z, # type: ignore + w=pose.orientation.w, # type: ignore + ) + + return PoseModel(translation=translation, rotation=rotation) diff --git a/src/rai_sim/rai_sim/simulation_bridge.py b/src/rai_sim/rai_sim/simulation_bridge.py new file mode 100644 index 000000000..93959382d --- /dev/null +++ b/src/rai_sim/rai_sim/simulation_bridge.py @@ -0,0 +1,119 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Generic, List, Optional, TypeVar + +import yaml +from pydantic import BaseModel, field_validator + + +class Translation(BaseModel): + x: float + y: float + z: float + + +class Rotation(BaseModel): + x: float + y: float + z: float + w: float + + +class PoseModel(BaseModel): + translation: Translation + rotation: Optional[Rotation] + + +class Entity(BaseModel): + name: str + prefab_name: str + pose: PoseModel + + +class SpawnedEntity(Entity): + id: str + + +class SimulationConfig(BaseModel): + """ + Setup of simulation - arrangemenet of objects in the environment. + """ + + # NOTE (mkotynia) can be extended by other attributes + entities: List[Entity] + + @field_validator("entities") + @classmethod + def check_unique_names(cls, entities: List[Entity]) -> List[Entity]: + names = [entity.name for entity in entities] + if len(names) != len(set(names)): + raise ValueError("Each entity must have a unique name.") + return entities + + @classmethod + def load_base_config(cls, base_config_path: Path) -> "SimulationConfig": + with open(base_config_path) as f: + content = yaml.safe_load(f) + return cls(**content) + + +class SceneState(BaseModel): + """ + Info about current entities' state in the scene. + """ + + # NOTE (mkotynia) can be extended by other attributes + entities: List[SpawnedEntity] + + +SimulationConfigT = TypeVar("SimulationConfigT", bound=SimulationConfig) + + +class SimulationBridge(ABC, Generic[SimulationConfigT]): + """ + Responsible for communication with simulation. + """ + + def __init__(self, logger: Optional[logging.Logger] = None): + self.spawned_entities: List[ + SpawnedEntity + ] = [] # list of spawned entities with their initial poses + if logger is None: + self.logger = logging.getLogger(__name__) + else: + self.logger = logger + + @abstractmethod + def setup_scene(self, simulation_config: SimulationConfigT): + pass + + @abstractmethod + def _spawn_entity(self, entity: Entity): + pass + + @abstractmethod + def _despawn_entity(self, entity: SpawnedEntity): + pass + + @abstractmethod + def get_object_pose(self, entity: SpawnedEntity) -> PoseModel: + pass + + @abstractmethod + def get_scene_state(self) -> SceneState: + pass From 3dc3d15f28b114348d95ed76faa55e3f10dfb316 Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Wed, 26 Feb 2025 09:58:57 +0100 Subject: [PATCH 2/8] feat: rai_bench (#436) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Maciej Majek Co-authored-by: Bartłomiej Boczek Co-authored-by: MagdalenaKotynia chore: pre-commit --- .gitignore | 2 + examples/manipulation-demo.launch.py | 49 +-- examples/manipulation-demo.py | 20 +- poetry.lock | 15 +- pyproject.toml | 3 +- setup_shell.sh | 3 + src/rai_bench/README.md | 60 ++++ src/rai_bench/pyproject.toml | 17 + src/rai_bench/rai_bench/benchmark_model.py | 290 ++++++++++++++++++ src/rai_bench/rai_bench/main.py | 187 +++++++++++ .../rai_bench/o3de_test_bench/__init__.py | 13 + .../o3de_test_bench/configs/scene1.yaml | 25 ++ .../o3de_test_bench/configs/scene2.yaml | 51 +++ .../o3de_test_bench/configs/scene3.yaml | 25 ++ .../o3de_test_bench/configs/scene4.yaml | 50 +++ .../o3de_test_bench/tasks/__init__.py | 17 + .../o3de_test_bench/tasks/grab_carrot_task.py | 101 ++++++ .../o3de_test_bench/tasks/place_cubes_task.py | 104 +++++++ .../rai/agents/conversational_agent.py | 2 +- src/rai_core/rai/agents/tool_runner.py | 7 +- .../rai/communication/ros2/connectors.py | 33 +- src/rai_core/rai/tools/ros/manipulation.py | 53 +--- src/rai_core/rai/tools/ros/utils.py | 4 +- .../rai_open_set_vision/examples/talker.py | 8 +- .../services/grounded_sam.py | 7 +- .../services/grounding_dino.py | 8 +- .../rai_open_set_vision/tools/gdino_tools.py | 31 +- .../tools/segmentation_tools.py | 114 +++++-- src/rai_sim/rai_sim/o3de/o3de_bridge.py | 93 +++++- src/rai_sim/rai_sim/simulation_bridge.py | 2 +- 30 files changed, 1229 insertions(+), 165 deletions(-) create mode 100644 src/rai_bench/README.md create mode 100644 src/rai_bench/pyproject.toml create mode 100644 src/rai_bench/rai_bench/benchmark_model.py create mode 100644 src/rai_bench/rai_bench/main.py create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/__init__.py create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py diff --git a/.gitignore b/.gitignore index 48d048836..9ee309471 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,5 @@ logs/ src/examples/*-demo artifact_database.pkl + +imgui.ini diff --git a/examples/manipulation-demo.launch.py b/examples/manipulation-demo.launch.py index a9210698f..35720a6af 100644 --- a/examples/manipulation-demo.launch.py +++ b/examples/manipulation-demo.launch.py @@ -12,22 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import rclpy -from launch import LaunchContext, LaunchDescription +from launch import LaunchDescription from launch.actions import ( DeclareLaunchArgument, ExecuteProcess, IncludeLaunchDescription, - OpaqueFunction, - RegisterEventHandler, ) -from launch.event_handlers import OnExecutionComplete, OnProcessStart from launch.launch_description_sources import PythonLaunchDescriptionSource from launch.substitutions import LaunchConfiguration from launch_ros.actions import Node from launch_ros.substitutions import FindPackageShare -from rclpy.qos import QoSProfile, ReliabilityPolicy -from rosgraph_msgs.msg import Clock def generate_launch_description(): @@ -46,21 +40,6 @@ def generate_launch_description(): output="screen", ) - def wait_for_clock_message(context: LaunchContext, *args, **kwargs): - rclpy.init() - node = rclpy.create_node("wait_for_game_launcher") - node.create_subscription( - Clock, - "/clock", - lambda msg: rclpy.shutdown(), - QoSProfile(depth=1, reliability=ReliabilityPolicy.BEST_EFFORT), - ) - rclpy.spin(node) - return None - - # Game launcher will start publishing the clock message after loading the simulation - wait_for_game_launcher = OpaqueFunction(function=wait_for_clock_message) - launch_moveit = IncludeLaunchDescription( PythonLaunchDescriptionSource( [ @@ -72,7 +51,7 @@ def wait_for_clock_message(context: LaunchContext, *args, **kwargs): launch_robotic_manipulation = Node( package="robotic_manipulation", executable="robotic_manipulation", - name="robotic_manipulation_node", + # name="robotic_manipulation_node", output="screen", parameters=[ {"use_sim_time": True}, @@ -90,28 +69,10 @@ def wait_for_clock_message(context: LaunchContext, *args, **kwargs): return LaunchDescription( [ - # Include the game_launcher argument game_launcher_arg, - # Launch the game launcher and wait for it to load launch_game_launcher, - RegisterEventHandler( - event_handler=OnProcessStart( - target_action=launch_game_launcher, - on_start=[ - wait_for_game_launcher, - ], - ) - ), - # Launch the MoveIt node after loading the simulation - RegisterEventHandler( - event_handler=OnExecutionComplete( - target_action=wait_for_game_launcher, - on_completion=[ - launch_openset, - launch_moveit, - launch_robotic_manipulation, - ], - ) - ), + launch_openset, + launch_moveit, + launch_robotic_manipulation, ] ) diff --git a/examples/manipulation-demo.py b/examples/manipulation-demo.py index 6f73248c2..92c820031 100644 --- a/examples/manipulation-demo.py +++ b/examples/manipulation-demo.py @@ -12,37 +12,37 @@ # See the License for the specific language goveself.rning permissions and # limitations under the License. -import threading import rclpy import rclpy.qos from langchain_core.messages import HumanMessage from rai.agents.conversational_agent import create_conversational_agent -from rai.node import RaiBaseNode +from rai.communication.ros2.connectors import ROS2ARIConnector from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool -from rai.tools.ros.native import GetCameraImage, Ros2GetTopicsNamesAndTypesTool +from rai.tools.ros2.topics import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool from rai.utils.model_initialization import get_llm_model +from rai_open_set_vision.tools import GetGrabbingPointTool def create_agent(): rclpy.init() - node = RaiBaseNode(node_name="manipulation_demo") + connector = ROS2ARIConnector() + node = connector.node node.declare_parameter("conversion_ratio", 1.0) - threading.Thread(target=node.spin).start() - tools = [ GetObjectPositionsTool( - node=node, + connector=connector, target_frame="panda_link0", source_frame="RGBDCamera5", camera_topic="/color_image5", depth_topic="/depth_image5", camera_info_topic="/color_camera_info5", + get_grabbing_point_tool=GetGrabbingPointTool(connector=connector), ), - MoveToPointTool(node=node, manipulator_frame="panda_link0"), - GetCameraImage(node=node), - Ros2GetTopicsNamesAndTypesTool(node=node), + MoveToPointTool(connector=connector, manipulator_frame="panda_link0"), + GetROS2ImageTool(connector=connector), + GetROS2TopicsNamesAndTypesTool(connector=connector), ] llm = get_llm_model(model_type="complex_model", streaming=True) diff --git a/poetry.lock b/poetry.lock index 2d22b986d..9eb39f9fa 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5700,6 +5700,19 @@ torchaudio = "^2.3.1" type = "directory" url = "src/rai_asr" +[[package]] +name = "rai-bench" +version = "0.1.0" +description = "" +optional = false +python-versions = "^3.10" +files = [] +develop = true + +[package.source] +type = "directory" +url = "src/rai_bench" + [[package]] name = "rai-sim" version = "0.0.1" @@ -8318,4 +8331,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10, <3.13" -content-hash = "a53906ce2c798e5e0a02c7db25cf00cf36e021186a79429ba1bd8f0836b12db2" +content-hash = "c5469635a5db79c258554ad9f4e49331515940e406fbf912822651a0e0c33dda" diff --git a/pyproject.toml b/pyproject.toml index 863309947..a3edd3730 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ rai = {path = "src/rai_core", develop = true} rai_asr = {path = "src/rai_asr", develop = true} rai_tts = {path = "src/rai_tts", develop = true} rai_sim = {path = "src/rai_sim", develop = true} +rai_bench = {path = "src/rai_bench", develop = true} langchain-core = "^0.3" langchain = "*" @@ -30,7 +31,6 @@ requests = "^2.32.2" pre-commit = "^3.7.0" openai = "^1.23.3" coloredlogs = "^15.0.1" -opencv-python = "^4.9.0.80" markdown = "^3.6" boto3 = "^1.34.98" tqdm = "^4.66.4" @@ -62,6 +62,7 @@ pytest-timeout = "^2.3.1" tomli-w = "^1.1.0" faster-whisper = "^1.1.1" pydub = "^0.25.1" +opencv-python = "^4.11.0.86" [tool.poetry.group.dev.dependencies] ipykernel = "^6.29.4" diff --git a/setup_shell.sh b/setup_shell.sh index cc67a5369..164372ed9 100755 --- a/setup_shell.sh +++ b/setup_shell.sh @@ -30,3 +30,6 @@ esac export PYTHONPATH PYTHONPATH="$(dirname "$(dirname "$(poetry run which python)")")/lib/python$(poetry run python --version | awk '{print $2}' | cut -d. -f1,2)/site-packages:$PYTHONPATH" +PYTHONPATH="src/rai_core:$PYTHONPATH" +PYTHONPATH="src/rai_asr:$PYTHONPATH" +PYTHONPATH="src/rai_tts:$PYTHONPATH" diff --git a/src/rai_bench/README.md b/src/rai_bench/README.md new file mode 100644 index 000000000..e2abd3d4f --- /dev/null +++ b/src/rai_bench/README.md @@ -0,0 +1,60 @@ +## RAI Benchmark + +## Description + +The RAI Bench is a package including benchmarks and providing frame for creating new benchmarks + +## Frame Components + +Frame components can be found in `src/rai_bench/rai_bench/benchmark_model.py` + +- `Task` - abstract class for creating specific task. It introduces helper funtions that make it easier to calculate metrics/scores. Your custom tasks must implement a prompt got agent to do, a way to calculate a result and a validation if given scene config suits the task. +- +- `Scenario` - class defined by a Scene and Task. Can be created manually like: + + ```python + + ``` + +- `Benchmark` - class responsible for running and logging scenarios. + +### O3DE TEST BENCHMARK + +O3DE Test Benchmark (src/rai_bench/rai_bench/o3de_test_bench/), contains 2 Tasks(tasks/) - GrabCarrotTask and PlaceCubesTask (these tasks implement calculating scores) and 4 scene_configs(configs/) for O3DE robotic arm simulation. + +Both tasks calculate score, taking into consideration 4 values: + +- initially_misplaced_now_correct - when the object which was in the incorrect place at the start, is in a correct place at the end +- initially_misplaced_still_incorrect - when the object which was in the incorrect place at the start, is in a incorrect place at the end +- initially_correct_still_correct - when the object which was in the correct place at the start, is in a correct place at the end +- initially_correct_now_incorrect - when the object which was in the correct place at the start, is in a incorrect place at the end + +The result is a value between 0 and 1, calculated like (initially_misplaced_now_correct + initially_correct_still_correct) / number_of_initial_objects. +This score is calculated at the beggining and at the end of each scenario. + +### Example usage + +Example of how to load scenes, define scenarios and run benchmark can be found in `src/rai_bench/rai_bench/benchmark_main.py` + +Scenarios can be loaded manually like: + +```python +one_carrot_simulation_config = O3DExROS2SimulationConfig.load_config( + base_config_path=Path("path_to_scene.yaml"), + connector_config_path=Path("path_to_o3de_config.yaml"), + ) + +Scenario(task=GrabCarrotTask(logger=some_logger), simulation_config=one_carrot_simulation_config) +``` + +or automatically like: + +```python +scenarios = Benchmark.create_scenarios( + tasks=tasks, simulation_configs=simulations_configs + ) +``` + +which will result in list of scenarios with combination of every possible task and scene(task decides if scene config is suitable for it). + +Both approaches can be found in `main.py` diff --git a/src/rai_bench/pyproject.toml b/src/rai_bench/pyproject.toml new file mode 100644 index 000000000..52255eb9a --- /dev/null +++ b/src/rai_bench/pyproject.toml @@ -0,0 +1,17 @@ +[tool.poetry] +name = "rai-bench" +version = "0.1.0" +description = "" +authors = ["jmatejcz "] +readme = "README.md" + +packages = [ + { include = "rai_bench", from = "." }, +] +[tool.poetry.dependencies] +python = "^3.10" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/src/rai_bench/rai_bench/benchmark_model.py b/src/rai_bench/rai_bench/benchmark_model.py new file mode 100644 index 000000000..bc47ef407 --- /dev/null +++ b/src/rai_bench/rai_bench/benchmark_model.py @@ -0,0 +1,290 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import logging +import time +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, List, Union + +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from rai.messages import HumanMultimodalMessage +from rclpy.impl.rcutils_logger import RcutilsLogger + +from rai_sim.simulation_bridge import ( + PoseModel, + SimulationBridge, + SimulationConfig, + SimulationConfigT, + SpawnedEntity, +) + +loggers_type = Union[RcutilsLogger, logging.Logger] + + +class EntitiesMismatchException(Exception): + def __init__(self, message: str) -> None: + super().__init__(message) + + +class Task(ABC): + """ + Task to perform. + Specyfic implementation should implement a way to calculate results. + Abstract provides utility functions for common calculations, that can be usefull when + creating metrics + """ + + def __init__( + self, + logger: loggers_type | None = None, + ) -> None: + if logger: + self.logger = logger + else: + self.logger = logging.getLogger(__name__) + + @abstractmethod + def get_prompt(self) -> str: + pass + + @abstractmethod + def validate_config(self, simulation_config: SimulationConfig) -> bool: + """Task should be able to verify if given config is suitable for specific task + + Args: + simulation_config (SimulationConfig): initial scene setup + Returns: + bool: True is suitable, False otherwise + """ + pass + + @abstractmethod + def calculate_result( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> float: + """ + Calculate result of the task + """ + pass + + def filter_entities_by_prefab_type( + self, entities: List[SpawnedEntity], prefab_types: List[str] + ) -> List[SpawnedEntity]: + """Filter and return only these entities that match provided prefab types""" + return [ent for ent in entities if ent.prefab_name in prefab_types] + + def euclidean_distance(self, pos1: PoseModel, pos2: PoseModel) -> float: + """Calculate euclidean distance between 2 positions""" + return ( + (pos1.translation.x - pos2.translation.x) ** 2 + + (pos1.translation.y - pos2.translation.y) ** 2 + + (pos1.translation.z - pos2.translation.z) ** 2 + ) ** 0.5 + + def is_adjacent(self, pos1: PoseModel, pos2: PoseModel, threshold_distance: float): + """ + Check if positions are adjacent to each other, the threshold_distance is a distance + in simulation, refering to how close they have to be to classify them as adjacent + """ + self.logger.debug( # type: ignore + f"Euclidean distance: {self.euclidean_distance(pos1, pos2)}, pos1: {pos1}, pos2: {pos2}" + ) + return self.euclidean_distance(pos1, pos2) < threshold_distance + + def is_adjacent_to_any( + self, pos1: PoseModel, positions: List[PoseModel], threshold_distance: float + ) -> bool: + """ + Check if given position is adjacent to any position in the given list. + """ + + return any( + self.is_adjacent(pos1, pos2, threshold_distance) for pos2 in positions + ) + + def count_adjacent( + self, positions: List[PoseModel], threshold_distance: float + ) -> int: + """ + Count how many adjacent positions are in the given list. + Note that position has to be adjacent to only 1 other position + to be counted, not all of them + """ + adjacent_count = 0 + + for i, p1 in enumerate(positions): + for j, p2 in enumerate(positions): + if i != j: + if self.is_adjacent(p1, p2, threshold_distance): + adjacent_count += 1 + break + + return adjacent_count + + +class Scenario(Generic[SimulationConfigT]): + """Single instances are run separatly by benchmark""" + + def __init__( + self, + task: Task, + simulation_config: SimulationConfigT, + simulation_config_path: str, + ) -> None: + if not task.validate_config(simulation_config): + raise ValueError("This scene is invalid for this task.") + self.task = task + self.simulation_config = simulation_config + # NOTE (jm) needed for logging which config was used, + # there probably is better method to do it + self.simulation_config_path = simulation_config_path + + +class Benchmark: + """ + Defined by a set of scenarios to be done + """ + + def __init__( + self, + simulation_bridge: SimulationBridge[SimulationConfigT], + scenarios: List[Scenario[SimulationConfigT]], + logger: loggers_type | None = None, + ) -> None: + self.simulation_bridge = simulation_bridge + self.num_of_scenarios = len(scenarios) + self.scenarios = enumerate(iter(scenarios)) + self.results: List[Dict[str, Any]] = [] + if logger: + self._logger = logger + else: + self._logger = logging.getLogger(__name__) + + @classmethod + def create_scenarios( + cls, + tasks: List[Task], + simulation_configs: List[SimulationConfigT], + simulation_configs_paths: List[str], + ) -> List[Scenario[SimulationConfigT]]: + # TODO (jm) hacky_fix, taking paths as args here, not the best solution, + # but more changes to code would be required + scenarios: List[Scenario[SimulationConfigT]] = [] + for task in tasks: + for sim_conf, sim_path in zip(simulation_configs, simulation_configs_paths): + try: + scenarios.append( + Scenario( + task=task, + simulation_config=sim_conf, + simulation_config_path=sim_path, + ) + ) + except ValueError as e: + print( + f"Could not create Scenario from task: {task.get_prompt()} and simulation_config: {sim_conf}, {e}" + ) + return scenarios + + def run_next(self, agent) -> None: + """ + Runs the next scenario + """ + try: + i, scenario = next(self.scenarios) # Get the next scenario + + self.simulation_bridge.setup_scene(scenario.simulation_config) + self._logger.info( # type: ignore + "======================================================================================" + ) + self._logger.info( # type: ignore + f"RUNNING SCENARIO NUMBER {i + 1} / {self.num_of_scenarios}, TASK: {scenario.task.get_prompt()}" + ) + initial_result = scenario.task.calculate_result(self.simulation_bridge) + self._logger.info(f"RESULT OF THE INITIAL SETUP: {initial_result}") # type: ignore + tool_calls_num = 0 + + ts = time.perf_counter() + for state in agent.stream( + {"messages": [HumanMessage(content=scenario.task.get_prompt())]} + ): + graph_node_name = list(state.keys())[0] + msg = state[graph_node_name]["messages"][-1] + + if isinstance(msg, HumanMultimodalMessage): + last_msg = msg.text + elif isinstance(msg, BaseMessage): + if isinstance(msg.content, list): + if len(msg.content) == 1: + if type(msg.content[0]) is dict: + last_msg = msg.content[0].get("text", "") + else: + last_msg = msg.content + self._logger.debug(f"{graph_node_name}: {last_msg}") # type: ignore + + else: + raise ValueError(f"Unexpected type of message: {type(msg)}") + + if isinstance(msg, AIMessage): + # TODO (jm) figure out more robust way of counting tool calls + tool_calls_num += len(msg.tool_calls) + + self._logger.info(f"AI Message: {msg}") # type: ignore + + te = time.perf_counter() + + result = scenario.task.calculate_result(self.simulation_bridge) + total_time = te - ts + self._logger.info( # type: ignore + f"TASK SCORE: {result}, TOTAL TIME: {total_time:.3f}, NUM_OF_TOOL_CALLS: {tool_calls_num}" + ) + + self.results.append( + { + "task": scenario.task.get_prompt(), + "simulation_config": scenario.simulation_config_path, + "initial_score": initial_result, + "final_score": result, + "total_time": f"{total_time:.3f}", + "number_of_tool_calls": tool_calls_num, + } + ) + + except StopIteration: + print("No more scenarios left to run.") + + def get_results(self) -> List[Dict[str, Any]]: + return self.results + + def dump_results_to_csv(self, filename: str) -> None: + if not self.results: + self._logger.warning("No results to save.") # type: ignore + return + + fieldnames = [ + "task", + "initial_score", + "simulation_config", + "final_score", + "total_time", + "number_of_tool_calls", + ] + + with open(filename, mode="w", newline="", encoding="utf-8") as file: + writer = csv.DictWriter(file, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(self.results) + + self._logger.info(f"Results saved to {filename}") # type: ignore diff --git a/src/rai_bench/rai_bench/main.py b/src/rai_bench/rai_bench/main.py new file mode 100644 index 000000000..7875c92c4 --- /dev/null +++ b/src/rai_bench/rai_bench/main.py @@ -0,0 +1,187 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +########### EXAMPLE USAGE ########### +import logging +import time +from pathlib import Path +from typing import List + +import rclpy +from langchain.tools import BaseTool +from rai.agents.conversational_agent import create_conversational_agent +from rai.communication.ros2.connectors import ROS2ARIConnector +from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool +from rai.tools.ros2.topics import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool +from rai.utils.model_initialization import get_llm_model +from rai_open_set_vision.tools import GetGrabbingPointTool + +from rai_bench.benchmark_model import Benchmark, Task +from rai_bench.o3de_test_bench.tasks import GrabCarrotTask, PlaceCubesTask +from rai_sim.o3de.o3de_bridge import ( + O3DEngineArmManipulationBridge, + O3DExROS2SimulationConfig, + PoseModel, +) +from rai_sim.simulation_bridge import Rotation, Translation + +if __name__ == "__main__": + rclpy.init() + connector = ROS2ARIConnector() + node = connector.node + node.declare_parameter("conversion_ratio", 1.0) + + # define model + llm = get_llm_model(model_type="complex_model", streaming=True) + + system_prompt = """ + You are a robotic arm with interfaces to detect and manipulate objects. + Here are the coordinates information: + x - front to back (positive is forward) + y - left to right (positive is right) + z - up to down (positive is up) + Before starting the task, make sure to grab the camera image to understand the environment. + """ + # define tools + tools: List[BaseTool] = [ + GetObjectPositionsTool( + connector=connector, + target_frame="panda_link0", + source_frame="RGBDCamera5", + camera_topic="/color_image5", + depth_topic="/depth_image5", + camera_info_topic="/color_camera_info5", + get_grabbing_point_tool=GetGrabbingPointTool(connector=connector), + ), + MoveToPointTool(connector=connector, manipulator_frame="panda_link0"), + GetROS2ImageTool(connector=connector), + GetROS2TopicsNamesAndTypesTool(connector=connector), + ] + # define loggers + log_file = "src/rai_bench/rai_bench/benchmark.log" + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + file_handler.setFormatter(formatter) + + bench_logger = logging.getLogger("Benchmark logger") + bench_logger.setLevel(logging.INFO) + bench_logger.addHandler(file_handler) + + agent_logger = logging.getLogger("Agent logger") + agent_logger.setLevel(logging.INFO) + agent_logger.addHandler(file_handler) + + configs_dir = "src/rai_bench/rai_bench/o3de_test_bench/configs/" + connector_path = configs_dir + "o3de_config.yaml" + #### Create scenarios manually + # load different scenes + # one_carrot_simulation_config = O3DExROS2SimulationConfig.load_config( + # base_config_path=Path(configs_dir + "scene1.yaml"), + # connector_config_path=Path(connector_path), + # ) + # multiple_carrot_simulation_config = O3DExROS2SimulationConfig.load_config( + # base_config_path=Path(configs_dir + "scene2.yaml"), + # connector_config_path=Path(connector_path), + # ) + # red_cubes_simulation_config = O3DExROS2SimulationConfig.load_config( + # base_config_path=Path(configs_dir + "scene3.yaml"), + # connector_config_path=Path(connector_path), + # ) + # multiple_cubes_simulation_config = O3DExROS2SimulationConfig.load_config( + # base_config_path=Path(configs_dir + "scene4.yaml"), + # connector_config_path=Path(connector_path), + # ) + # # combine different scene configs with the tasks to create various scenarios + # scenarios = [ + # Scenario( + # task=GrabCarrotTask(logger=bench_logger), + # simulation_config=one_carrot_simulation_config, + # simulation_config_path=configs_dir + "scene1.yaml", + # ), + # Scenario( + # task=GrabCarrotTask(logger=bench_logger), + # simulation_config=multiple_carrot_simulation_config, + # simulation_config_path=configs_dir + "scene2.yaml", + # ), + # Scenario( + # task=PlaceCubesTask(logger=bench_logger), + # simulation_config=red_cubes_simulation_config, + # simulation_config_path=configs_dir + "scene3.yaml", + # ), + # Scenario( + # task=PlaceCubesTask(logger=bench_logger), + # simulation_config=multiple_cubes_simulation_config, + # simulation_config_path=configs_dir + "scene4.yaml", + # ), + # ] + + ### Create scenarios automatically + simulation_configs_paths = [ + configs_dir + "scene1.yaml", + configs_dir + "scene2.yaml", + configs_dir + "scene3.yaml", + configs_dir + "scene4.yaml", + ] + simulations_configs = [ + O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) + for path in simulation_configs_paths + ] + tasks: List[Task] = [ + GrabCarrotTask(logger=bench_logger), + PlaceCubesTask(logger=bench_logger), + ] + scenarios = Benchmark.create_scenarios( + tasks=tasks, + simulation_configs=simulations_configs, + simulation_configs_paths=simulation_configs_paths, + ) + + # custom request to arm + base_arm_pose = PoseModel( + translation=Translation(x=0.5, y=0.1, z=0.3), + rotation=Rotation(x=1.0, y=0.0, z=0.0, w=0.0), + ) + + o3de = O3DEngineArmManipulationBridge(connector, logger=agent_logger) + # define benchamrk + benchmark = Benchmark( + simulation_bridge=o3de, + scenarios=scenarios, + logger=bench_logger, + ) + for i, s in enumerate(scenarios): + agent = create_conversational_agent( + llm, tools, system_prompt, logger=agent_logger + ) + benchmark.run_next(agent=agent) + o3de.move_arm( + pose=base_arm_pose, + initial_gripper_state=True, + final_gripper_state=False, + frame_id="panda_link0", + ) # return to case position + time.sleep(2) # admire the end position for a second ;) + + bench_logger.info("===============================================================") + bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") + bench_logger.info("===============================================================") + benchmark.dump_results_to_csv(filename="src/rai_bench/rai_bench/results.csv") + + connector.shutdown() + o3de.shutdown() + rclpy.shutdown() diff --git a/src/rai_bench/rai_bench/o3de_test_bench/__init__.py b/src/rai_bench/rai_bench/o3de_test_bench/__init__.py new file mode 100644 index 000000000..97ceef6f0 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/__init__.py @@ -0,0 +1,13 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml new file mode 100644 index 000000000..a683362a6 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml @@ -0,0 +1,25 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: corn2 + prefab_name: corn + pose: + translation: + x: 0.5 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml new file mode 100644 index 000000000..2be04e047 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml @@ -0,0 +1,51 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: carrot2 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: 0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: carrot3 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: carrot4 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: -0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml new file mode 100644 index 000000000..1eef69a48 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml @@ -0,0 +1,25 @@ +entities: + - name: cube1 + prefab_name: red_cube + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: cube2 + prefab_name: red_cube + pose: + translation: + x: 0.5 + y: 0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml new file mode 100644 index 000000000..00b814c93 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml @@ -0,0 +1,50 @@ +entities: + - name: cube1 + prefab_name: red_cube + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: cube2 + prefab_name: blue_cube + pose: + translation: + x: 0.4 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: cube3 + prefab_name: yellow_cube + pose: + translation: + x: 0.5 + y: -0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: cube4 + prefab_name: yellow_cube + pose: + translation: + x: 0.5 + y: 0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py new file mode 100644 index 000000000..5be82bf8c --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py @@ -0,0 +1,17 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from rai_bench.o3de_test_bench.tasks.grab_carrot_task import GrabCarrotTask +from rai_bench.o3de_test_bench.tasks.place_cubes_task import PlaceCubesTask + +__all__ = ["GrabCarrotTask", "PlaceCubesTask"] diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py new file mode 100644 index 000000000..ca040fd62 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py @@ -0,0 +1,101 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rai_bench.benchmark_model import ( + EntitiesMismatchException, + Task, +) +from rai_sim.o3de.o3de_bridge import ( + SimulationBridge, +) +from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT + + +class GrabCarrotTask(Task): + def get_prompt(self) -> str: + return "Manipulate objects, so that all carrots to the left side of the table (positive y)" + + def validate_config(self, simulation_config: SimulationConfig) -> bool: + for ent in simulation_config.entities: + if ent.prefab_name == "carrot": + return True + + return False + + def calculate_result( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> float: + # TODO (jm) extract common logic to some parent manipulation task? + initially_misplaced_now_correct = 0 # when the object which was in the incorrect place at the start, is in a correct place at the end + initially_misplaced_still_incorrect = 0 # when the object which was in the incorrect place at the start, is in a incorrect place at the end + initially_correct_still_correct = 0 # when the object which was in the correct place at the start, is in a correct place at the end + initially_correct_now_incorrect = 0 # when the object which was in the correct place at the start, is in a incorrect place at the end + + scene_state = simulation_bridge.get_scene_state() + initial_carrots = self.filter_entities_by_prefab_type( + simulation_bridge.spawned_entities, prefab_types=["carrot"] + ) + final_carrots = self.filter_entities_by_prefab_type( + scene_state.entities, prefab_types=["carrot"] + ) + num_initial_carrots = len(initial_carrots) + + if num_initial_carrots != len(final_carrots): + raise EntitiesMismatchException( + "Number of initially spawned entities does not match number of entities present at the end." + ) + + else: + self.logger.debug(f"initial positions: {initial_carrots}") # type: ignore + self.logger.debug(f"current positions: {final_carrots}") # type: ignore + for ini_carrot in initial_carrots: + for final_carrot in final_carrots: + if ini_carrot.name == final_carrot.name: + initial_y = ini_carrot.pose.translation.y + final_y = final_carrot.pose.translation.y + # NOTE the specific coords that refer to for example + # middle of the table can differ across simulations, + # take that into consideration + if ( + initial_y <= 0.0 + ): # Carrot started in the incorrect place (right side) + if final_y >= 0.0: + initially_misplaced_now_correct += ( + 1 # Moved to correct side + ) + else: + initially_misplaced_still_incorrect += ( + 1 # Stayed on incorrect side + ) + else: # Carrot started in the correct place (left side) + if final_y >= 0.0: + initially_correct_still_correct += ( + 1 # Stayed on correct side + ) + else: + initially_correct_now_incorrect += ( + 1 # Moved incorrectly to the wrong side + ) + break + else: + raise EntitiesMismatchException( + f"Entity with name: {ini_carrot.name} which was present in initial scene, not found in final scene." + ) + + self.logger.info( # type: ignore + f"initially_misplaced_now_correct: {initially_misplaced_now_correct}, initially_misplaced_still_incorrect: {initially_misplaced_still_incorrect}, initially_correct_still_correct: {initially_correct_still_correct}, initially_correct_now_incorrect: {initially_correct_now_incorrect}" + ) + return ( + initially_misplaced_now_correct + initially_correct_still_correct + ) / num_initial_carrots diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py new file mode 100644 index 000000000..26bdd590e --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py @@ -0,0 +1,104 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rai_bench.benchmark_model import ( + EntitiesMismatchException, + Task, +) +from rai_sim.o3de.o3de_bridge import SimulationBridge +from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT + + +class PlaceCubesTask(Task): + def get_prompt(self) -> str: + return "Manipulate objects, so that all cubes are adjacent to at least one cube" + + def validate_config(self, simulation_config: SimulationConfig) -> bool: + cube_types = ["red_cube", "blue_cube", "yellow_cube"] + cubes_num = 0 + for ent in simulation_config.entities: + if ent.prefab_name in cube_types: + cubes_num += 1 + if cubes_num > 1: + return True + + return False + + def calculate_result( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> float: + # TODO (jm) extract common logic to some parent manipulation task? + initially_misplaced_now_correct = 0 # when the object which was in the incorrect place at the start, is in a correct place at the end + initially_misplaced_still_incorrect = 0 # when the object which was in the incorrect place at the start, is in a incorrect place at the end + initially_correct_still_correct = 0 # when the object which was in the correct place at the start, is in a correct place at the end + initially_correct_now_incorrect = 0 # when the object which was in the correct place at the start, is in a incorrect place at the end + + cube_types = ["red_cube", "blue_cube", "yellow_cube"] + scene_state = simulation_bridge.get_scene_state() + + initial_cubes = self.filter_entities_by_prefab_type( + simulation_bridge.spawned_entities, prefab_types=cube_types + ) + final_cubes = self.filter_entities_by_prefab_type( + scene_state.entities, prefab_types=cube_types + ) + num_of_objects = len(initial_cubes) + + if num_of_objects != len(final_cubes): + raise EntitiesMismatchException( + "Number of initially spawned entities does not match number of entities present at the end." + ) + + else: + ini_poses = [cube.pose for cube in initial_cubes] + final_poses = [cube.pose for cube in final_cubes] + # NOTE the specific coords that refer to for example + # middle of the table can differ across simulations, + # take that into consideration + self.logger.debug(f"initial positions: {initial_cubes}") + self.logger.debug(f"current positions: {final_cubes}") + for i, ini_cube in enumerate(initial_cubes): + for j, final_cube in enumerate(final_cubes): + if ini_cube.name == final_cube.name: + was_adjacent_initially = self.is_adjacent_to_any( + ini_cube.pose, + [p for p in ini_poses if p != ini_cube.pose], + 0.15, + ) + is_adjacent_finally = self.is_adjacent_to_any( + final_cube.pose, + [p for p in final_poses if p != final_cube.pose], + 0.15, + ) + if not was_adjacent_initially and is_adjacent_finally: + initially_misplaced_now_correct += 1 + elif not was_adjacent_initially and not is_adjacent_finally: + initially_misplaced_still_incorrect += 1 + elif was_adjacent_initially and is_adjacent_finally: + initially_correct_still_correct += 1 + elif was_adjacent_initially and not is_adjacent_finally: + initially_correct_now_incorrect += 1 + + break + else: + raise EntitiesMismatchException( + f"Entity with name: {ini_cube.name} which was present in initial scene, not found in final scene." + ) + + self.logger.info( + f"initially_misplaced_now_correct: {initially_misplaced_now_correct}, initially_misplaced_still_incorrect: {initially_misplaced_still_incorrect}, initially_correct_still_correct: {initially_correct_still_correct}, initially_correct_now_incorrect: {initially_correct_now_incorrect}" + ) + return ( + initially_misplaced_now_correct + initially_correct_still_correct + ) / num_of_objects diff --git a/src/rai_core/rai/agents/conversational_agent.py b/src/rai_core/rai/agents/conversational_agent.py index 072a1c164..739b159ae 100644 --- a/src/rai_core/rai/agents/conversational_agent.py +++ b/src/rai_core/rai/agents/conversational_agent.py @@ -56,7 +56,7 @@ def create_conversational_agent( debug=False, ): _logger = None - if isinstance(logger, RcutilsLogger): + if logger: _logger = logger else: _logger = logging.getLogger(__name__) diff --git a/src/rai_core/rai/agents/tool_runner.py b/src/rai_core/rai/agents/tool_runner.py index 12e0889d3..5c35ac9a8 100644 --- a/src/rai_core/rai/agents/tool_runner.py +++ b/src/rai_core/rai/agents/tool_runner.py @@ -69,8 +69,13 @@ def run_one(call: ToolCall): ts = time.perf_counter() output = self.tools_by_name[call["name"]].invoke(call, config) # type: ignore te = time.perf_counter() - ts + tool_output_log = ( + str(output.content)[:1000] + "..." + if len(str(output.content)) > 1000 + else "" + ) self.logger.info( - f"Tool {call['name']} completed in {te:.2f} seconds. Tool output: {str(output.content)[:100]}{'...' if len(str(output.content)) > 100 else ''}" + f"Tool {call['name']} completed in {te:.2f} seconds. Tool output: {tool_output_log}" ) self.logger.debug( f"Tool {call['name']} output: \n\n{str(output.content)}" diff --git a/src/rai_core/rai/communication/ros2/connectors.py b/src/rai_core/rai/communication/ros2/connectors.py index af7953d78..fb7db7d3d 100644 --- a/src/rai_core/rai/communication/ros2/connectors.py +++ b/src/rai_core/rai/communication/ros2/connectors.py @@ -70,6 +70,7 @@ def __init__( self._service_api = ROS2ServiceAPI(self._node) self._actions_api = ROS2ActionAPI(self._node) self._tf_buffer = Buffer(node=self._node) + self.tf_listener = TransformListener(self._tf_buffer, self._node) self._executor = MultiThreadedExecutor() self._executor.add_node(self._node) @@ -180,7 +181,6 @@ def get_transform( source_frame: str, timeout_sec: float = 5.0, ) -> TransformStamped: - tf_listener = TransformListener(self._tf_buffer, self._node) transform_available = self.wait_for_transform( self._tf_buffer, target_frame, source_frame, timeout_sec ) @@ -192,20 +192,25 @@ def get_transform( target_frame, source_frame, rclpy.time.Time(), - timeout=Duration(seconds=timeout_sec), + timeout=Duration(seconds=int(timeout_sec)), ) - tf_listener.unregister() + return transform def terminate_action(self, action_handle: str, **kwargs: Any): self._actions_api.terminate_goal(action_handle) + @property + def node(self) -> Node: + return self._node + def shutdown(self): - self._executor.shutdown() - self._thread.join() + self.tf_listener.unregister() + self._node.destroy_node() self._actions_api.shutdown() self._topic_api.shutdown() - self._node.destroy_node() + self._executor.shutdown() + self._thread.join() class ROS2HRIMessage(HRIMessage): @@ -279,15 +284,19 @@ def __init__( ] _targets = [ - target - if isinstance(target, tuple) - else (target, TopicConfig(is_subscriber=False)) + ( + target + if isinstance(target, tuple) + else (target, TopicConfig(is_subscriber=False)) + ) for target in targets ] _sources = [ - source - if isinstance(source, tuple) - else (source, TopicConfig(is_subscriber=True)) + ( + source + if isinstance(source, tuple) + else (source, TopicConfig(is_subscriber=True)) + ) for source in sources ] diff --git a/src/rai_core/rai/tools/ros/manipulation.py b/src/rai_core/rai/tools/ros/manipulation.py index 8deacd603..9436490b0 100644 --- a/src/rai_core/rai/tools/ros/manipulation.py +++ b/src/rai_core/rai/tools/ros/manipulation.py @@ -15,21 +15,15 @@ from typing import Literal, Type import numpy as np -import rclpy -import rclpy.callback_groups -import rclpy.executors -import rclpy.qos -import rclpy.subscription -import rclpy.task from geometry_msgs.msg import Point, Pose, PoseStamped, Quaternion from langchain_core.tools import BaseTool from pydantic import BaseModel, Field from rai_open_set_vision.tools import GetGrabbingPointTool -from rclpy.client import Client -from rclpy.node import Node from tf2_geometry_msgs import do_transform_pose +from rai.communication.ros2.connectors import ROS2ARIConnector from rai.tools.utils import TF2TransformFetcher +from rai.utils.ros_async import get_future_result from rai_interfaces.srv import ManipulatorMoveTo @@ -52,8 +46,7 @@ class MoveToPointTool(BaseTool): "success of grabbing or releasing objects. Use additional sensors or tools for that information." ) - node: Node - client: Client + connector: ROS2ARIConnector = Field(..., exclude=True) manipulator_frame: str = Field(..., description="Manipulator frame") min_z: float = Field(default=0.135, description="Minimum z coordinate [m]") @@ -72,16 +65,6 @@ class MoveToPointTool(BaseTool): args_schema: Type[MoveToPointToolInput] = MoveToPointToolInput - def __init__(self, node: Node, **kwargs): - super().__init__( - node=node, - client=node.create_client( - ManipulatorMoveTo, - "/manipulator_move_to", - ), - **kwargs, - ) - def _run( self, x: float, @@ -89,6 +72,10 @@ def _run( z: float, task: Literal["grab", "drop"], ) -> str: + client = self.connector.node.create_client( + ManipulatorMoveTo, + "/manipulator_move_to", + ) pose_stamped = PoseStamped() pose_stamped.header.frame_id = self.manipulator_frame pose_stamped.pose = Pose( @@ -117,21 +104,18 @@ def _run( request.initial_gripper_state = False # closed request.final_gripper_state = True # open - future = self.client.call_async(request) - self.node.get_logger().debug( + future = client.call_async(request) + self.connector.node.get_logger().debug( f"Calling ManipulatorMoveTo service with request: x={request.target_pose.pose.position.x:.2f}, y={request.target_pose.pose.position.y:.2f}, z={request.target_pose.pose.position.z:.2f}" ) + response = get_future_result(future, timeout_sec=5.0) + if response is None: + return f"Service call failed for point ({x:.2f}, {y:.2f}, {z:.2f})." - rclpy.spin_until_future_complete(self.node, future, timeout_sec=5.0) - - if future.result() is not None: - response = future.result() - if response.success: - return f"End effector successfully positioned at coordinates ({x:.2f}, {y:.2f}, {z:.2f}). Note: The status of object interaction (grab/drop) is not confirmed by this movement." - else: - return f"Failed to position end effector at coordinates ({x:.2f}, {y:.2f}, {z:.2f})." + if response.success: + return f"End effector successfully positioned at coordinates ({x:.2f}, {y:.2f}, {z:.2f}). Note: The status of object interaction (grab/drop) is not confirmed by this movement." else: - return f"Service call failed for point ({x:.2f}, {y:.2f}, {z:.2f})." + return f"Failed to position end effector at coordinates ({x:.2f}, {y:.2f}, {z:.2f})." class GetObjectPositionsToolInput(BaseModel): @@ -153,14 +137,9 @@ class GetObjectPositionsTool(BaseTool): camera_topic: str # rgb camera topic depth_topic: str camera_info_topic: str # rgb camera info topic - node: Node + connector: ROS2ARIConnector = Field(..., exclude=True) get_grabbing_point_tool: GetGrabbingPointTool - def __init__(self, node: Node, **kwargs): - super(GetObjectPositionsTool, self).__init__( - node=node, get_grabbing_point_tool=GetGrabbingPointTool(node=node), **kwargs - ) - args_schema: Type[GetObjectPositionsToolInput] = GetObjectPositionsToolInput @staticmethod diff --git a/src/rai_core/rai/tools/ros/utils.py b/src/rai_core/rai/tools/ros/utils.py index 8c34e23c4..1e207e31c 100644 --- a/src/rai_core/rai/tools/ros/utils.py +++ b/src/rai_core/rai/tools/ros/utils.py @@ -151,7 +151,9 @@ def wait_for_message( if msg_info is not None: return True, msg_info[0] finally: - node.destroy_subscription(sub) + # TODO(boczekbartek): uncomment when rclpy resolves: https://github.com/ros2/rclpy/issues/1142 + # node.destroy_subscription(sub) + pass return False, None diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/examples/talker.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/examples/talker.py index 370c6e488..72fe82198 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/examples/talker.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/examples/talker.py @@ -28,7 +28,9 @@ def __init__(self): self.declare_parameter("image_path", "") self.cli = self.create_client(RAIGroundingDino, "grounding_dino_classify") while not self.cli.wait_for_service(timeout_sec=1.0): - self.get_logger().info("service not available, waiting again...") + self.get_logger().info( + "service grounding_dino_classify not available, waiting again..." + ) self.req = RAIGroundingDino.Request() self.bridge = CvBridge() @@ -56,7 +58,9 @@ def __init__(self): super().__init__(node_name="GSClientExample", parameter_overrides=[]) self.cli = self.create_client(RAIGroundedSam, "grounded_sam_segment") while not self.cli.wait_for_service(timeout_sec=1.0): - self.get_logger().info("service not available, waiting again...") + self.get_logger().info( + "service grounded_sam_segment not available, waiting again..." + ) self.req = RAIGroundedSam.Request() self.bridge = CvBridge() diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py index 45fe9eb52..85fa70488 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py @@ -34,9 +34,6 @@ class GSamService(Node): def __init__(self): super().__init__(node_name=GSAM_NODE_NAME, parameter_overrides=[]) - self.srv = self.create_service( - RAIGroundedSam, GSAM_SERVICE_NAME, self.segment_callback - ) self.declare_parameter("weights_path", "") try: @@ -49,6 +46,10 @@ def __init__(self): self.get_logger().error("Could not load model") raise Exception("Could not load model") + self.srv = self.create_service( + RAIGroundedSam, GSAM_SERVICE_NAME, self.segment_callback + ) + def _init_weight_path(self): try: found_path = get_package_share_directory("rai_open_set_vision") diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py index 7927d7cf8..eba1ee44a 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py @@ -43,9 +43,7 @@ class GDRequest(TypedDict): class GDinoService(Node): def __init__(self): super().__init__(node_name=GDINO_NODE_NAME, parameter_overrides=[]) - self.srv = self.create_service( - RAIGroundingDino, GDINO_SERVICE_NAME, self.classify_callback - ) + self.declare_parameter("weights_path", "") try: weight_path = self.get_parameter("weights_path").value @@ -57,6 +55,10 @@ def __init__(self): self.get_logger().error("Could not load model") raise Exception("Could not load model") + self.srv = self.create_service( + RAIGroundingDino, GDINO_SERVICE_NAME, self.classify_callback + ) + def _init_weight_path(self) -> Path: try: found_path = get_package_share_directory("rai_open_set_vision") diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py index 336ec4bc7..62656c5a8 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py @@ -15,14 +15,11 @@ from typing import List, NamedTuple, Type import numpy as np -import rclpy -import rclpy.qos import sensor_msgs.msg from pydantic import BaseModel, Field -from rai.node import RaiBaseNode +from rai.communication.ros2.connectors import ROS2ARIConnector from rai.tools.ros import Ros2BaseInput, Ros2BaseTool from rai.tools.ros.utils import convert_ros_img_to_ndarray -from rai.tools.utils import wait_for_message from rai.utils.ros_async import get_future_result from rclpy.exceptions import ( ParameterNotDeclaredException, @@ -82,7 +79,7 @@ class DistanceMeasurement(NamedTuple): # --------------------- Tools --------------------- class GroundingDinoBaseTool(Ros2BaseTool): - node: RaiBaseNode = Field(..., exclude=True, required=True) + connector: ROS2ARIConnector = Field(..., exclude=True) box_threshold: float = Field(default=0.35, description="Box threshold for GDINO") text_threshold: float = Field(default=0.45, description="Text threshold for GDINO") @@ -90,9 +87,11 @@ class GroundingDinoBaseTool(Ros2BaseTool): def _call_gdino_node( self, camera_img_message: sensor_msgs.msg.Image, object_names: list[str] ) -> Future: - cli = self.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) + cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) while not cli.wait_for_service(timeout_sec=1.0): - self.node.get_logger().info("service not available, waiting again...") + self.node.get_logger().info( + f"service {GDINO_SERVICE_NAME} not available, waiting again..." + ) req = RAIGroundingDino.Request() req.source_img = camera_img_message req.classes = " , ".join(object_names) @@ -103,20 +102,16 @@ def _call_gdino_node( return future def get_img_from_topic(self, topic: str, timeout_sec: int = 2): - success, msg = wait_for_message( - sensor_msgs.msg.Image, - self.node, - topic, - qos_profile=rclpy.qos.qos_profile_sensor_data, - time_to_wait=timeout_sec, - ) - - if success: - self.node.get_logger().info(f"Received message of type from topic {topic}") + msg = self.connector.receive_message(topic, timeout_sec=timeout_sec).payload + + if msg is not None: + self.connector.node.get_logger().info( + f"Received message of {type(msg)} from topic {topic}" + ) return msg else: error = f"No message received in {timeout_sec} seconds from topic {topic}" - self.node.get_logger().error(error) + self.connector.node.get_logger().error(error) return error def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py index 29ff0fe18..043802d8f 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py @@ -18,9 +18,10 @@ import numpy as np import rclpy import sensor_msgs.msg +from langchain_core.tools import BaseTool from pydantic import Field -from rai.node import RaiBaseNode -from rai.tools.ros import Ros2BaseInput, Ros2BaseTool +from rai.communication.ros2.connectors import ROS2ARIConnector +from rai.tools.ros import Ros2BaseInput from rai.tools.ros.utils import convert_ros_img_to_base64, convert_ros_img_to_ndarray from rai.utils.ros_async import get_future_result from rclpy import Future @@ -64,8 +65,8 @@ class GetGrabbingPointInput(Ros2BaseInput): # --------------------- Tools --------------------- -class GetSegmentationTool(Ros2BaseTool): - node: RaiBaseNode = Field(..., exclude=True) +class GetSegmentationTool: + connector: ROS2ARIConnector = Field(..., exclude=True) name: str = "" description: str = "" @@ -84,7 +85,7 @@ def _get_gsam_response(self, future: Future) -> Optional[RAIGroundedSam.Response return get_future_result(future) def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: - msg = self.node.get_raw_message_from_topic(topic) + msg = self.connector.receive_message(topic).payload if type(msg) is sensor_msgs.msg.Image: return msg else: @@ -93,9 +94,11 @@ def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: def _call_gdino_node( self, camera_img_message: sensor_msgs.msg.Image, object_name: str ) -> Future: - cli = self.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) + cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) while not cli.wait_for_service(timeout_sec=1.0): - self.node.get_logger().info("service not available, waiting again...") + self.node.get_logger().info( + f"service {GDINO_SERVICE_NAME} not available, waiting again..." + ) req = RAIGroundingDino.Request() req.source_img = camera_img_message req.classes = object_name @@ -108,9 +111,11 @@ def _call_gdino_node( def _call_gsam_node( self, camera_img_message: sensor_msgs.msg.Image, data: RAIGroundingDino.Response ): - cli = self.node.create_client(RAIGroundedSam, "grounded_sam_segment") + cli = self.connector.node.create_client(RAIGroundedSam, "grounded_sam_segment") while not cli.wait_for_service(timeout_sec=1.0): - self.node.get_logger().info("service not available, waiting again...") + self.node.get_logger().info( + "service grounded_sam_segment not available, waiting again..." + ) req = RAIGroundedSam.Request() req.detections = data.detections req.source_img = camera_img_message @@ -126,9 +131,11 @@ def _run( camera_img_msg = self._get_image_message(camera_topic) future = self._call_gdino_node(camera_img_msg, object_name) - logger = self.node.get_logger() + logger = self.connector.node.get_logger() try: - conversion_ratio = self.node.get_parameter("conversion_ratio").value + conversion_ratio = self.connector.node.get_parameter( + "conversion_ratio" + ).value if not isinstance(conversion_ratio, float): logger.error( f"Parameter conversion_ratio was set badly: {type(conversion_ratio)}: {conversion_ratio} expected float. Using default value 0.001" @@ -185,19 +192,72 @@ def depth_to_point_cloud( return points -class GetGrabbingPointTool(GetSegmentationTool): +class GetGrabbingPointTool(BaseTool): + connector: ROS2ARIConnector = Field(..., exclude=True) + name: str = "GetGrabbingPointTool" description: str = "Get the grabbing point of an object" pcd: List[Any] = [] args_schema: Type[GetGrabbingPointInput] = GetGrabbingPointInput + box_threshold: float = Field(default=0.35, description="Box threshold for GDINO") + text_threshold: float = Field(default=0.45, description="Text threshold for GDINO") + + def _get_gdino_response( + self, future: Future + ) -> Optional[RAIGroundingDino.Response]: + return get_future_result(future) + + def _get_gsam_response(self, future: Future) -> Optional[RAIGroundedSam.Response]: + return get_future_result(future) + + def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: + msg = self.connector.receive_message(topic).payload + if type(msg) is sensor_msgs.msg.Image: + return msg + else: + raise Exception("Received wrong message") + + def _call_gdino_node( + self, camera_img_message: sensor_msgs.msg.Image, object_name: str + ) -> Future: + cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) + while not cli.wait_for_service(timeout_sec=1.0): + self.connector.node.get_logger().info( + "service not available, waiting again..." + ) + req = RAIGroundingDino.Request() + req.source_img = camera_img_message + req.classes = object_name + req.box_threshold = self.box_threshold + req.text_threshold = self.text_threshold + + future = cli.call_async(req) + return future + + def _call_gsam_node( + self, camera_img_message: sensor_msgs.msg.Image, data: RAIGroundingDino.Response + ): + cli = self.connector.node.create_client(RAIGroundedSam, "grounded_sam_segment") + while not cli.wait_for_service(timeout_sec=1.0): + self.connector.node.get_logger().info( + "service not available, waiting again..." + ) + req = RAIGroundedSam.Request() + req.detections = data.detections + req.source_img = camera_img_message + future = cli.call_async(req) + + return future def _get_camera_info_message(self, topic: str) -> sensor_msgs.msg.CameraInfo: for _ in range(3): - msg = self.node.get_raw_message_from_topic(topic, timeout_sec=3.0) + msg = self.connector.receive_message(topic, timeout_sec=3.0).payload if isinstance(msg, sensor_msgs.msg.CameraInfo): return msg - self.node.get_logger().warn("Received wrong message type. Retrying...") + self.connector.node.get_logger().warn( + "Received wrong message type. Retrying..." + ) raise Exception("Failed to receive correct CameraInfo message after 3 attempts") @@ -259,16 +319,18 @@ def _run( camera_info_topic: str, object_name: str, ): - camera_img_msg = self._get_image_message(camera_topic) - depth_msg = self._get_image_message(depth_topic) + camera_img_msg = self.connector.receive_message(camera_topic).payload + depth_msg = self.connector.receive_message(depth_topic).payload camera_info = self._get_camera_info_message(camera_info_topic) intrinsic = self._get_intrinsic_from_camera_info(camera_info) future = self._call_gdino_node(camera_img_msg, object_name) - logger = self.node.get_logger() + logger = self.connector.node.get_logger() try: - conversion_ratio = self.node.get_parameter("conversion_ratio").value + conversion_ratio = self.connector.node.get_parameter( + "conversion_ratio" + ).value if not isinstance(conversion_ratio, float): logger.error( f"Parameter conversion_ratio was set badly: {type(conversion_ratio)}: {conversion_ratio} expected float. Using default value 0.001" @@ -280,21 +342,17 @@ def _run( ) conversion_ratio = 0.001 resolved = None - while rclpy.ok(): - resolved = self._get_gdino_response(future) - if resolved is not None: - break + + resolved = get_future_result(future) assert resolved is not None future = self._call_gsam_node(camera_img_msg, resolved) ret = [] - while rclpy.ok(): - resolved = self._get_gsam_response(future) - if resolved is not None: - for img_msg in resolved.masks: - ret.append(convert_ros_img_to_base64(img_msg)) - break + resolved = get_future_result(future) + if resolved is not None: + for img_msg in resolved.masks: + ret.append(convert_ros_img_to_base64(img_msg)) assert resolved is not None rets = [] for mask_msg in resolved.masks: diff --git a/src/rai_sim/rai_sim/o3de/o3de_bridge.py b/src/rai_sim/rai_sim/o3de/o3de_bridge.py index 45e377d39..d514beb78 100644 --- a/src/rai_sim/rai_sim/o3de/o3de_bridge.py +++ b/src/rai_sim/rai_sim/o3de/o3de_bridge.py @@ -18,13 +18,16 @@ import subprocess import time from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import yaml -from geometry_msgs.msg import Point, Pose, Quaternion +from geometry_msgs.msg import Point, Pose, PoseStamped, Quaternion from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage +from rai.utils.ros_async import get_future_result +from std_msgs.msg import Header from tf2_geometry_msgs import do_transform_pose +from rai_interfaces.srv import ManipulatorMoveTo from rai_sim.simulation_bridge import ( Entity, PoseModel, @@ -40,6 +43,9 @@ class O3DExROS2SimulationConfig(SimulationConfig): binary_path: Path robotic_stack_command: str + required_services: List[str] + required_topics: List[str] + required_actions: List[str] @classmethod def load_config( @@ -199,6 +205,35 @@ def get_scene_state(self) -> SceneState: ) return SceneState(entities=entities) + def _is_robotic_stack_ready( + self, simulation_config: O3DExROS2SimulationConfig, retries: int = 30 + ) -> bool: + i = 0 + while i < retries: + topics = self.connector.get_topics_names_and_types() + services = self.connector.node.get_service_names_and_types() + topics_names = [tp[0] for tp in topics] + service_names = [srv[0] for srv in services] + self.logger.info( + f"required services: {simulation_config.required_services}" + ) + self.logger.info(f"required topics: {simulation_config.required_topics}") + self.logger.info(f"required actions: {simulation_config.required_actions}") + # NOTE actions will be listed in services and topics + if ( + all(srv in service_names for srv in simulation_config.required_services) + and all(tp in topics_names for tp in simulation_config.required_topics) + and all( + ac in service_names for ac in simulation_config.required_actions + ) + ): + self.logger.info("All required services are available.") + return True + + time.sleep(5) + retries += 1 + return False + def setup_scene(self, simulation_config: O3DExROS2SimulationConfig): if self.current_binary_path != simulation_config.binary_path: if self.current_sim_process: @@ -211,6 +246,11 @@ def setup_scene(self, simulation_config: O3DExROS2SimulationConfig): while self.spawned_entities: self._despawn_entity(self.spawned_entities[0]) + if not self._is_robotic_stack_ready(simulation_config=simulation_config): + raise RuntimeError( + "Not all required services, topics and actions are available" + ) + for entity in simulation_config.entities: self._spawn_entity(entity) @@ -304,3 +344,52 @@ def from_ros2_pose(self, pose: Pose) -> PoseModel: ) return PoseModel(translation=translation, rotation=rotation) + + +class O3DEngineArmManipulationBridge(O3DExROS2Bridge): + def move_arm( + self, + pose: PoseModel, + initial_gripper_state: bool, + final_gripper_state: bool, + frame_id: str, + ): + """Moves arm to a given position + + Args: + pose (PoseModel): where to move arm + initial_gripper_state (bool): False means closed grip, True means open grip + final_gripper_state (bool): False means closed grip, True means open grip + frame_id (str): reference frame + """ + + request = ManipulatorMoveTo.Request() + request.initial_gripper_state = initial_gripper_state + request.final_gripper_state = final_gripper_state + + request.target_pose = PoseStamped() + request.target_pose.header = Header() + request.target_pose.header.frame_id = frame_id + + request.target_pose.pose.position.x = pose.translation.x + request.target_pose.pose.position.x = pose.translation.y + request.target_pose.pose.position.z = pose.translation.z + + if pose.rotation: + request.target_pose.pose.orientation.x = pose.rotation.x + request.target_pose.pose.orientation.y = pose.rotation.y + request.target_pose.pose.orientation.z = pose.rotation.z + request.target_pose.pose.orientation.w = pose.rotation.w + + client = self.connector.node.create_client( + ManipulatorMoveTo, + "/manipulator_move_to", + ) + while not client.wait_for_service(timeout_sec=5.0): + self.connector.node.get_logger().info("Service not available, waiting...") + + self.connector.node.get_logger().info("Making request to arm manipulator...") + future = client.call_async(request) + result = get_future_result(future, timeout_sec=5.0) + + self.connector.node.get_logger().debug(f"Moving arm result: {result}") diff --git a/src/rai_sim/rai_sim/simulation_bridge.py b/src/rai_sim/rai_sim/simulation_bridge.py index 93959382d..a1e1fcbd2 100644 --- a/src/rai_sim/rai_sim/simulation_bridge.py +++ b/src/rai_sim/rai_sim/simulation_bridge.py @@ -36,7 +36,7 @@ class Rotation(BaseModel): class PoseModel(BaseModel): translation: Translation - rotation: Optional[Rotation] + rotation: Optional[Rotation] = None class Entity(BaseModel): From 6dfc98a727a6cf1bf861b75cb86850d339ccde70 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Wed, 26 Feb 2025 10:24:49 +0100 Subject: [PATCH 3/8] fix: add __init__.py in rai_bench --- src/rai_bench/rai_bench/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 src/rai_bench/rai_bench/__init__.py diff --git a/src/rai_bench/rai_bench/__init__.py b/src/rai_bench/rai_bench/__init__.py new file mode 100644 index 000000000..97ceef6f0 --- /dev/null +++ b/src/rai_bench/rai_bench/__init__.py @@ -0,0 +1,13 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From 053eb813bf26d6d3b6fb11929f4efe2b72ed53c3 Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Thu, 27 Feb 2025 14:42:32 +0100 Subject: [PATCH 4/8] refactor: requested changes to rai_bench (#439) --- examples/manipulation-demo.launch.py | 1 - poetry.lock | 4 +- pyproject.toml | 2 +- src/rai_bench/README.md | 26 +++--- src/rai_bench/pyproject.toml | 2 +- src/rai_bench/rai_bench/benchmark_model.py | 89 ++++++++++++------- .../o3de_test_benchmark.py} | 2 +- .../o3de_test_bench/tasks/grab_carrot_task.py | 8 +- .../o3de_test_bench/tasks/place_cubes_task.py | 14 +-- src/rai_core/rai/agents/tool_runner.py | 7 +- 10 files changed, 84 insertions(+), 71 deletions(-) rename src/rai_bench/rai_bench/{main.py => examples/o3de_test_benchmark.py} (98%) diff --git a/examples/manipulation-demo.launch.py b/examples/manipulation-demo.launch.py index 35720a6af..ee28c9b44 100644 --- a/examples/manipulation-demo.launch.py +++ b/examples/manipulation-demo.launch.py @@ -51,7 +51,6 @@ def generate_launch_description(): launch_robotic_manipulation = Node( package="robotic_manipulation", executable="robotic_manipulation", - # name="robotic_manipulation_node", output="screen", parameters=[ {"use_sim_time": True}, diff --git a/poetry.lock b/poetry.lock index 9eb39f9fa..335612128 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5703,7 +5703,7 @@ url = "src/rai_asr" [[package]] name = "rai-bench" version = "0.1.0" -description = "" +description = "Package for running and creating benchmarks." optional = false python-versions = "^3.10" files = [] @@ -8331,4 +8331,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10, <3.13" -content-hash = "c5469635a5db79c258554ad9f4e49331515940e406fbf912822651a0e0c33dda" +content-hash = "d943b786f2bb8dddc9249475409a4d7c9c4b0a77041611c039431df55ad94000" diff --git a/pyproject.toml b/pyproject.toml index a3edd3730..92bcfa7c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ requests = "^2.32.2" pre-commit = "^3.7.0" openai = "^1.23.3" coloredlogs = "^15.0.1" +opencv-python = "^4.9.0.80" markdown = "^3.6" boto3 = "^1.34.98" tqdm = "^4.66.4" @@ -62,7 +63,6 @@ pytest-timeout = "^2.3.1" tomli-w = "^1.1.0" faster-whisper = "^1.1.1" pydub = "^0.25.1" -opencv-python = "^4.11.0.86" [tool.poetry.group.dev.dependencies] ipykernel = "^6.29.4" diff --git a/src/rai_bench/README.md b/src/rai_bench/README.md index e2abd3d4f..3db5d04ba 100644 --- a/src/rai_bench/README.md +++ b/src/rai_bench/README.md @@ -8,33 +8,29 @@ The RAI Bench is a package including benchmarks and providing frame for creating Frame components can be found in `src/rai_bench/rai_bench/benchmark_model.py` -- `Task` - abstract class for creating specific task. It introduces helper funtions that make it easier to calculate metrics/scores. Your custom tasks must implement a prompt got agent to do, a way to calculate a result and a validation if given scene config suits the task. -- -- `Scenario` - class defined by a Scene and Task. Can be created manually like: +- `Task` +- `Scenario` +- `Benchmark` - ```python - - ``` - -- `Benchmark` - class responsible for running and logging scenarios. +For more information about these classes go to -> `src/rai_bench/rai_bench/benchmark_model.py` ### O3DE TEST BENCHMARK -O3DE Test Benchmark (src/rai_bench/rai_bench/o3de_test_bench/), contains 2 Tasks(tasks/) - GrabCarrotTask and PlaceCubesTask (these tasks implement calculating scores) and 4 scene_configs(configs/) for O3DE robotic arm simulation. +O3DE Test Benchmark (`src/rai_bench/rai_bench/o3de_test_bench/`), contains 2 Tasks(`tasks/`) - GrabCarrotTask and PlaceCubesTask (these tasks implement calculating scores) and 4 scene_configs(`configs/`) for O3DE robotic arm simulation. Both tasks calculate score, taking into consideration 4 values: -- initially_misplaced_now_correct - when the object which was in the incorrect place at the start, is in a correct place at the end -- initially_misplaced_still_incorrect - when the object which was in the incorrect place at the start, is in a incorrect place at the end -- initially_correct_still_correct - when the object which was in the correct place at the start, is in a correct place at the end -- initially_correct_now_incorrect - when the object which was in the correct place at the start, is in a incorrect place at the end +- initially_misplaced_now_correct +- initially_misplaced_still_incorrect +- initially_correct_still_correct +- initially_correct_now_incorrect The result is a value between 0 and 1, calculated like (initially_misplaced_now_correct + initially_correct_still_correct) / number_of_initial_objects. This score is calculated at the beggining and at the end of each scenario. ### Example usage -Example of how to load scenes, define scenarios and run benchmark can be found in `src/rai_bench/rai_bench/benchmark_main.py` +Example of how to load scenes, define scenarios and run benchmark can be found in `src/rai_bench/rai_bench/examples/o3de_test_benchmark.py` Scenarios can be loaded manually like: @@ -56,5 +52,3 @@ scenarios = Benchmark.create_scenarios( ``` which will result in list of scenarios with combination of every possible task and scene(task decides if scene config is suitable for it). - -Both approaches can be found in `main.py` diff --git a/src/rai_bench/pyproject.toml b/src/rai_bench/pyproject.toml index 52255eb9a..a5426f409 100644 --- a/src/rai_bench/pyproject.toml +++ b/src/rai_bench/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "rai-bench" version = "0.1.0" -description = "" +description = "Package for running and creating benchmarks." authors = ["jmatejcz "] readme = "README.md" diff --git a/src/rai_bench/rai_bench/benchmark_model.py b/src/rai_bench/rai_bench/benchmark_model.py index bc47ef407..34cb7b776 100644 --- a/src/rai_bench/rai_bench/benchmark_model.py +++ b/src/rai_bench/rai_bench/benchmark_model.py @@ -34,16 +34,17 @@ class EntitiesMismatchException(Exception): - def __init__(self, message: str) -> None: - super().__init__(message) + pass class Task(ABC): """ - Task to perform. - Specyfic implementation should implement a way to calculate results. - Abstract provides utility functions for common calculations, that can be usefull when - creating metrics + Abstract of a Task. Provides utility functions for common calculations + that can be helfull when creating metrics. + Specific child classes should implement: + - get_prompt method + - validate_config + - calculate_result """ def __init__( @@ -57,6 +58,7 @@ def __init__( @abstractmethod def get_prompt(self) -> str: + """Returns the task instruction - the prompt that will be passed to agent""" pass @abstractmethod @@ -75,7 +77,8 @@ def calculate_result( self, simulation_bridge: SimulationBridge[SimulationConfigT] ) -> float: """ - Calculate result of the task + Calculates result of the task, based on info retrieved from simulation. + Should return score between 0.0 and 1. """ pass @@ -135,7 +138,10 @@ def count_adjacent( class Scenario(Generic[SimulationConfigT]): - """Single instances are run separatly by benchmark""" + """ + A Scenarios are defined by a pair of Task and Simlation Config. + Each Scenario is executed separatly by a Benchmark. + """ def __init__( self, @@ -154,7 +160,9 @@ def __init__( class Benchmark: """ - Defined by a set of scenarios to be done + Benchmark represents a set of Scenarios to be executed and evaluated. + It manages the execution, logs results, and provides functionality + for tracking and exporting performance metrics. """ def __init__( @@ -162,16 +170,20 @@ def __init__( simulation_bridge: SimulationBridge[SimulationConfigT], scenarios: List[Scenario[SimulationConfigT]], logger: loggers_type | None = None, + results_filename: str = "benchmark_results.csv", ) -> None: self.simulation_bridge = simulation_bridge self.num_of_scenarios = len(scenarios) self.scenarios = enumerate(iter(scenarios)) self.results: List[Dict[str, Any]] = [] + self.results_filename = results_filename if logger: self._logger = logger else: self._logger = logging.getLogger(__name__) + self._initialize_results_file() + @classmethod def create_scenarios( cls, @@ -198,6 +210,23 @@ def create_scenarios( ) return scenarios + def _initialize_results_file(self): + """Initialize the CSV file with headers.""" + fieldnames = [ + "task", + "simulation_config", + "initial_score", + "final_score", + "total_time", + "number_of_tool_calls", + ] + + with open( + self.results_filename, mode="w", newline="", encoding="utf-8" + ) as file: + writer = csv.DictWriter(file, fieldnames=fieldnames) + writer.writeheader() + def run_next(self, agent) -> None: """ Runs the next scenario @@ -251,40 +280,36 @@ def run_next(self, agent) -> None: f"TASK SCORE: {result}, TOTAL TIME: {total_time:.3f}, NUM_OF_TOOL_CALLS: {tool_calls_num}" ) - self.results.append( - { - "task": scenario.task.get_prompt(), - "simulation_config": scenario.simulation_config_path, - "initial_score": initial_result, - "final_score": result, - "total_time": f"{total_time:.3f}", - "number_of_tool_calls": tool_calls_num, - } - ) + scenario_result: Dict[str, Any] = { + "task": scenario.task.get_prompt(), + "simulation_config": scenario.simulation_config_path, + "initial_score": initial_result, + "final_score": result, + "total_time": f"{total_time:.3f}", + "number_of_tool_calls": tool_calls_num, + } + self.results.append(scenario_result) + self._save_scenario_result_to_csv(scenario_result) except StopIteration: print("No more scenarios left to run.") - def get_results(self) -> List[Dict[str, Any]]: - return self.results - - def dump_results_to_csv(self, filename: str) -> None: - if not self.results: - self._logger.warning("No results to save.") # type: ignore - return - + def _save_scenario_result_to_csv(self, result: Dict[str, Any]) -> None: + """Save a single scenario result to the CSV file.""" fieldnames = [ "task", - "initial_score", "simulation_config", + "initial_score", "final_score", "total_time", "number_of_tool_calls", ] - with open(filename, mode="w", newline="", encoding="utf-8") as file: + with open( + self.results_filename, mode="a", newline="", encoding="utf-8" + ) as file: writer = csv.DictWriter(file, fieldnames=fieldnames) - writer.writeheader() - writer.writerows(self.results) + writer.writerow(result) - self._logger.info(f"Results saved to {filename}") # type: ignore + def get_results(self) -> List[Dict[str, Any]]: + return self.results diff --git a/src/rai_bench/rai_bench/main.py b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py similarity index 98% rename from src/rai_bench/rai_bench/main.py rename to src/rai_bench/rai_bench/examples/o3de_test_benchmark.py index 7875c92c4..bfa81d0bc 100644 --- a/src/rai_bench/rai_bench/main.py +++ b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py @@ -163,6 +163,7 @@ simulation_bridge=o3de, scenarios=scenarios, logger=bench_logger, + results_filename="src/rai_bench/rai_bench/results.csv", ) for i, s in enumerate(scenarios): agent = create_conversational_agent( @@ -180,7 +181,6 @@ bench_logger.info("===============================================================") bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") bench_logger.info("===============================================================") - benchmark.dump_results_to_csv(filename="src/rai_bench/rai_bench/results.csv") connector.shutdown() o3de.shutdown() diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py index ca040fd62..56203f194 100644 --- a/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py @@ -59,10 +59,10 @@ def calculate_result( else: self.logger.debug(f"initial positions: {initial_carrots}") # type: ignore self.logger.debug(f"current positions: {final_carrots}") # type: ignore - for ini_carrot in initial_carrots: + for initial_carrot in initial_carrots: for final_carrot in final_carrots: - if ini_carrot.name == final_carrot.name: - initial_y = ini_carrot.pose.translation.y + if initial_carrot.name == final_carrot.name: + initial_y = initial_carrot.pose.translation.y final_y = final_carrot.pose.translation.y # NOTE the specific coords that refer to for example # middle of the table can differ across simulations, @@ -90,7 +90,7 @@ def calculate_result( break else: raise EntitiesMismatchException( - f"Entity with name: {ini_carrot.name} which was present in initial scene, not found in final scene." + f"Entity with name: {initial_carrot.name} which was present in initial scene, not found in final scene." ) self.logger.info( # type: ignore diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py index 26bdd590e..9ad03c9b2 100644 --- a/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py @@ -61,19 +61,19 @@ def calculate_result( ) else: - ini_poses = [cube.pose for cube in initial_cubes] + initial_poses = [cube.pose for cube in initial_cubes] final_poses = [cube.pose for cube in final_cubes] # NOTE the specific coords that refer to for example # middle of the table can differ across simulations, # take that into consideration self.logger.debug(f"initial positions: {initial_cubes}") self.logger.debug(f"current positions: {final_cubes}") - for i, ini_cube in enumerate(initial_cubes): - for j, final_cube in enumerate(final_cubes): - if ini_cube.name == final_cube.name: + for initial_cube in initial_cubes: + for final_cube in final_cubes: + if initial_cube.name == final_cube.name: was_adjacent_initially = self.is_adjacent_to_any( - ini_cube.pose, - [p for p in ini_poses if p != ini_cube.pose], + initial_cube.pose, + [p for p in initial_poses if p != initial_cube.pose], 0.15, ) is_adjacent_finally = self.is_adjacent_to_any( @@ -93,7 +93,7 @@ def calculate_result( break else: raise EntitiesMismatchException( - f"Entity with name: {ini_cube.name} which was present in initial scene, not found in final scene." + f"Entity with name: {initial_cube.name} which was present in initial scene, not found in final scene." ) self.logger.info( diff --git a/src/rai_core/rai/agents/tool_runner.py b/src/rai_core/rai/agents/tool_runner.py index 5c35ac9a8..12e0889d3 100644 --- a/src/rai_core/rai/agents/tool_runner.py +++ b/src/rai_core/rai/agents/tool_runner.py @@ -69,13 +69,8 @@ def run_one(call: ToolCall): ts = time.perf_counter() output = self.tools_by_name[call["name"]].invoke(call, config) # type: ignore te = time.perf_counter() - ts - tool_output_log = ( - str(output.content)[:1000] + "..." - if len(str(output.content)) > 1000 - else "" - ) self.logger.info( - f"Tool {call['name']} completed in {te:.2f} seconds. Tool output: {tool_output_log}" + f"Tool {call['name']} completed in {te:.2f} seconds. Tool output: {str(output.content)[:100]}{'...' if len(str(output.content)) > 100 else ''}" ) self.logger.debug( f"Tool {call['name']} output: \n\n{str(output.content)}" From 3fff1b274801f8a9b8a0a127d2da826c1b0d5e20 Mon Sep 17 00:00:00 2001 From: Magdalena Kotynia Date: Thu, 27 Feb 2025 16:59:26 +0100 Subject: [PATCH 5/8] test: rai_sim tests (#437) --- setup_shell.sh | 2 + src/rai_bench/rai_bench/benchmark_model.py | 12 +- .../rai_bench/examples/o3de_test_benchmark.py | 4 +- src/rai_sim/README.md | 7 +- src/rai_sim/rai_sim/o3de/o3de_bridge.py | 156 ++++--- src/rai_sim/rai_sim/simulation_bridge.py | 197 ++++++++- tests/rai_sim/conftest.py | 75 ++++ tests/rai_sim/test_o3de_bridge.py | 379 ++++++++++++++++++ tests/rai_sim/test_simulation_bridge.py | 325 +++++++++++++++ 9 files changed, 1061 insertions(+), 96 deletions(-) create mode 100644 tests/rai_sim/conftest.py create mode 100644 tests/rai_sim/test_o3de_bridge.py create mode 100644 tests/rai_sim/test_simulation_bridge.py diff --git a/setup_shell.sh b/setup_shell.sh index 164372ed9..506df2cd2 100755 --- a/setup_shell.sh +++ b/setup_shell.sh @@ -33,3 +33,5 @@ PYTHONPATH="$(dirname "$(dirname "$(poetry run which python)")")/lib/python$(poe PYTHONPATH="src/rai_core:$PYTHONPATH" PYTHONPATH="src/rai_asr:$PYTHONPATH" PYTHONPATH="src/rai_tts:$PYTHONPATH" +PYTHONPATH="src/rai_sim:$PYTHONPATH" +PYTHONPATH="src/rai_bench:$PYTHONPATH" diff --git a/src/rai_bench/rai_bench/benchmark_model.py b/src/rai_bench/rai_bench/benchmark_model.py index 34cb7b776..de9fc43eb 100644 --- a/src/rai_bench/rai_bench/benchmark_model.py +++ b/src/rai_bench/rai_bench/benchmark_model.py @@ -23,7 +23,7 @@ from rclpy.impl.rcutils_logger import RcutilsLogger from rai_sim.simulation_bridge import ( - PoseModel, + Pose, SimulationBridge, SimulationConfig, SimulationConfigT, @@ -88,7 +88,7 @@ def filter_entities_by_prefab_type( """Filter and return only these entities that match provided prefab types""" return [ent for ent in entities if ent.prefab_name in prefab_types] - def euclidean_distance(self, pos1: PoseModel, pos2: PoseModel) -> float: + def euclidean_distance(self, pos1: Pose, pos2: Pose) -> float: """Calculate euclidean distance between 2 positions""" return ( (pos1.translation.x - pos2.translation.x) ** 2 @@ -96,7 +96,7 @@ def euclidean_distance(self, pos1: PoseModel, pos2: PoseModel) -> float: + (pos1.translation.z - pos2.translation.z) ** 2 ) ** 0.5 - def is_adjacent(self, pos1: PoseModel, pos2: PoseModel, threshold_distance: float): + def is_adjacent(self, pos1: Pose, pos2: Pose, threshold_distance: float): """ Check if positions are adjacent to each other, the threshold_distance is a distance in simulation, refering to how close they have to be to classify them as adjacent @@ -107,7 +107,7 @@ def is_adjacent(self, pos1: PoseModel, pos2: PoseModel, threshold_distance: floa return self.euclidean_distance(pos1, pos2) < threshold_distance def is_adjacent_to_any( - self, pos1: PoseModel, positions: List[PoseModel], threshold_distance: float + self, pos1: Pose, positions: List[Pose], threshold_distance: float ) -> bool: """ Check if given position is adjacent to any position in the given list. @@ -117,9 +117,7 @@ def is_adjacent_to_any( self.is_adjacent(pos1, pos2, threshold_distance) for pos2 in positions ) - def count_adjacent( - self, positions: List[PoseModel], threshold_distance: float - ) -> int: + def count_adjacent(self, positions: List[Pose], threshold_distance: float) -> int: """ Count how many adjacent positions are in the given list. Note that position has to be adjacent to only 1 other position diff --git a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py index bfa81d0bc..b644567d8 100644 --- a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py +++ b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py @@ -32,7 +32,7 @@ from rai_sim.o3de.o3de_bridge import ( O3DEngineArmManipulationBridge, O3DExROS2SimulationConfig, - PoseModel, + Pose, ) from rai_sim.simulation_bridge import Rotation, Translation @@ -152,7 +152,7 @@ ) # custom request to arm - base_arm_pose = PoseModel( + base_arm_pose = Pose( translation=Translation(x=0.5, y=0.1, z=0.3), rotation=Rotation(x=1.0, y=0.0, z=0.0, w=0.0), ) diff --git a/src/rai_sim/README.md b/src/rai_sim/README.md index f6e82bcb6..8008b4d2d 100644 --- a/src/rai_sim/README.md +++ b/src/rai_sim/README.md @@ -6,12 +6,13 @@ The RAI Sim is a package providing interface to implement connection with a spec ### Components -- `SimulationConnector` - An interface for connecting with a specific simulation. It manages scene setup, spawning, despawning objects, getting current state of the scene. +- `SimulationBridge` - An interface for connecting with a specific simulation. It manages scene setup, spawning, despawning objects, getting current state of the scene. -- `SimulationConfig` - base config class to specify the entities to be spawned. For each simulation connector there should be specified custom simulation config specifying additional parameters needed to run and connect with the simulation. +- `SimulationConfig` - base config class to specify the entities to be spawned. For each simulation bridge there should be specified custom simulation config specifying additional parameters needed to run and connect with the simulation. - `SceneState` - stores the current info about spawned entities ### Example implementation -- `O3DExROS2Connector` - An implementation of SimulationConnector for working with simulation based on O3DE and ROS2. +- `O3DExROS2Bridge` - An implementation of SimulationBridge for working with simulation based on O3DE and ROS2. +- `O3DExROS2SimulationConfig` - config class for `O3DExROS2Bridge` diff --git a/src/rai_sim/rai_sim/o3de/o3de_bridge.py b/src/rai_sim/rai_sim/o3de/o3de_bridge.py index d514beb78..8fbe022ee 100644 --- a/src/rai_sim/rai_sim/o3de/o3de_bridge.py +++ b/src/rai_sim/rai_sim/o3de/o3de_bridge.py @@ -18,10 +18,11 @@ import subprocess import time from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set import yaml -from geometry_msgs.msg import Point, Pose, PoseStamped, Quaternion +from geometry_msgs.msg import Point, PoseStamped, Quaternion +from geometry_msgs.msg import Pose as ROS2Pose from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage from rai.utils.ros_async import get_future_result from std_msgs.msg import Header @@ -30,7 +31,7 @@ from rai_interfaces.srv import ManipulatorMoveTo from rai_sim.simulation_bridge import ( Entity, - PoseModel, + Pose, Rotation, SceneState, SimulationBridge, @@ -43,9 +44,8 @@ class O3DExROS2SimulationConfig(SimulationConfig): binary_path: Path robotic_stack_command: str - required_services: List[str] - required_topics: List[str] - required_actions: List[str] + required_simulation_ros2_interfaces: dict[str, List[str]] + required_robotic_ros2_interfaces: dict[str, List[str]] @classmethod def load_config( @@ -110,13 +110,7 @@ def get_available_spawnable_names(self) -> list[str]: target="get_available_spawnable_names", msg_type="gazebo_msgs/srv/GetWorldProperties", ) - # NOTE (mkotynia) There is a bug in the gazebo_msgs/srv/GetWorldProperties service - payload.success is not set to True even if the service call is successful. It was reported to Kacper Dąbrowski and he is going to fix it. - # PR fixing the bug: https://github.com/o3de/o3de-extras/pull/828 - # TODO (mkotynia) uncomment check if response.payload.success when the bug is fixed and remove workaround check if response.payload.model_names. - - # if response.payload.success: - # return response.payload.model_names - if response.payload.model_names: + if response.payload.success: return response.payload.model_names else: raise RuntimeError( @@ -125,7 +119,7 @@ def get_available_spawnable_names(self) -> list[str]: def _spawn_entity(self, entity: Entity): pose = do_transform_pose( - self.to_ros2_pose(entity.pose), + self._to_ros2_pose(entity.pose), self.connector.get_transform("odom", "world"), ) @@ -176,15 +170,13 @@ def _despawn_entity(self, entity: SpawnedEntity): f"Failed to delete entity {entity.name}. Response: {response.payload.status_message}" ) - def get_object_pose(self, entity: SpawnedEntity) -> PoseModel: + def get_object_pose(self, entity: SpawnedEntity) -> Pose: object_frame = entity.name + "/" ros2_pose = do_transform_pose( - Pose(), self.connector.get_transform(object_frame + "odom", object_frame) - ) - ros2_pose = do_transform_pose( - ros2_pose, self.connector.get_transform("world", "odom") + ROS2Pose(), + self.connector.get_transform(object_frame + "odom", object_frame), ) - return self.from_ros2_pose(ros2_pose) + return self._from_ros2_pose(ros2_pose) def get_scene_state(self) -> SceneState: """ @@ -205,72 +197,112 @@ def get_scene_state(self) -> SceneState: ) return SceneState(entities=entities) - def _is_robotic_stack_ready( - self, simulation_config: O3DExROS2SimulationConfig, retries: int = 30 + def _is_ros2_stack_ready( + self, required_ros2_stack: dict[str, List[str]], retries: int = 30 ) -> bool: - i = 0 - while i < retries: - topics = self.connector.get_topics_names_and_types() - services = self.connector.node.get_service_names_and_types() - topics_names = [tp[0] for tp in topics] - service_names = [srv[0] for srv in services] - self.logger.info( - f"required services: {simulation_config.required_services}" - ) - self.logger.info(f"required topics: {simulation_config.required_topics}") - self.logger.info(f"required actions: {simulation_config.required_actions}") - # NOTE actions will be listed in services and topics - if ( - all(srv in service_names for srv in simulation_config.required_services) - and all(tp in topics_names for tp in simulation_config.required_topics) - and all( - ac in service_names for ac in simulation_config.required_actions + for i in range(retries): + available_topics = self.connector.get_topics_names_and_types() + available_services = self.connector.node.get_service_names_and_types() + available_topics_names = [tp[0] for tp in available_topics] + available_services_names = [srv[0] for srv in available_services] + + # Extract action names + available_actions_names: Set[str] = set() + for service in available_services_names: + if "/_action" in service: + action_name = service.split("/_action")[0] + available_actions_names.add(action_name) + + required_services = required_ros2_stack["services"] + required_topics = required_ros2_stack["topics"] + required_actions = required_ros2_stack["actions"] + self.logger.info(f"required services: {required_services}") + self.logger.info(f"required topics: {required_topics}") + self.logger.info(f"required actions: {required_actions}") + self.logger.info(f"available actions: {available_actions_names}") + + missing_services = [ + service + for service in required_services + if service not in available_services_names + ] + missing_topics = [ + topic + for topic in required_topics + if topic not in available_topics_names + ] + missing_actions = [ + action + for action in required_actions + if action not in available_actions_names + ] + + if missing_services: + self.logger.warning( + f"Waiting for missing services {missing_services} out of required services: {required_services}" + ) + + if missing_topics: + self.logger.warning( + f"Waiting for missing topics: {missing_topics} out of required topics: {required_topics}" ) - ): - self.logger.info("All required services are available.") + + if missing_actions: + self.logger.warning( + f"Waiting for missing actions: {missing_actions} out of required actions: {required_actions}" + ) + + if not (missing_services or missing_topics or missing_actions): + self.logger.info("All required ROS2 stack components are available.") return True - time.sleep(5) - retries += 1 + time.sleep(3) + + self.logger.error( + "Maximum number of retries reached. Required ROS2 stack components are not fully available." + ) return False def setup_scene(self, simulation_config: O3DExROS2SimulationConfig): if self.current_binary_path != simulation_config.binary_path: if self.current_sim_process: self.shutdown() - self._launch_binary(simulation_config.binary_path) - self._launch_robotic_stack(simulation_config.robotic_stack_command) + self._launch_binary(simulation_config) + self._launch_robotic_stack(simulation_config) self.current_binary_path = simulation_config.binary_path else: while self.spawned_entities: self._despawn_entity(self.spawned_entities[0]) - if not self._is_robotic_stack_ready(simulation_config=simulation_config): - raise RuntimeError( - "Not all required services, topics and actions are available" - ) - for entity in simulation_config.entities: self._spawn_entity(entity) - def _launch_binary(self, binary_path: Path): - command = [binary_path.as_posix()] + def _launch_binary(self, simulation_config: O3DExROS2SimulationConfig): + command = [simulation_config.binary_path.as_posix()] self.logger.info(f"Running command: {command}") self.current_sim_process = subprocess.Popen( command, ) if not self._has_process_started(process=self.current_sim_process): raise RuntimeError("Process did not start in time.") + if not self._is_ros2_stack_ready( + required_ros2_stack=simulation_config.required_simulation_ros2_interfaces + ): + raise RuntimeError("ROS2 stack is not ready in time.") - def _launch_robotic_stack(self, robotic_stack_command: str): - command = shlex.split(robotic_stack_command) + def _launch_robotic_stack(self, simulation_config: O3DExROS2SimulationConfig): + command = shlex.split(simulation_config.robotic_stack_command) self.logger.info(f"Running command: {command}") self.current_robotic_stack_process = subprocess.Popen( command, ) if not self._has_process_started(self.current_robotic_stack_process): raise RuntimeError("Process did not start in time.") + if not self._is_ros2_stack_ready( + required_ros2_stack=simulation_config.required_robotic_ros2_interfaces + ): + raise RuntimeError("ROS2 stack is not ready in time.") def _has_process_started(self, process: subprocess.Popen[Any], timeout: int = 15): start_time = time.time() @@ -303,9 +335,9 @@ def _try_service_call( return response # type: ignore # NOTE (mkotynia) probably to be refactored, other bridges may also want to use pose conversion to/from ROS2 format - def to_ros2_pose(self, pose: PoseModel) -> Pose: + def _to_ros2_pose(self, pose: Pose) -> ROS2Pose: """ - Converts pose in PoseModel format to pose in ROS2 Pose format. + Converts pose to pose in ROS2 Pose format. """ position = Point( x=pose.translation.x, y=pose.translation.y, z=pose.translation.z @@ -321,13 +353,13 @@ def to_ros2_pose(self, pose: PoseModel) -> Pose: else: orientation = Quaternion() - ros2_pose = Pose(position=position, orientation=orientation) + ros2_pose = ROS2Pose(position=position, orientation=orientation) return ros2_pose - def from_ros2_pose(self, pose: Pose) -> PoseModel: + def _from_ros2_pose(self, pose: ROS2Pose) -> Pose: """ - Converts ROS2 pose to PoseModel format + Converts ROS2Pose to Pose """ translation = Translation( @@ -343,13 +375,13 @@ def from_ros2_pose(self, pose: Pose) -> PoseModel: w=pose.orientation.w, # type: ignore ) - return PoseModel(translation=translation, rotation=rotation) + return Pose(translation=translation, rotation=rotation) class O3DEngineArmManipulationBridge(O3DExROS2Bridge): def move_arm( self, - pose: PoseModel, + pose: Pose, initial_gripper_state: bool, final_gripper_state: bool, frame_id: str, @@ -357,7 +389,7 @@ def move_arm( """Moves arm to a given position Args: - pose (PoseModel): where to move arm + pose (Pose): where to move arm initial_gripper_state (bool): False means closed grip, True means open grip final_gripper_state (bool): False means closed grip, True means open grip frame_id (str): reference frame diff --git a/src/rai_sim/rai_sim/simulation_bridge.py b/src/rai_sim/rai_sim/simulation_bridge.py index a1e1fcbd2..336406ad0 100644 --- a/src/rai_sim/rai_sim/simulation_bridge.py +++ b/src/rai_sim/rai_sim/simulation_bridge.py @@ -18,48 +18,102 @@ from typing import Generic, List, Optional, TypeVar import yaml -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator class Translation(BaseModel): - x: float - y: float - z: float + """ + Represents the position of an object in 3D space using + x, y, and z coordinates. + """ + + x: float = Field(description="X coordinate in meters") + y: float = Field(description="Y coordinate in meters") + z: float = Field(description="Z coordinate in meters") class Rotation(BaseModel): - x: float - y: float - z: float - w: float + """ + Represents a 3D rotation using quaternion representation. + """ + x: float = Field(description="X component of the quaternion") + y: float = Field(description="Y component of the quaternion") + z: float = Field(description="Z component of the quaternion") + w: float = Field(description="W component of the quaternion") + + +class Pose(BaseModel): + """ + Represents the complete pose (position and orientation) of an object. + """ -class PoseModel(BaseModel): - translation: Translation - rotation: Optional[Rotation] = None + translation: Translation = Field( + description="The position of the object in 3D space" + ) + rotation: Optional[Rotation] = Field( + default=None, + description="The orientation of the object as a quaternion. Optional if orientation is not needed and default orientation is handled by the bridge", + ) class Entity(BaseModel): - name: str - prefab_name: str - pose: PoseModel + """ + Entity that can be spawned in the simulation environment. + """ + + name: str = Field(description="Unique name for the entity") + prefab_name: str = Field( + description="Name of the prefab resource to use for spawning this entity" + ) + pose: Pose = Field(description="Initial pose of the entity") class SpawnedEntity(Entity): - id: str + """ + Entity that has been spawned in the simulation environment. + """ + + id: str = Field( + description="Unique identifier assigned to the spawned entity instance" + ) class SimulationConfig(BaseModel): """ - Setup of simulation - arrangemenet of objects in the environment. + Setup of simulation - arrangement of objects in the environment. + + Attributes + ---------- + entities : List[Entity] + List of entities to be spawned in the simulation. """ - # NOTE (mkotynia) can be extended by other attributes - entities: List[Entity] + entities: List[Entity] = Field( + description="List of entities to be spawned in the simulation environment" + ) @field_validator("entities") @classmethod def check_unique_names(cls, entities: List[Entity]) -> List[Entity]: + """ + Validates that all entity names in the configuration are unique. + + Parameters + ---------- + entities : List[Entity] + List of entities to validate. + + Returns + ------- + List[Entity] + The validated list of entities. + + Raises + ------ + ValueError + If any entity names are duplicated. + """ names = [entity.name for entity in entities] if len(names) != len(set(names)): raise ValueError("Each entity must have a unique name.") @@ -67,6 +121,19 @@ def check_unique_names(cls, entities: List[Entity]) -> List[Entity]: @classmethod def load_base_config(cls, base_config_path: Path) -> "SimulationConfig": + """ + Loads a simulation configuration from a YAML file. + + Parameters + ---------- + base_config_path : Path + Path to the YAML configuration file. + + Returns + ------- + SimulationConfig + The loaded simulation configuration. + """ with open(base_config_path) as f: content = yaml.safe_load(f) return cls(**content) @@ -74,11 +141,17 @@ def load_base_config(cls, base_config_path: Path) -> "SimulationConfig": class SceneState(BaseModel): """ - Info about current entities' state in the scene. + Info about current state of the scene. + + Attributes + ---------- + entities : List[SpawnedEntity] + List of all entities currently present in the scene. """ - # NOTE (mkotynia) can be extended by other attributes - entities: List[SpawnedEntity] + entities: List[SpawnedEntity] = Field( + description="List of all entities currently spawned in the scene with their current poses" + ) SimulationConfigT = TypeVar("SimulationConfigT", bound=SimulationConfig) @@ -100,20 +173,100 @@ def __init__(self, logger: Optional[logging.Logger] = None): @abstractmethod def setup_scene(self, simulation_config: SimulationConfigT): + """ + Runs and sets up the simulation scene according to the provided configuration. + + Parameters + ---------- + simulation_config : SimulationConfigT + Configuration containing the simulation initialization and setup details including + entities to be spawned and their initial poses. + + Returns + ------- + None + """ pass @abstractmethod def _spawn_entity(self, entity: Entity): + """ + Spawns a single entity in the simulation environment. + + Parameters + ---------- + entity : Entity + Entity object containing the entity's properties + + Returns + ------- + None + + Notes + ----- + The spawned entity should be added to the spawned_entities list maintained + by the simulation bridge. + """ pass @abstractmethod def _despawn_entity(self, entity: SpawnedEntity): + """ + Removes a previously spawned entity from the simulation environment. + + Parameters + ---------- + entity : SpawnedEntity + Entity object representing the spawned entity to be removed. + + Returns + ------- + None + + Notes + ----- + Despawned entity should be removed from the spawned_entities list maintained + by the simulation bridge. + """ pass @abstractmethod - def get_object_pose(self, entity: SpawnedEntity) -> PoseModel: + def get_object_pose(self, entity: SpawnedEntity) -> Pose: + """ + Gets the current pose of a spawned entity. + + This method queries the simulation to get the current position and + orientation of a specific entity. + + Parameters + ---------- + entity : SpawnedEntity + Entity object representing the spawned entity whose pose is + to be retrieved. + + Returns + ------- + Pose + Object containing the entity's current pose. + """ pass @abstractmethod def get_scene_state(self) -> SceneState: + """ + Gets the current state of the simulation scene. + + Parameters + ---------- + None + + Returns + ------- + SceneState + Object containing the current state of the scene. + + Notes + ----- + SceneState should contain the current poses of spawned_entities. + """ pass diff --git a/tests/rai_sim/conftest.py b/tests/rai_sim/conftest.py new file mode 100644 index 000000000..4ebb264a9 --- /dev/null +++ b/tests/rai_sim/conftest.py @@ -0,0 +1,75 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import pytest + + +@pytest.fixture +def sample_base_yaml_config(tmp_path: Path) -> Path: + yaml_content = """ + entities: + - name: entity1 + prefab_name: cube + pose: + translation: + x: 1.0 + y: 2.0 + z: 3.0 + + - name: entity2 + prefab_name: carrot + pose: + translation: + x: 1.0 + y: 2.0 + z: 3.0 + rotation: + x: 0.1 + y: 0.2 + z: 0.3 + w: 0.4 + """ + file_path = tmp_path / "test_config.yaml" + file_path.write_text(yaml_content) + return file_path + + +@pytest.fixture +def sample_o3dexros2_config(tmp_path: Path) -> Path: + yaml_content = """ + binary_path: /path/to/binary + robotic_stack_command: "ros2 launch robotic_stack.launch.py" + required_simulation_ros2_interfaces: + services: + - /spawn_entity + - /delete_entity + topics: + - /color_image5 + - /depth_image5 + - /color_camera_info5 + actions: [] + required_robotic_ros2_interfaces: + services: + - /grounding_dino_classify + - /grounded_sam_segment + - /manipulator_move_to + topics: [] + actions: + - /execute_trajectory + """ + file_path = tmp_path / "test_o3dexros2_config.yaml" + file_path.write_text(yaml_content) + return file_path diff --git a/tests/rai_sim/test_o3de_bridge.py b/tests/rai_sim/test_o3de_bridge.py new file mode 100644 index 000000000..b7687304a --- /dev/null +++ b/tests/rai_sim/test_o3de_bridge.py @@ -0,0 +1,379 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import signal +import typing +import unittest +from pathlib import Path +from typing import List, Optional, Tuple, get_args, get_origin +from unittest.mock import MagicMock, patch + +import rclpy +import rclpy.qos +from geometry_msgs.msg import Point, Quaternion, TransformStamped +from geometry_msgs.msg import Pose as ROS2Pose +from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage +from rclpy.node import Node +from rclpy.qos import QoSProfile + +from rai_sim.o3de.o3de_bridge import O3DExROS2Bridge, O3DExROS2SimulationConfig +from rai_sim.simulation_bridge import ( + Entity, + Pose, + Rotation, + SpawnedEntity, + Translation, +) + + +def test_load_config(sample_base_yaml_config: Path, sample_o3dexros2_config: Path): + config = O3DExROS2SimulationConfig.load_config( + sample_base_yaml_config, sample_o3dexros2_config + ) + assert isinstance(config, O3DExROS2SimulationConfig) + assert config.binary_path == Path("/path/to/binary") + assert config.robotic_stack_command == "ros2 launch robotic_stack.launch.py" + assert config.required_simulation_ros2_interfaces == { + "services": ["/spawn_entity", "/delete_entity"], + "topics": ["/color_image5", "/depth_image5", "/color_camera_info5"], + "actions": [], + } + assert config.required_robotic_ros2_interfaces == { + "services": [ + "/grounding_dino_classify", + "/grounded_sam_segment", + "/manipulator_move_to", + ], + "topics": [], + "actions": ["/execute_trajectory"], + } + assert isinstance(config.entities, list) + assert all(isinstance(e, Entity) for e in config.entities) + + assert len(config.entities) == 2 + + +class TestO3DExROS2Bridge(unittest.TestCase): + def setUp(self): + self.mock_connector = MagicMock(spec=ROS2ARIConnector) + self.mock_logger = MagicMock() + self.bridge = O3DExROS2Bridge( + connector=self.mock_connector, logger=self.mock_logger + ) + + # Create test data + self.test_entity = Entity( + name="test_entity1", + prefab_name="cube", + pose=Pose( + translation=Translation(x=1.0, y=2.0, z=3.0), + rotation=Rotation(x=0.0, y=0.0, z=0.0, w=1.0), + ), + ) + + self.test_spawned_entity = SpawnedEntity( + id="entity_id_123", + name="test_entity1", + prefab_name="cube", + pose=Pose( + translation=Translation(x=1.0, y=2.0, z=3.0), + rotation=Rotation(x=0.0, y=0.0, z=0.0, w=1.0), + ), + ) + + self.test_config = O3DExROS2SimulationConfig( + binary_path=Path("/path/to/binary"), + robotic_stack_command="ros2 launch robot.launch.py", + entities=[self.test_entity], + required_simulation_ros2_interfaces={ + "services": [], + "topics": [], + "actions": [], + }, + required_robotic_ros2_interfaces={ + "services": [], + "topics": [], + "actions": [], + }, + ) + + def test_init(self): + self.assertEqual(self.bridge.connector, self.mock_connector) + self.assertEqual(self.bridge.logger, self.mock_logger) + self.assertIsNone(self.bridge.current_sim_process) + self.assertIsNone(self.bridge.current_robotic_stack_process) + self.assertIsNone(self.bridge.current_binary_path) + self.assertEqual(self.bridge.spawned_entities, []) + + @patch("subprocess.Popen") + def test_launch_robotic_stack(self, mock_popen): + mock_process = MagicMock() + mock_process.poll.return_value = None + mock_process.pid = 54321 + mock_popen.return_value = mock_process + self.bridge._launch_robotic_stack(self.test_config) + + mock_popen.assert_called_once_with(["ros2", "launch", "robot.launch.py"]) + self.assertEqual(self.bridge.current_robotic_stack_process, mock_process) + + @patch("subprocess.Popen") + def test_launch_binary(self, mock_popen): + mock_process = MagicMock() + mock_process.poll.return_value = None + mock_process.pid = 54322 + mock_popen.return_value = mock_process + + self.bridge._launch_binary(self.test_config) + + mock_popen.assert_called_once_with(["/path/to/binary"]) + self.assertEqual(self.bridge.current_sim_process, mock_process) + + def test_shutdown_binary(self): + mock_process = MagicMock() + mock_process.poll.return_value = 0 + + self.bridge.current_sim_process = mock_process + + self.bridge._shutdown_binary() + + mock_process.send_signal.assert_called_once_with(signal.SIGINT) + mock_process.wait.assert_called_once() + + self.assertIsNone(self.bridge.current_sim_process) + + def test_shutdown_robotic_stack(self): + self.bridge.current_robotic_stack_process = MagicMock() + self.bridge.current_robotic_stack_process.poll.return_value = 0 + + self.bridge._shutdown_robotic_stack() + + self.bridge.current_robotic_stack_process.send_signal.assert_called_once_with( + signal.SIGINT + ) + self.bridge.current_robotic_stack_process.wait.assert_called_once() + + def test_get_available_spawnable_names(self): + # Mock the response + response = MagicMock() + response.payload.model_names = ["cube", "carrot"] + self.bridge._try_service_call = MagicMock(return_value=response) + + names = self.bridge.get_available_spawnable_names() + + self.bridge._try_service_call.assert_called_once() + self.assertEqual(names, ["cube", "carrot"]) + + def test_to_ros2_pose(self): + # Create a pose + pose = Pose( + translation=Translation(x=1.0, y=2.0, z=3.0), + rotation=Rotation(x=0.1, y=0.2, z=0.3, w=0.4), + ) + + # Convert to ROS2 pose + ros2_pose = self.bridge._to_ros2_pose(pose) + + # Check the conversion + self.assertEqual(ros2_pose.position.x, 1.0) + self.assertEqual(ros2_pose.position.y, 2.0) + self.assertEqual(ros2_pose.position.z, 3.0) + self.assertEqual(ros2_pose.orientation.x, 0.1) + self.assertEqual(ros2_pose.orientation.y, 0.2) + self.assertEqual(ros2_pose.orientation.z, 0.3) + self.assertEqual(ros2_pose.orientation.w, 0.4) + + def test_from_ros2_pose(self): + # Create a ROS2 pose + position = Point(x=1.0, y=2.0, z=3.0) + orientation = Quaternion(x=0.1, y=0.2, z=0.3, w=0.4) + ros2_pose = ROS2Pose(position=position, orientation=orientation) + + # Convert from ROS2Pose to Pose + pose = self.bridge._from_ros2_pose(ros2_pose) + + # Check the conversion + self.assertEqual(pose.translation.x, 1.0) + self.assertEqual(pose.translation.y, 2.0) + self.assertEqual(pose.translation.z, 3.0) + self.assertEqual(pose.rotation.x, 0.1) + self.assertEqual(pose.rotation.y, 0.2) + self.assertEqual(pose.rotation.z, 0.3) + self.assertEqual(pose.rotation.w, 0.4) + + +class TestROS2ARIConnectorInterface(unittest.TestCase): + """Tests to ensure the ROS2ARIConnector interface meets the expectations of O3DExROS2Bridge.""" + + def setUp(self): + rclpy.init() + self.connector = ROS2ARIConnector() + + def tearDown(self): + rclpy.shutdown() + + def test_connector_required_methods_exist(self): + """Test that all required methods exist on the ROS2ARIConnector.""" + connector = ROS2ARIConnector() + + # Check that all required methods exist + self.assertTrue( + hasattr(connector, "service_call"), "service_call method is missing" + ) + self.assertTrue( + hasattr(connector, "get_transform"), "get_transform method is missing" + ) + self.assertTrue( + hasattr(connector, "send_message"), "send_message method is missing" + ) + self.assertTrue( + hasattr(connector, "receive_message"), "receive_message method is missing" + ) + self.assertTrue(hasattr(connector, "shutdown"), "shutdown method is missing") + self.assertTrue( + hasattr(connector, "get_topics_names_and_types"), + "get_topics_names_and_types method is missing", + ) + self.assertTrue( + hasattr(connector, "node"), + "node property is missing", + ) + + def resolve_annotation(self, annotation: type) -> type: + """Helper function to unwrap Optional types. Workaround for problem with asserting Optional types.""" + if get_origin(annotation) is typing.Optional: + return get_args(annotation)[0] + return annotation + + def test_get_transform_signature(self): + signature = inspect.signature(self.connector.get_transform) + parameters = signature.parameters + + expected_params: dict[str, type] = { + "target_frame": str, + "source_frame": str, + "timeout_sec": float, + } + + assert list(parameters.keys()) == list(expected_params.keys()), ( + f"Parameter names do not match, expected: {list(expected_params.keys())}, got: {list(parameters.keys())}" + ) + + for param_name, expected_type in expected_params.items(): + param = parameters[param_name] + self.assertEqual( + self.resolve_annotation(param.annotation), + self.resolve_annotation(expected_type), + f"Parameter '{param_name}' has incorrect type, expected: {expected_type}, got: {param.annotation}", + ) + + # Check return type explicitly + assert signature.return_annotation is TransformStamped, ( + f"Return type is incorrect, expected: TransformStamped, got: {signature.return_annotation}" + ) + + def test_send_message_signature(self): + signature = inspect.signature(self.connector.send_message) + parameters = signature.parameters + + expected_params: dict[str, type] = { + "message": ROS2ARIMessage, + "target": str, + "msg_type": str, + "auto_qos_matching": bool, + "qos_profile": Optional[QoSProfile], + } + + self.assertListEqual( + list(parameters.keys())[: len(expected_params)], + list(expected_params.keys()), + f"Parameter names do not match, expected: {list(expected_params.keys())}, got: {list(parameters.keys())}", + ) + + for param_name, expected_type in expected_params.items(): + param = parameters[param_name] + self.assertEqual( + self.resolve_annotation(param.annotation), + self.resolve_annotation(expected_type), + f"Parameter '{param_name}' has incorrect type, expected: {expected_type}, got: {param.annotation}", + ) + + self.assertIs( + signature.return_annotation, + inspect.Signature.empty, + "send_message should have no return value", + ) + + def test_receive_message_signature(self): + signature = inspect.signature(self.connector.receive_message) + parameters = signature.parameters + + expected_params: dict[str, type] = { + "source": str, + "timeout_sec": float, + "msg_type": Optional[str], + "auto_topic_type": bool, + } + + self.assertListEqual( + list(parameters.keys())[: len(expected_params)], + list(expected_params.keys()), + f"Parameter names do not match, expected: {list(expected_params.keys())}, got: {list(parameters.keys())}", + ) + + for param_name, expected_type in expected_params.items(): + param = parameters[param_name] + self.assertEqual( + self.resolve_annotation(param.annotation), + self.resolve_annotation(expected_type), + f"Parameter '{param_name}' has incorrect type, expected: {expected_type}, got: {param.annotation}", + ) + + self.assertIs( + signature.return_annotation, + ROS2ARIMessage, + f"Return type is incorrect, expected: ROS2ARIMessage, got: {signature.return_annotation}", + ) + + def test_get_topics_names_and_types_signature(self): + signature = inspect.signature(self.connector.get_topics_names_and_types) + parameters = signature.parameters + + expected_params: dict[str, type] = {} + + assert list(parameters.keys()) == list(expected_params.keys()), ( + f"Parameter names do not match, expected: {list(expected_params.keys())}, got: {list(parameters.keys())}" + ) + + for param_name, expected_type in expected_params.items(): + param = parameters[param_name] + self.assertEqual( + self.resolve_annotation(param.annotation), + self.resolve_annotation(expected_type), + f"Parameter '{param_name}' has incorrect type, expected: {expected_type}, got: {param.annotation}", + ) + + self.assertEqual( + signature.return_annotation, + List[Tuple[str, List[str]]], + f"Return type is incorrect, expected: List[Tuple[str, List[str]]], got: {signature.return_annotation}", + ) + + def test_node_property(self): + """Test that the node property returns the expected Node instance.""" + mock_node = MagicMock(spec=Node) + self.connector._node = mock_node + + self.assertEqual(self.connector.node, mock_node) + self.assertIsInstance(self.connector.node, Node) diff --git a/tests/rai_sim/test_simulation_bridge.py b/tests/rai_sim/test_simulation_bridge.py new file mode 100644 index 000000000..39f91ab33 --- /dev/null +++ b/tests/rai_sim/test_simulation_bridge.py @@ -0,0 +1,325 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import unittest +from pathlib import Path +from typing import Optional + +import pytest +from pydantic import ValidationError + +from rai_sim.simulation_bridge import ( + Entity, + Pose, + Rotation, + SceneState, + SimulationBridge, + SimulationConfig, + SpawnedEntity, + Translation, +) + + +# Helper Functions +def create_translation(x: float, y: float, z: float) -> Translation: + return Translation(x=x, y=y, z=z) + + +def create_rotation(x: float, y: float, z: float, w: float) -> Rotation: + return Rotation(x=x, y=y, z=z, w=w) + + +def create_pose(translation: Translation, rotation: Optional[Rotation] = None) -> Pose: + return Pose(translation=translation, rotation=rotation) + + +# Test Cases +def test_translation(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + + assert isinstance(translation.x, float) + assert isinstance(translation.y, float) + assert isinstance(translation.z, float) + + assert translation.x == 1.1 + assert translation.y == 2.2 + assert translation.z == 3.3 + + +def test_rotation(): + rotation = Rotation(x=0.1, y=0.2, z=0.3, w=0.4) + + assert isinstance(rotation.x, float) + assert isinstance(rotation.y, float) + assert isinstance(rotation.z, float) + assert isinstance(rotation.w, float) + + assert rotation.x == 0.1 + assert rotation.y == 0.2 + assert rotation.z == 0.3 + assert rotation.w == 0.4 + + +def test_pose(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = Rotation(x=0.1, y=0.2, z=0.3, w=0.4) + + pose = Pose(translation=translation, rotation=rotation) + + assert isinstance(pose.translation, Translation) + assert isinstance(pose.rotation, Rotation) + + assert pose.translation.x == 1.1 + assert pose.rotation.w == 0.4 + + +def test_optional_rotation(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + pose = create_pose(translation=translation) + + assert isinstance(pose.translation, Translation) + assert pose.rotation is None + + +def test_entity(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = create_rotation(x=0.1, y=0.2, z=0.3, w=0.4) + pose = create_pose(translation=translation, rotation=rotation) + + entity = Entity(name="test_cube", prefab_name="cube", pose=pose) + + assert isinstance(entity.name, str) + assert isinstance(entity.prefab_name, str) + assert isinstance(entity.pose, Pose) + + assert entity.name == "test_cube" + assert entity.prefab_name == "cube" + assert entity.pose == pose + + +def test_spawned_entity(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = create_rotation(x=0.1, y=0.2, z=0.3, w=0.4) + pose = create_pose(translation=translation, rotation=rotation) + + spawned_entity = SpawnedEntity( + name="test_cube", + prefab_name="cube", + pose=pose, + id="id_123", + ) + + assert isinstance(spawned_entity.name, str) + assert isinstance(spawned_entity.prefab_name, str) + assert isinstance(spawned_entity.pose, Pose) + assert isinstance(spawned_entity.id, str) + + assert spawned_entity.id == "id_123" + + +def test_simulation_config_unique_names(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = create_rotation(x=0.1, y=0.2, z=0.3, w=0.4) + pose = create_pose(translation=translation, rotation=rotation) + + entities = [ + Entity(name="entity1", prefab_name="cube", pose=pose), + Entity(name="entity2", prefab_name="carrot", pose=pose), + ] + + config = SimulationConfig(entities=entities) + + assert isinstance(config.entities, list) + assert all(isinstance(e, Entity) for e in config.entities) + + assert len(config.entities) == 2 + + +def test_simulation_config_duplicate_names(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = create_rotation(x=0.1, y=0.2, z=0.3, w=0.4) + pose = create_pose(translation=translation, rotation=rotation) + + entities = [ + Entity(name="duplicate", prefab_name="cube", pose=pose), + Entity(name="duplicate", prefab_name="carrot", pose=pose), + ] + + with pytest.raises(ValidationError): + SimulationConfig(entities=entities) + + +def test_load_base_config(sample_base_yaml_config: Path): + config = SimulationConfig.load_base_config(sample_base_yaml_config) + + assert isinstance(config.entities, list) + assert all(isinstance(e, Entity) for e in config.entities) + + assert len(config.entities) == 2 + + +class MockSimulationBridge(SimulationBridge[SimulationConfig]): + """Mock implementation of SimulationBridge for testing.""" + + def setup_scene(self, simulation_config: SimulationConfig): + """Mock implementation of setup_scene.""" + for entity in simulation_config.entities: + self._spawn_entity(entity) + + def _spawn_entity(self, entity: Entity): + """Mock implementation of _spawn_entity.""" + spawned_entity = SpawnedEntity( + id=f"id_{entity.name}", + name=entity.name, + prefab_name=entity.prefab_name, + pose=entity.pose, + ) + self.spawned_entities.append(spawned_entity) + + def _despawn_entity(self, entity: SpawnedEntity): + """Mock implementation of _despawn_entity.""" + self.spawned_entities = [e for e in self.spawned_entities if e.id != entity.id] + + def get_object_pose(self, entity: SpawnedEntity) -> Pose: + """Mock implementation of get_object_pose.""" + for spawned_entity in self.spawned_entities: + if spawned_entity.id == entity.id: + return spawned_entity.pose + raise ValueError(f"Entity with id {entity.id} not found") + + def get_scene_state(self) -> SceneState: + """Mock implementation of get_scene_state.""" + return SceneState(entities=self.spawned_entities) + + +class TestSimulationBridge(unittest.TestCase): + def setUp(self): + """Set up test fixtures.""" + self.logger = logging.getLogger("test_logger") + self.bridge = MockSimulationBridge(logger=self.logger) + + # Create test entities + self.test_entity1: Entity = Entity( + name="test_entity1", + prefab_name="test_prefab1", + pose=Pose( + translation=Translation(x=1.0, y=2.0, z=3.0), + rotation=Rotation(x=0.0, y=0.0, z=0.0, w=1.0), + ), + ) + + self.test_entity2: Entity = Entity( + name="test_entity2", + prefab_name="test_prefab2", + pose=Pose(translation=Translation(x=4.0, y=5.0, z=6.0), rotation=None), + ) + + # Create a test configuration + self.test_config = SimulationConfig( + entities=[self.test_entity1, self.test_entity2] + ) + + def test_init(self): + # Test with provided logger + bridge = MockSimulationBridge(logger=self.logger) + self.assertEqual(bridge.logger, self.logger) + self.assertEqual(len(bridge.spawned_entities), 0) + + # Test with default logger + bridge = MockSimulationBridge() + self.assertIsNotNone(bridge.logger) + self.assertEqual(len(bridge.spawned_entities), 0) + + def test_setup_scene(self): + self.bridge.setup_scene(self.test_config) + + # Check if entities were spawned + self.assertEqual(len(self.bridge.spawned_entities), 2) + + # Check if the spawned entities have correct properties + self.assertEqual(self.bridge.spawned_entities[0].name, "test_entity1") + self.assertEqual(self.bridge.spawned_entities[0].prefab_name, "test_prefab1") + self.assertEqual(self.bridge.spawned_entities[0].pose.translation.x, 1.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.translation.y, 2.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.translation.z, 3.0) + assert self.bridge.spawned_entities[0].pose.rotation + self.assertEqual(self.bridge.spawned_entities[0].pose.rotation.x, 0.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.rotation.y, 0.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.rotation.z, 0.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.rotation.w, 1.0) + + self.assertEqual(self.bridge.spawned_entities[1].name, "test_entity2") + self.assertEqual(self.bridge.spawned_entities[1].prefab_name, "test_prefab2") + self.assertEqual(self.bridge.spawned_entities[1].pose.translation.x, 4.0) + self.assertEqual(self.bridge.spawned_entities[1].pose.translation.y, 5.0) + self.assertEqual(self.bridge.spawned_entities[1].pose.translation.z, 6.0) + self.assertIsNone(self.bridge.spawned_entities[1].pose.rotation) + + def test_spawn_entity(self): + self.bridge._spawn_entity(self.test_entity1) # type: ignore + spawned_entity = self.bridge.spawned_entities[0] + # Check if entity was added to spawned_entities + self.assertEqual(len(self.bridge.spawned_entities), 1) + self.assertIsInstance(spawned_entity, SpawnedEntity) + + def test_despawn_entity(self): + # First spawn an entity + self.bridge._spawn_entity(self.test_entity1) # type: ignore + self.assertEqual(len(self.bridge.spawned_entities), 1) + + # Then despawn it + self.bridge._despawn_entity(self.bridge.spawned_entities[0]) # type: ignore + self.assertEqual(len(self.bridge.spawned_entities), 0) + + def test_get_object_pose(self): + # First spawn an entity + self.bridge._spawn_entity(self.test_entity1) # type: ignore + + # Get the pose + pose = self.bridge.get_object_pose(self.bridge.spawned_entities[0]) + + # Check if the pose matches + self.assertEqual(pose.translation.x, 1.0) + self.assertEqual(pose.translation.y, 2.0) + self.assertEqual(pose.translation.z, 3.0) + assert pose.rotation + self.assertEqual(pose.rotation.x, 0.0) + self.assertEqual(pose.rotation.y, 0.0) + self.assertEqual(pose.rotation.z, 0.0) + self.assertEqual(pose.rotation.w, 1.0) + + # Test for non-existent entity + non_existent_entity = SpawnedEntity( + id="non_existent", + name="non_existent", + prefab_name="non_existent", + pose=Pose(translation=Translation(x=0.0, y=0.0, z=0.0)), + ) + + with self.assertRaises(ValueError): + self.bridge.get_object_pose(non_existent_entity) + + def test_get_scene_state(self): + # First spawn some entities + self.bridge._spawn_entity(self.test_entity1) # type: ignore + self.bridge._spawn_entity(self.test_entity2) # type: ignore + + # Get the scene state + scene_state = self.bridge.get_scene_state() + + # Check if the scene state contains the correct entities + self.assertEqual(len(scene_state.entities), 2) + self.assertEqual(scene_state.entities[0].name, "test_entity1") + self.assertEqual(scene_state.entities[1].name, "test_entity2") From 70e824f35ee97a76c865fb12940e61b0c22e5524 Mon Sep 17 00:00:00 2001 From: Magdalena Kotynia Date: Fri, 28 Feb 2025 12:24:42 +0100 Subject: [PATCH 6/8] fix: simbench fixes (#443) --- .../rai_bench/examples/o3de_test_benchmark.py | 11 +++++++++-- src/rai_sim/rai_sim/o3de/o3de_bridge.py | 4 ++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py index b644567d8..cedbafd57 100644 --- a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py +++ b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py @@ -15,6 +15,7 @@ ########### EXAMPLE USAGE ########### import logging import time +from datetime import datetime from pathlib import Path from typing import List @@ -69,7 +70,12 @@ GetROS2TopicsNamesAndTypesTool(connector=connector), ] # define loggers - log_file = "src/rai_bench/rai_bench/benchmark.log" + now = datetime.now() + experiment_dir = ( + f"src/rai_bench/rai_bench/experiments/{now.strftime('%Y-%m-%d_%H-%M-%S')}" + ) + Path(experiment_dir).mkdir(parents=True, exist_ok=True) + log_file = f"{experiment_dir}/benchmark.log" file_handler = logging.FileHandler(log_file) file_handler.setLevel(logging.DEBUG) @@ -159,11 +165,12 @@ o3de = O3DEngineArmManipulationBridge(connector, logger=agent_logger) # define benchamrk + results_filename = f"{experiment_dir}/results.csv" benchmark = Benchmark( simulation_bridge=o3de, scenarios=scenarios, logger=bench_logger, - results_filename="src/rai_bench/rai_bench/results.csv", + results_filename=results_filename, ) for i, s in enumerate(scenarios): agent = create_conversational_agent( diff --git a/src/rai_sim/rai_sim/o3de/o3de_bridge.py b/src/rai_sim/rai_sim/o3de/o3de_bridge.py index 8fbe022ee..3a981531a 100644 --- a/src/rai_sim/rai_sim/o3de/o3de_bridge.py +++ b/src/rai_sim/rai_sim/o3de/o3de_bridge.py @@ -176,6 +176,9 @@ def get_object_pose(self, entity: SpawnedEntity) -> Pose: ROS2Pose(), self.connector.get_transform(object_frame + "odom", object_frame), ) + ros2_pose = do_transform_pose( + ros2_pose, self.connector.get_transform("world", "odom") + ) return self._from_ros2_pose(ros2_pose) def get_scene_state(self) -> SceneState: @@ -274,6 +277,7 @@ def setup_scene(self, simulation_config: O3DExROS2SimulationConfig): else: while self.spawned_entities: self._despawn_entity(self.spawned_entities[0]) + self.logger.info(f"Entities after despawn: {self.spawned_entities}") for entity in simulation_config.entities: self._spawn_entity(entity) From e7527b5cf0d643026b95523cb593924791087384 Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Fri, 28 Feb 2025 13:54:19 +0100 Subject: [PATCH 7/8] perf: improve performance of O3DE Benchmark (#442) --- src/rai_bench/rai_bench/examples/o3de_test_benchmark.py | 4 ++-- src/rai_sim/rai_sim/o3de/o3de_bridge.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py index cedbafd57..6068806d7 100644 --- a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py +++ b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py @@ -159,7 +159,7 @@ # custom request to arm base_arm_pose = Pose( - translation=Translation(x=0.5, y=0.1, z=0.3), + translation=Translation(x=0.1, y=0.5, z=0.4), rotation=Rotation(x=1.0, y=0.0, z=0.0, w=0.0), ) @@ -183,7 +183,7 @@ final_gripper_state=False, frame_id="panda_link0", ) # return to case position - time.sleep(2) # admire the end position for a second ;) + time.sleep(0.2) # admire the end position for a second ;) bench_logger.info("===============================================================") bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") diff --git a/src/rai_sim/rai_sim/o3de/o3de_bridge.py b/src/rai_sim/rai_sim/o3de/o3de_bridge.py index 3a981531a..14ba247d4 100644 --- a/src/rai_sim/rai_sim/o3de/o3de_bridge.py +++ b/src/rai_sim/rai_sim/o3de/o3de_bridge.py @@ -201,7 +201,7 @@ def get_scene_state(self) -> SceneState: return SceneState(entities=entities) def _is_ros2_stack_ready( - self, required_ros2_stack: dict[str, List[str]], retries: int = 30 + self, required_ros2_stack: dict[str, List[str]], retries: int = 120 ) -> bool: for i in range(retries): available_topics = self.connector.get_topics_names_and_types() @@ -259,7 +259,7 @@ def _is_ros2_stack_ready( self.logger.info("All required ROS2 stack components are available.") return True - time.sleep(3) + time.sleep(0.5) self.logger.error( "Maximum number of retries reached. Required ROS2 stack components are not fully available." @@ -408,7 +408,7 @@ def move_arm( request.target_pose.header.frame_id = frame_id request.target_pose.pose.position.x = pose.translation.x - request.target_pose.pose.position.x = pose.translation.y + request.target_pose.pose.position.y = pose.translation.y request.target_pose.pose.position.z = pose.translation.z if pose.rotation: From 84c64934b64aea9e28ce599290d1fd9e26c5c668 Mon Sep 17 00:00:00 2001 From: Jakub Matejczyk <58983084+jmatejcz@users.noreply.github.com> Date: Fri, 28 Feb 2025 16:07:48 +0100 Subject: [PATCH 8/8] refactor: Task score always start at 0.0 (#444) --- src/rai_bench/rai_bench/benchmark_model.py | 33 ++--- .../o3de_test_bench/tasks/grab_carrot_task.py | 133 ++++++++++-------- .../o3de_test_bench/tasks/place_cubes_task.py | 131 +++++++++-------- 3 files changed, 153 insertions(+), 144 deletions(-) diff --git a/src/rai_bench/rai_bench/benchmark_model.py b/src/rai_bench/rai_bench/benchmark_model.py index de9fc43eb..83f4c1f66 100644 --- a/src/rai_bench/rai_bench/benchmark_model.py +++ b/src/rai_bench/rai_bench/benchmark_model.py @@ -180,6 +180,13 @@ def __init__( else: self._logger = logging.getLogger(__name__) + self.fieldnames = [ + "task", + "simulation_config", + "final_score", + "total_time", + "number_of_tool_calls", + ] self._initialize_results_file() @classmethod @@ -192,6 +199,7 @@ def create_scenarios( # TODO (jm) hacky_fix, taking paths as args here, not the best solution, # but more changes to code would be required scenarios: List[Scenario[SimulationConfigT]] = [] + for task in tasks: for sim_conf, sim_path in zip(simulation_configs, simulation_configs_paths): try: @@ -210,19 +218,10 @@ def create_scenarios( def _initialize_results_file(self): """Initialize the CSV file with headers.""" - fieldnames = [ - "task", - "simulation_config", - "initial_score", - "final_score", - "total_time", - "number_of_tool_calls", - ] - with open( self.results_filename, mode="w", newline="", encoding="utf-8" ) as file: - writer = csv.DictWriter(file, fieldnames=fieldnames) + writer = csv.DictWriter(file, fieldnames=self.fieldnames) writer.writeheader() def run_next(self, agent) -> None: @@ -239,8 +238,6 @@ def run_next(self, agent) -> None: self._logger.info( # type: ignore f"RUNNING SCENARIO NUMBER {i + 1} / {self.num_of_scenarios}, TASK: {scenario.task.get_prompt()}" ) - initial_result = scenario.task.calculate_result(self.simulation_bridge) - self._logger.info(f"RESULT OF THE INITIAL SETUP: {initial_result}") # type: ignore tool_calls_num = 0 ts = time.perf_counter() @@ -281,7 +278,6 @@ def run_next(self, agent) -> None: scenario_result: Dict[str, Any] = { "task": scenario.task.get_prompt(), "simulation_config": scenario.simulation_config_path, - "initial_score": initial_result, "final_score": result, "total_time": f"{total_time:.3f}", "number_of_tool_calls": tool_calls_num, @@ -294,19 +290,10 @@ def run_next(self, agent) -> None: def _save_scenario_result_to_csv(self, result: Dict[str, Any]) -> None: """Save a single scenario result to the CSV file.""" - fieldnames = [ - "task", - "simulation_config", - "initial_score", - "final_score", - "total_time", - "number_of_tool_calls", - ] - with open( self.results_filename, mode="a", newline="", encoding="utf-8" ) as file: - writer = csv.DictWriter(file, fieldnames=fieldnames) + writer = csv.DictWriter(file, fieldnames=self.fieldnames) writer.writerow(result) def get_results(self) -> List[Dict[str, Any]]: diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py index 56203f194..d011051d1 100644 --- a/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py @@ -11,91 +11,100 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple -from rai_bench.benchmark_model import ( - EntitiesMismatchException, - Task, -) +from rai_bench.benchmark_model import EntitiesMismatchException, Task from rai_sim.o3de.o3de_bridge import ( SimulationBridge, ) -from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT +from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT, SpawnedEntity class GrabCarrotTask(Task): + obj_types = ["carrot"] + + # TODO (jm) extract common logic to some parent manipulation task def get_prompt(self) -> str: return "Manipulate objects, so that all carrots to the left side of the table (positive y)" def validate_config(self, simulation_config: SimulationConfig) -> bool: for ent in simulation_config.entities: - if ent.prefab_name == "carrot": + if ent.prefab_name in self.obj_types: return True return False - def calculate_result( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> float: - # TODO (jm) extract common logic to some parent manipulation task? - initially_misplaced_now_correct = 0 # when the object which was in the incorrect place at the start, is in a correct place at the end - initially_misplaced_still_incorrect = 0 # when the object which was in the incorrect place at the start, is in a incorrect place at the end - initially_correct_still_correct = 0 # when the object which was in the correct place at the start, is in a correct place at the end - initially_correct_now_incorrect = 0 # when the object which was in the correct place at the start, is in a incorrect place at the end + def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]: + """Calculate how many objects are positioned correct and incorrect""" + correct = sum(1 for ent in entities if ent.pose.translation.y > 0.0) + incorrect: int = len(entities) - correct + return correct, incorrect - scene_state = simulation_bridge.get_scene_state() + def calculate_initial_placements( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> tuple[int, int]: + """ + Calculates the number of objects that are correctly and incorrectly placed initially. + """ initial_carrots = self.filter_entities_by_prefab_type( - simulation_bridge.spawned_entities, prefab_types=["carrot"] + simulation_bridge.spawned_entities, prefab_types=self.obj_types + ) + initially_correct, initially_incorrect = self.calculate_correct( + entities=initial_carrots + ) + + self.logger.info( # type: ignore + f"Initially correctly placed carrots: {initially_correct}, Initially incorrectly placed carrots: {initially_incorrect}" ) + return initially_correct, initially_incorrect + + def calculate_final_placements( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> tuple[int, int]: + """ + Calculates the number of objects that are correctly and incorrectly placed at the end of the simulation. + """ + scene_state = simulation_bridge.get_scene_state() final_carrots = self.filter_entities_by_prefab_type( - scene_state.entities, prefab_types=["carrot"] + scene_state.entities, prefab_types=self.obj_types + ) + final_correct, final_incorrect = self.calculate_correct(entities=final_carrots) + + self.logger.info( # type: ignore + f"Finally correctly placed carrots: {final_correct}, Finally incorrectly placed carrots: {final_incorrect}" ) - num_initial_carrots = len(initial_carrots) + return final_correct, final_incorrect - if num_initial_carrots != len(final_carrots): + def calculate_result( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> float: + """ + Calculates a score from 0.0 to 1.0, where 0.0 represents the initial placements or worse and 1.0 represents perfect final placements. + """ + initially_correct, initially_incorrect = self.calculate_initial_placements( + simulation_bridge + ) + final_correct, final_incorrect = self.calculate_final_placements( + simulation_bridge + ) + + total_objects = initially_correct + initially_incorrect + if total_objects == 0: + return 1.0 + elif (initially_correct + initially_incorrect) != ( + final_correct + final_incorrect + ): raise EntitiesMismatchException( - "Number of initially spawned entities does not match number of entities present at the end." + "number of initial entities does not match final entities number." ) - + elif initially_incorrect == 0: + pass + # NOTE all objects are placed correctly + # no point in running task + raise ValueError("All objects are placed correctly at the start.") else: - self.logger.debug(f"initial positions: {initial_carrots}") # type: ignore - self.logger.debug(f"current positions: {final_carrots}") # type: ignore - for initial_carrot in initial_carrots: - for final_carrot in final_carrots: - if initial_carrot.name == final_carrot.name: - initial_y = initial_carrot.pose.translation.y - final_y = final_carrot.pose.translation.y - # NOTE the specific coords that refer to for example - # middle of the table can differ across simulations, - # take that into consideration - if ( - initial_y <= 0.0 - ): # Carrot started in the incorrect place (right side) - if final_y >= 0.0: - initially_misplaced_now_correct += ( - 1 # Moved to correct side - ) - else: - initially_misplaced_still_incorrect += ( - 1 # Stayed on incorrect side - ) - else: # Carrot started in the correct place (left side) - if final_y >= 0.0: - initially_correct_still_correct += ( - 1 # Stayed on correct side - ) - else: - initially_correct_now_incorrect += ( - 1 # Moved incorrectly to the wrong side - ) - break - else: - raise EntitiesMismatchException( - f"Entity with name: {initial_carrot.name} which was present in initial scene, not found in final scene." - ) + corrected = final_correct - initially_correct + score = max(0.0, corrected / initially_incorrect) - self.logger.info( # type: ignore - f"initially_misplaced_now_correct: {initially_misplaced_now_correct}, initially_misplaced_still_incorrect: {initially_misplaced_still_incorrect}, initially_correct_still_correct: {initially_correct_still_correct}, initially_correct_now_incorrect: {initially_correct_now_incorrect}" - ) - return ( - initially_misplaced_now_correct + initially_correct_still_correct - ) / num_initial_carrots + self.logger.info(f"Calculated score: {score:.2f}") # type: ignore + return score diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py index 9ad03c9b2..e7547c49e 100644 --- a/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py @@ -11,94 +11,107 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple from rai_bench.benchmark_model import ( EntitiesMismatchException, Task, ) from rai_sim.o3de.o3de_bridge import SimulationBridge -from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT +from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT, SpawnedEntity class PlaceCubesTask(Task): + # TODO (jm) extract common logic to some parent manipulation task + obj_types = ["red_cube", "blue_cube", "yellow_cube"] + def get_prompt(self) -> str: return "Manipulate objects, so that all cubes are adjacent to at least one cube" def validate_config(self, simulation_config: SimulationConfig) -> bool: - cube_types = ["red_cube", "blue_cube", "yellow_cube"] cubes_num = 0 for ent in simulation_config.entities: - if ent.prefab_name in cube_types: + if ent.prefab_name in self.obj_types: cubes_num += 1 if cubes_num > 1: return True return False - def calculate_result( + def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]: + """Calculate how many objects are positioned correct and incorrect""" + correct = sum( + 1 + for ent in entities + if self.is_adjacent_to_any( + ent.pose, [e.pose for e in entities if e != ent], 0.15 + ) + ) + incorrect: int = len(entities) - correct + return correct, incorrect + + def calculate_initial_placements( self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> float: - # TODO (jm) extract common logic to some parent manipulation task? - initially_misplaced_now_correct = 0 # when the object which was in the incorrect place at the start, is in a correct place at the end - initially_misplaced_still_incorrect = 0 # when the object which was in the incorrect place at the start, is in a incorrect place at the end - initially_correct_still_correct = 0 # when the object which was in the correct place at the start, is in a correct place at the end - initially_correct_now_incorrect = 0 # when the object which was in the correct place at the start, is in a incorrect place at the end + ) -> tuple[int, int]: + """ + Calculates the number of objects that are correctly and incorrectly placed initially. + """ + initial_cubes = self.filter_entities_by_prefab_type( + simulation_bridge.spawned_entities, prefab_types=self.obj_types + ) + initially_correct, initially_incorrect = self.calculate_correct( + entities=initial_cubes + ) - cube_types = ["red_cube", "blue_cube", "yellow_cube"] + self.logger.info( # type: ignore + f"Initially correctly placed cubes: {initially_correct}, Initially incorrectly placed cubes: {initially_incorrect}" + ) + return initially_correct, initially_incorrect + + def calculate_final_placements( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> tuple[int, int]: + """ + Calculates the number of objects that are correctly and incorrectly placed at the end of the simulation. + """ scene_state = simulation_bridge.get_scene_state() + final_cubes = self.filter_entities_by_prefab_type( + scene_state.entities, prefab_types=self.obj_types + ) + final_correct, final_incorrect = self.calculate_correct(entities=final_cubes) - initial_cubes = self.filter_entities_by_prefab_type( - simulation_bridge.spawned_entities, prefab_types=cube_types + self.logger.info( # type: ignore + f"Finally correctly placed cubes: {final_correct}, Finally incorrectly placed cubes: {final_incorrect}" ) - final_cubes = self.filter_entities_by_prefab_type( - scene_state.entities, prefab_types=cube_types + return final_correct, final_incorrect + + def calculate_result( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> float: + """ + Calculates a score from 0.0 to 1.0, where 0.0 represents the initial placements or worse and 1.0 represents perfect final placements. + """ + initially_correct, initially_incorrect = self.calculate_initial_placements( + simulation_bridge + ) + final_correct, final_incorrect = self.calculate_final_placements( + simulation_bridge ) - num_of_objects = len(initial_cubes) - if num_of_objects != len(final_cubes): + total_objects = initially_correct + initially_incorrect + if total_objects == 0: + return 1.0 + elif (initially_correct + initially_incorrect) != ( + final_correct + final_incorrect + ): raise EntitiesMismatchException( - "Number of initially spawned entities does not match number of entities present at the end." + "number of initial entities does not match final entities number." ) - + elif initially_incorrect == 0: + raise ValueError("All objects are placed correctly at the start.") else: - initial_poses = [cube.pose for cube in initial_cubes] - final_poses = [cube.pose for cube in final_cubes] - # NOTE the specific coords that refer to for example - # middle of the table can differ across simulations, - # take that into consideration - self.logger.debug(f"initial positions: {initial_cubes}") - self.logger.debug(f"current positions: {final_cubes}") - for initial_cube in initial_cubes: - for final_cube in final_cubes: - if initial_cube.name == final_cube.name: - was_adjacent_initially = self.is_adjacent_to_any( - initial_cube.pose, - [p for p in initial_poses if p != initial_cube.pose], - 0.15, - ) - is_adjacent_finally = self.is_adjacent_to_any( - final_cube.pose, - [p for p in final_poses if p != final_cube.pose], - 0.15, - ) - if not was_adjacent_initially and is_adjacent_finally: - initially_misplaced_now_correct += 1 - elif not was_adjacent_initially and not is_adjacent_finally: - initially_misplaced_still_incorrect += 1 - elif was_adjacent_initially and is_adjacent_finally: - initially_correct_still_correct += 1 - elif was_adjacent_initially and not is_adjacent_finally: - initially_correct_now_incorrect += 1 - - break - else: - raise EntitiesMismatchException( - f"Entity with name: {initial_cube.name} which was present in initial scene, not found in final scene." - ) + corrected = final_correct - initially_correct + score = max(0.0, corrected / initially_incorrect) - self.logger.info( - f"initially_misplaced_now_correct: {initially_misplaced_now_correct}, initially_misplaced_still_incorrect: {initially_misplaced_still_incorrect}, initially_correct_still_correct: {initially_correct_still_correct}, initially_correct_now_incorrect: {initially_correct_now_incorrect}" - ) - return ( - initially_misplaced_now_correct + initially_correct_still_correct - ) / num_of_objects + self.logger.info(f"Calculated score: {score:.2f}") # type: ignore + return score