diff --git a/.gitignore b/.gitignore index 48d048836..9ee309471 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,5 @@ logs/ src/examples/*-demo artifact_database.pkl + +imgui.ini diff --git a/examples/manipulation-demo-no-binary.launch.py b/examples/manipulation-demo-no-binary.launch.py new file mode 100755 index 000000000..494deea7b --- /dev/null +++ b/examples/manipulation-demo-no-binary.launch.py @@ -0,0 +1,59 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from launch import LaunchDescription +from launch.actions import ( + IncludeLaunchDescription, +) +from launch.launch_description_sources import PythonLaunchDescriptionSource +from launch_ros.actions import Node +from launch_ros.substitutions import FindPackageShare + + +# TODO (mkotynia) think about separation of launches +def generate_launch_description(): + launch_moveit = IncludeLaunchDescription( + PythonLaunchDescriptionSource( + [ + "src/examples/rai-manipulation-demo/Project/Examples/panda_moveit_config_demo.launch.py", + ] + ) + ) + + launch_robotic_manipulation = Node( + package="robotic_manipulation", + executable="robotic_manipulation", + # name="robotic_manipulation_node", + output="screen", + parameters=[ + {"use_sim_time": True}, + ], + ) + + launch_openset = IncludeLaunchDescription( + PythonLaunchDescriptionSource( + [ + FindPackageShare("rai_bringup"), + "/launch/openset.launch.py", + ] + ), + ) + + return LaunchDescription( + [ + launch_openset, + launch_moveit, + launch_robotic_manipulation, + ] + ) diff --git a/examples/manipulation-demo.launch.py b/examples/manipulation-demo.launch.py index a9210698f..ee28c9b44 100644 --- a/examples/manipulation-demo.launch.py +++ b/examples/manipulation-demo.launch.py @@ -12,22 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import rclpy -from launch import LaunchContext, LaunchDescription +from launch import LaunchDescription from launch.actions import ( DeclareLaunchArgument, ExecuteProcess, IncludeLaunchDescription, - OpaqueFunction, - RegisterEventHandler, ) -from launch.event_handlers import OnExecutionComplete, OnProcessStart from launch.launch_description_sources import PythonLaunchDescriptionSource from launch.substitutions import LaunchConfiguration from launch_ros.actions import Node from launch_ros.substitutions import FindPackageShare -from rclpy.qos import QoSProfile, ReliabilityPolicy -from rosgraph_msgs.msg import Clock def generate_launch_description(): @@ -46,21 +40,6 @@ def generate_launch_description(): output="screen", ) - def wait_for_clock_message(context: LaunchContext, *args, **kwargs): - rclpy.init() - node = rclpy.create_node("wait_for_game_launcher") - node.create_subscription( - Clock, - "/clock", - lambda msg: rclpy.shutdown(), - QoSProfile(depth=1, reliability=ReliabilityPolicy.BEST_EFFORT), - ) - rclpy.spin(node) - return None - - # Game launcher will start publishing the clock message after loading the simulation - wait_for_game_launcher = OpaqueFunction(function=wait_for_clock_message) - launch_moveit = IncludeLaunchDescription( PythonLaunchDescriptionSource( [ @@ -72,7 +51,6 @@ def wait_for_clock_message(context: LaunchContext, *args, **kwargs): launch_robotic_manipulation = Node( package="robotic_manipulation", executable="robotic_manipulation", - name="robotic_manipulation_node", output="screen", parameters=[ {"use_sim_time": True}, @@ -90,28 +68,10 @@ def wait_for_clock_message(context: LaunchContext, *args, **kwargs): return LaunchDescription( [ - # Include the game_launcher argument game_launcher_arg, - # Launch the game launcher and wait for it to load launch_game_launcher, - RegisterEventHandler( - event_handler=OnProcessStart( - target_action=launch_game_launcher, - on_start=[ - wait_for_game_launcher, - ], - ) - ), - # Launch the MoveIt node after loading the simulation - RegisterEventHandler( - event_handler=OnExecutionComplete( - target_action=wait_for_game_launcher, - on_completion=[ - launch_openset, - launch_moveit, - launch_robotic_manipulation, - ], - ) - ), + launch_openset, + launch_moveit, + launch_robotic_manipulation, ] ) diff --git a/examples/manipulation-demo.py b/examples/manipulation-demo.py index 6f73248c2..92c820031 100644 --- a/examples/manipulation-demo.py +++ b/examples/manipulation-demo.py @@ -12,37 +12,37 @@ # See the License for the specific language goveself.rning permissions and # limitations under the License. -import threading import rclpy import rclpy.qos from langchain_core.messages import HumanMessage from rai.agents.conversational_agent import create_conversational_agent -from rai.node import RaiBaseNode +from rai.communication.ros2.connectors import ROS2ARIConnector from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool -from rai.tools.ros.native import GetCameraImage, Ros2GetTopicsNamesAndTypesTool +from rai.tools.ros2.topics import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool from rai.utils.model_initialization import get_llm_model +from rai_open_set_vision.tools import GetGrabbingPointTool def create_agent(): rclpy.init() - node = RaiBaseNode(node_name="manipulation_demo") + connector = ROS2ARIConnector() + node = connector.node node.declare_parameter("conversion_ratio", 1.0) - threading.Thread(target=node.spin).start() - tools = [ GetObjectPositionsTool( - node=node, + connector=connector, target_frame="panda_link0", source_frame="RGBDCamera5", camera_topic="/color_image5", depth_topic="/depth_image5", camera_info_topic="/color_camera_info5", + get_grabbing_point_tool=GetGrabbingPointTool(connector=connector), ), - MoveToPointTool(node=node, manipulator_frame="panda_link0"), - GetCameraImage(node=node), - Ros2GetTopicsNamesAndTypesTool(node=node), + MoveToPointTool(connector=connector, manipulator_frame="panda_link0"), + GetROS2ImageTool(connector=connector), + GetROS2TopicsNamesAndTypesTool(connector=connector), ] llm = get_llm_model(model_type="complex_model", streaming=True) diff --git a/poetry.lock b/poetry.lock index 93059b345..335612128 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5700,6 +5700,35 @@ torchaudio = "^2.3.1" type = "directory" url = "src/rai_asr" +[[package]] +name = "rai-bench" +version = "0.1.0" +description = "Package for running and creating benchmarks." +optional = false +python-versions = "^3.10" +files = [] +develop = true + +[package.source] +type = "directory" +url = "src/rai_bench" + +[[package]] +name = "rai-sim" +version = "0.0.1" +description = "Package to run simulations" +optional = false +python-versions = "^3.10, <3.13" +files = [] +develop = true + +[package.dependencies] +PyYAML = "*" + +[package.source] +type = "directory" +url = "src/rai_sim" + [[package]] name = "rai-tts" version = "1.0.0" @@ -8302,4 +8331,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10, <3.13" -content-hash = "242e440c4ce4b31fa629d198a3d79b0854e84284d013d819a4b7a24e633a1706" +content-hash = "d943b786f2bb8dddc9249475409a4d7c9c4b0a77041611c039431df55ad94000" diff --git a/pyproject.toml b/pyproject.toml index b313dd299..92bcfa7c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,8 @@ python = "^3.10, <3.13" rai = {path = "src/rai_core", develop = true} rai_asr = {path = "src/rai_asr", develop = true} rai_tts = {path = "src/rai_tts", develop = true} +rai_sim = {path = "src/rai_sim", develop = true} +rai_bench = {path = "src/rai_bench", develop = true} langchain-core = "^0.3" langchain = "*" diff --git a/setup_shell.sh b/setup_shell.sh index cc67a5369..506df2cd2 100755 --- a/setup_shell.sh +++ b/setup_shell.sh @@ -30,3 +30,8 @@ esac export PYTHONPATH PYTHONPATH="$(dirname "$(dirname "$(poetry run which python)")")/lib/python$(poetry run python --version | awk '{print $2}' | cut -d. -f1,2)/site-packages:$PYTHONPATH" +PYTHONPATH="src/rai_core:$PYTHONPATH" +PYTHONPATH="src/rai_asr:$PYTHONPATH" +PYTHONPATH="src/rai_tts:$PYTHONPATH" +PYTHONPATH="src/rai_sim:$PYTHONPATH" +PYTHONPATH="src/rai_bench:$PYTHONPATH" diff --git a/src/rai_bench/README.md b/src/rai_bench/README.md new file mode 100644 index 000000000..3db5d04ba --- /dev/null +++ b/src/rai_bench/README.md @@ -0,0 +1,54 @@ +## RAI Benchmark + +## Description + +The RAI Bench is a package including benchmarks and providing frame for creating new benchmarks + +## Frame Components + +Frame components can be found in `src/rai_bench/rai_bench/benchmark_model.py` + +- `Task` +- `Scenario` +- `Benchmark` + +For more information about these classes go to -> `src/rai_bench/rai_bench/benchmark_model.py` + +### O3DE TEST BENCHMARK + +O3DE Test Benchmark (`src/rai_bench/rai_bench/o3de_test_bench/`), contains 2 Tasks(`tasks/`) - GrabCarrotTask and PlaceCubesTask (these tasks implement calculating scores) and 4 scene_configs(`configs/`) for O3DE robotic arm simulation. + +Both tasks calculate score, taking into consideration 4 values: + +- initially_misplaced_now_correct +- initially_misplaced_still_incorrect +- initially_correct_still_correct +- initially_correct_now_incorrect + +The result is a value between 0 and 1, calculated like (initially_misplaced_now_correct + initially_correct_still_correct) / number_of_initial_objects. +This score is calculated at the beggining and at the end of each scenario. + +### Example usage + +Example of how to load scenes, define scenarios and run benchmark can be found in `src/rai_bench/rai_bench/examples/o3de_test_benchmark.py` + +Scenarios can be loaded manually like: + +```python +one_carrot_simulation_config = O3DExROS2SimulationConfig.load_config( + base_config_path=Path("path_to_scene.yaml"), + connector_config_path=Path("path_to_o3de_config.yaml"), + ) + +Scenario(task=GrabCarrotTask(logger=some_logger), simulation_config=one_carrot_simulation_config) +``` + +or automatically like: + +```python +scenarios = Benchmark.create_scenarios( + tasks=tasks, simulation_configs=simulations_configs + ) +``` + +which will result in list of scenarios with combination of every possible task and scene(task decides if scene config is suitable for it). diff --git a/src/rai_bench/pyproject.toml b/src/rai_bench/pyproject.toml new file mode 100644 index 000000000..a5426f409 --- /dev/null +++ b/src/rai_bench/pyproject.toml @@ -0,0 +1,17 @@ +[tool.poetry] +name = "rai-bench" +version = "0.1.0" +description = "Package for running and creating benchmarks." +authors = ["jmatejcz "] +readme = "README.md" + +packages = [ + { include = "rai_bench", from = "." }, +] +[tool.poetry.dependencies] +python = "^3.10" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/src/rai_bench/rai_bench/__init__.py b/src/rai_bench/rai_bench/__init__.py new file mode 100644 index 000000000..97ceef6f0 --- /dev/null +++ b/src/rai_bench/rai_bench/__init__.py @@ -0,0 +1,13 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/rai_bench/rai_bench/benchmark_model.py b/src/rai_bench/rai_bench/benchmark_model.py new file mode 100644 index 000000000..83f4c1f66 --- /dev/null +++ b/src/rai_bench/rai_bench/benchmark_model.py @@ -0,0 +1,300 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import logging +import time +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, List, Union + +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from rai.messages import HumanMultimodalMessage +from rclpy.impl.rcutils_logger import RcutilsLogger + +from rai_sim.simulation_bridge import ( + Pose, + SimulationBridge, + SimulationConfig, + SimulationConfigT, + SpawnedEntity, +) + +loggers_type = Union[RcutilsLogger, logging.Logger] + + +class EntitiesMismatchException(Exception): + pass + + +class Task(ABC): + """ + Abstract of a Task. Provides utility functions for common calculations + that can be helfull when creating metrics. + Specific child classes should implement: + - get_prompt method + - validate_config + - calculate_result + """ + + def __init__( + self, + logger: loggers_type | None = None, + ) -> None: + if logger: + self.logger = logger + else: + self.logger = logging.getLogger(__name__) + + @abstractmethod + def get_prompt(self) -> str: + """Returns the task instruction - the prompt that will be passed to agent""" + pass + + @abstractmethod + def validate_config(self, simulation_config: SimulationConfig) -> bool: + """Task should be able to verify if given config is suitable for specific task + + Args: + simulation_config (SimulationConfig): initial scene setup + Returns: + bool: True is suitable, False otherwise + """ + pass + + @abstractmethod + def calculate_result( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> float: + """ + Calculates result of the task, based on info retrieved from simulation. + Should return score between 0.0 and 1. + """ + pass + + def filter_entities_by_prefab_type( + self, entities: List[SpawnedEntity], prefab_types: List[str] + ) -> List[SpawnedEntity]: + """Filter and return only these entities that match provided prefab types""" + return [ent for ent in entities if ent.prefab_name in prefab_types] + + def euclidean_distance(self, pos1: Pose, pos2: Pose) -> float: + """Calculate euclidean distance between 2 positions""" + return ( + (pos1.translation.x - pos2.translation.x) ** 2 + + (pos1.translation.y - pos2.translation.y) ** 2 + + (pos1.translation.z - pos2.translation.z) ** 2 + ) ** 0.5 + + def is_adjacent(self, pos1: Pose, pos2: Pose, threshold_distance: float): + """ + Check if positions are adjacent to each other, the threshold_distance is a distance + in simulation, refering to how close they have to be to classify them as adjacent + """ + self.logger.debug( # type: ignore + f"Euclidean distance: {self.euclidean_distance(pos1, pos2)}, pos1: {pos1}, pos2: {pos2}" + ) + return self.euclidean_distance(pos1, pos2) < threshold_distance + + def is_adjacent_to_any( + self, pos1: Pose, positions: List[Pose], threshold_distance: float + ) -> bool: + """ + Check if given position is adjacent to any position in the given list. + """ + + return any( + self.is_adjacent(pos1, pos2, threshold_distance) for pos2 in positions + ) + + def count_adjacent(self, positions: List[Pose], threshold_distance: float) -> int: + """ + Count how many adjacent positions are in the given list. + Note that position has to be adjacent to only 1 other position + to be counted, not all of them + """ + adjacent_count = 0 + + for i, p1 in enumerate(positions): + for j, p2 in enumerate(positions): + if i != j: + if self.is_adjacent(p1, p2, threshold_distance): + adjacent_count += 1 + break + + return adjacent_count + + +class Scenario(Generic[SimulationConfigT]): + """ + A Scenarios are defined by a pair of Task and Simlation Config. + Each Scenario is executed separatly by a Benchmark. + """ + + def __init__( + self, + task: Task, + simulation_config: SimulationConfigT, + simulation_config_path: str, + ) -> None: + if not task.validate_config(simulation_config): + raise ValueError("This scene is invalid for this task.") + self.task = task + self.simulation_config = simulation_config + # NOTE (jm) needed for logging which config was used, + # there probably is better method to do it + self.simulation_config_path = simulation_config_path + + +class Benchmark: + """ + Benchmark represents a set of Scenarios to be executed and evaluated. + It manages the execution, logs results, and provides functionality + for tracking and exporting performance metrics. + """ + + def __init__( + self, + simulation_bridge: SimulationBridge[SimulationConfigT], + scenarios: List[Scenario[SimulationConfigT]], + logger: loggers_type | None = None, + results_filename: str = "benchmark_results.csv", + ) -> None: + self.simulation_bridge = simulation_bridge + self.num_of_scenarios = len(scenarios) + self.scenarios = enumerate(iter(scenarios)) + self.results: List[Dict[str, Any]] = [] + self.results_filename = results_filename + if logger: + self._logger = logger + else: + self._logger = logging.getLogger(__name__) + + self.fieldnames = [ + "task", + "simulation_config", + "final_score", + "total_time", + "number_of_tool_calls", + ] + self._initialize_results_file() + + @classmethod + def create_scenarios( + cls, + tasks: List[Task], + simulation_configs: List[SimulationConfigT], + simulation_configs_paths: List[str], + ) -> List[Scenario[SimulationConfigT]]: + # TODO (jm) hacky_fix, taking paths as args here, not the best solution, + # but more changes to code would be required + scenarios: List[Scenario[SimulationConfigT]] = [] + + for task in tasks: + for sim_conf, sim_path in zip(simulation_configs, simulation_configs_paths): + try: + scenarios.append( + Scenario( + task=task, + simulation_config=sim_conf, + simulation_config_path=sim_path, + ) + ) + except ValueError as e: + print( + f"Could not create Scenario from task: {task.get_prompt()} and simulation_config: {sim_conf}, {e}" + ) + return scenarios + + def _initialize_results_file(self): + """Initialize the CSV file with headers.""" + with open( + self.results_filename, mode="w", newline="", encoding="utf-8" + ) as file: + writer = csv.DictWriter(file, fieldnames=self.fieldnames) + writer.writeheader() + + def run_next(self, agent) -> None: + """ + Runs the next scenario + """ + try: + i, scenario = next(self.scenarios) # Get the next scenario + + self.simulation_bridge.setup_scene(scenario.simulation_config) + self._logger.info( # type: ignore + "======================================================================================" + ) + self._logger.info( # type: ignore + f"RUNNING SCENARIO NUMBER {i + 1} / {self.num_of_scenarios}, TASK: {scenario.task.get_prompt()}" + ) + tool_calls_num = 0 + + ts = time.perf_counter() + for state in agent.stream( + {"messages": [HumanMessage(content=scenario.task.get_prompt())]} + ): + graph_node_name = list(state.keys())[0] + msg = state[graph_node_name]["messages"][-1] + + if isinstance(msg, HumanMultimodalMessage): + last_msg = msg.text + elif isinstance(msg, BaseMessage): + if isinstance(msg.content, list): + if len(msg.content) == 1: + if type(msg.content[0]) is dict: + last_msg = msg.content[0].get("text", "") + else: + last_msg = msg.content + self._logger.debug(f"{graph_node_name}: {last_msg}") # type: ignore + + else: + raise ValueError(f"Unexpected type of message: {type(msg)}") + + if isinstance(msg, AIMessage): + # TODO (jm) figure out more robust way of counting tool calls + tool_calls_num += len(msg.tool_calls) + + self._logger.info(f"AI Message: {msg}") # type: ignore + + te = time.perf_counter() + + result = scenario.task.calculate_result(self.simulation_bridge) + total_time = te - ts + self._logger.info( # type: ignore + f"TASK SCORE: {result}, TOTAL TIME: {total_time:.3f}, NUM_OF_TOOL_CALLS: {tool_calls_num}" + ) + + scenario_result: Dict[str, Any] = { + "task": scenario.task.get_prompt(), + "simulation_config": scenario.simulation_config_path, + "final_score": result, + "total_time": f"{total_time:.3f}", + "number_of_tool_calls": tool_calls_num, + } + self.results.append(scenario_result) + self._save_scenario_result_to_csv(scenario_result) + + except StopIteration: + print("No more scenarios left to run.") + + def _save_scenario_result_to_csv(self, result: Dict[str, Any]) -> None: + """Save a single scenario result to the CSV file.""" + with open( + self.results_filename, mode="a", newline="", encoding="utf-8" + ) as file: + writer = csv.DictWriter(file, fieldnames=self.fieldnames) + writer.writerow(result) + + def get_results(self) -> List[Dict[str, Any]]: + return self.results diff --git a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py new file mode 100644 index 000000000..6068806d7 --- /dev/null +++ b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py @@ -0,0 +1,194 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +########### EXAMPLE USAGE ########### +import logging +import time +from datetime import datetime +from pathlib import Path +from typing import List + +import rclpy +from langchain.tools import BaseTool +from rai.agents.conversational_agent import create_conversational_agent +from rai.communication.ros2.connectors import ROS2ARIConnector +from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool +from rai.tools.ros2.topics import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool +from rai.utils.model_initialization import get_llm_model +from rai_open_set_vision.tools import GetGrabbingPointTool + +from rai_bench.benchmark_model import Benchmark, Task +from rai_bench.o3de_test_bench.tasks import GrabCarrotTask, PlaceCubesTask +from rai_sim.o3de.o3de_bridge import ( + O3DEngineArmManipulationBridge, + O3DExROS2SimulationConfig, + Pose, +) +from rai_sim.simulation_bridge import Rotation, Translation + +if __name__ == "__main__": + rclpy.init() + connector = ROS2ARIConnector() + node = connector.node + node.declare_parameter("conversion_ratio", 1.0) + + # define model + llm = get_llm_model(model_type="complex_model", streaming=True) + + system_prompt = """ + You are a robotic arm with interfaces to detect and manipulate objects. + Here are the coordinates information: + x - front to back (positive is forward) + y - left to right (positive is right) + z - up to down (positive is up) + Before starting the task, make sure to grab the camera image to understand the environment. + """ + # define tools + tools: List[BaseTool] = [ + GetObjectPositionsTool( + connector=connector, + target_frame="panda_link0", + source_frame="RGBDCamera5", + camera_topic="/color_image5", + depth_topic="/depth_image5", + camera_info_topic="/color_camera_info5", + get_grabbing_point_tool=GetGrabbingPointTool(connector=connector), + ), + MoveToPointTool(connector=connector, manipulator_frame="panda_link0"), + GetROS2ImageTool(connector=connector), + GetROS2TopicsNamesAndTypesTool(connector=connector), + ] + # define loggers + now = datetime.now() + experiment_dir = ( + f"src/rai_bench/rai_bench/experiments/{now.strftime('%Y-%m-%d_%H-%M-%S')}" + ) + Path(experiment_dir).mkdir(parents=True, exist_ok=True) + log_file = f"{experiment_dir}/benchmark.log" + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + file_handler.setFormatter(formatter) + + bench_logger = logging.getLogger("Benchmark logger") + bench_logger.setLevel(logging.INFO) + bench_logger.addHandler(file_handler) + + agent_logger = logging.getLogger("Agent logger") + agent_logger.setLevel(logging.INFO) + agent_logger.addHandler(file_handler) + + configs_dir = "src/rai_bench/rai_bench/o3de_test_bench/configs/" + connector_path = configs_dir + "o3de_config.yaml" + #### Create scenarios manually + # load different scenes + # one_carrot_simulation_config = O3DExROS2SimulationConfig.load_config( + # base_config_path=Path(configs_dir + "scene1.yaml"), + # connector_config_path=Path(connector_path), + # ) + # multiple_carrot_simulation_config = O3DExROS2SimulationConfig.load_config( + # base_config_path=Path(configs_dir + "scene2.yaml"), + # connector_config_path=Path(connector_path), + # ) + # red_cubes_simulation_config = O3DExROS2SimulationConfig.load_config( + # base_config_path=Path(configs_dir + "scene3.yaml"), + # connector_config_path=Path(connector_path), + # ) + # multiple_cubes_simulation_config = O3DExROS2SimulationConfig.load_config( + # base_config_path=Path(configs_dir + "scene4.yaml"), + # connector_config_path=Path(connector_path), + # ) + # # combine different scene configs with the tasks to create various scenarios + # scenarios = [ + # Scenario( + # task=GrabCarrotTask(logger=bench_logger), + # simulation_config=one_carrot_simulation_config, + # simulation_config_path=configs_dir + "scene1.yaml", + # ), + # Scenario( + # task=GrabCarrotTask(logger=bench_logger), + # simulation_config=multiple_carrot_simulation_config, + # simulation_config_path=configs_dir + "scene2.yaml", + # ), + # Scenario( + # task=PlaceCubesTask(logger=bench_logger), + # simulation_config=red_cubes_simulation_config, + # simulation_config_path=configs_dir + "scene3.yaml", + # ), + # Scenario( + # task=PlaceCubesTask(logger=bench_logger), + # simulation_config=multiple_cubes_simulation_config, + # simulation_config_path=configs_dir + "scene4.yaml", + # ), + # ] + + ### Create scenarios automatically + simulation_configs_paths = [ + configs_dir + "scene1.yaml", + configs_dir + "scene2.yaml", + configs_dir + "scene3.yaml", + configs_dir + "scene4.yaml", + ] + simulations_configs = [ + O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) + for path in simulation_configs_paths + ] + tasks: List[Task] = [ + GrabCarrotTask(logger=bench_logger), + PlaceCubesTask(logger=bench_logger), + ] + scenarios = Benchmark.create_scenarios( + tasks=tasks, + simulation_configs=simulations_configs, + simulation_configs_paths=simulation_configs_paths, + ) + + # custom request to arm + base_arm_pose = Pose( + translation=Translation(x=0.1, y=0.5, z=0.4), + rotation=Rotation(x=1.0, y=0.0, z=0.0, w=0.0), + ) + + o3de = O3DEngineArmManipulationBridge(connector, logger=agent_logger) + # define benchamrk + results_filename = f"{experiment_dir}/results.csv" + benchmark = Benchmark( + simulation_bridge=o3de, + scenarios=scenarios, + logger=bench_logger, + results_filename=results_filename, + ) + for i, s in enumerate(scenarios): + agent = create_conversational_agent( + llm, tools, system_prompt, logger=agent_logger + ) + benchmark.run_next(agent=agent) + o3de.move_arm( + pose=base_arm_pose, + initial_gripper_state=True, + final_gripper_state=False, + frame_id="panda_link0", + ) # return to case position + time.sleep(0.2) # admire the end position for a second ;) + + bench_logger.info("===============================================================") + bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") + bench_logger.info("===============================================================") + + connector.shutdown() + o3de.shutdown() + rclpy.shutdown() diff --git a/src/rai_bench/rai_bench/o3de_test_bench/__init__.py b/src/rai_bench/rai_bench/o3de_test_bench/__init__.py new file mode 100644 index 000000000..97ceef6f0 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/__init__.py @@ -0,0 +1,13 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml new file mode 100644 index 000000000..a683362a6 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml @@ -0,0 +1,25 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: corn2 + prefab_name: corn + pose: + translation: + x: 0.5 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml new file mode 100644 index 000000000..2be04e047 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml @@ -0,0 +1,51 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: carrot2 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: 0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: carrot3 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: carrot4 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: -0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml new file mode 100644 index 000000000..1eef69a48 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml @@ -0,0 +1,25 @@ +entities: + - name: cube1 + prefab_name: red_cube + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: cube2 + prefab_name: red_cube + pose: + translation: + x: 0.5 + y: 0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml new file mode 100644 index 000000000..00b814c93 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml @@ -0,0 +1,50 @@ +entities: + - name: cube1 + prefab_name: red_cube + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: cube2 + prefab_name: blue_cube + pose: + translation: + x: 0.4 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: cube3 + prefab_name: yellow_cube + pose: + translation: + x: 0.5 + y: -0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: cube4 + prefab_name: yellow_cube + pose: + translation: + x: 0.5 + y: 0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py new file mode 100644 index 000000000..5be82bf8c --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py @@ -0,0 +1,17 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from rai_bench.o3de_test_bench.tasks.grab_carrot_task import GrabCarrotTask +from rai_bench.o3de_test_bench.tasks.place_cubes_task import PlaceCubesTask + +__all__ = ["GrabCarrotTask", "PlaceCubesTask"] diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py new file mode 100644 index 000000000..d011051d1 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py @@ -0,0 +1,110 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Tuple + +from rai_bench.benchmark_model import EntitiesMismatchException, Task +from rai_sim.o3de.o3de_bridge import ( + SimulationBridge, +) +from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT, SpawnedEntity + + +class GrabCarrotTask(Task): + obj_types = ["carrot"] + + # TODO (jm) extract common logic to some parent manipulation task + def get_prompt(self) -> str: + return "Manipulate objects, so that all carrots to the left side of the table (positive y)" + + def validate_config(self, simulation_config: SimulationConfig) -> bool: + for ent in simulation_config.entities: + if ent.prefab_name in self.obj_types: + return True + + return False + + def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]: + """Calculate how many objects are positioned correct and incorrect""" + correct = sum(1 for ent in entities if ent.pose.translation.y > 0.0) + incorrect: int = len(entities) - correct + return correct, incorrect + + def calculate_initial_placements( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> tuple[int, int]: + """ + Calculates the number of objects that are correctly and incorrectly placed initially. + """ + initial_carrots = self.filter_entities_by_prefab_type( + simulation_bridge.spawned_entities, prefab_types=self.obj_types + ) + initially_correct, initially_incorrect = self.calculate_correct( + entities=initial_carrots + ) + + self.logger.info( # type: ignore + f"Initially correctly placed carrots: {initially_correct}, Initially incorrectly placed carrots: {initially_incorrect}" + ) + return initially_correct, initially_incorrect + + def calculate_final_placements( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> tuple[int, int]: + """ + Calculates the number of objects that are correctly and incorrectly placed at the end of the simulation. + """ + scene_state = simulation_bridge.get_scene_state() + final_carrots = self.filter_entities_by_prefab_type( + scene_state.entities, prefab_types=self.obj_types + ) + final_correct, final_incorrect = self.calculate_correct(entities=final_carrots) + + self.logger.info( # type: ignore + f"Finally correctly placed carrots: {final_correct}, Finally incorrectly placed carrots: {final_incorrect}" + ) + return final_correct, final_incorrect + + def calculate_result( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> float: + """ + Calculates a score from 0.0 to 1.0, where 0.0 represents the initial placements or worse and 1.0 represents perfect final placements. + """ + initially_correct, initially_incorrect = self.calculate_initial_placements( + simulation_bridge + ) + final_correct, final_incorrect = self.calculate_final_placements( + simulation_bridge + ) + + total_objects = initially_correct + initially_incorrect + if total_objects == 0: + return 1.0 + elif (initially_correct + initially_incorrect) != ( + final_correct + final_incorrect + ): + raise EntitiesMismatchException( + "number of initial entities does not match final entities number." + ) + elif initially_incorrect == 0: + pass + # NOTE all objects are placed correctly + # no point in running task + raise ValueError("All objects are placed correctly at the start.") + else: + corrected = final_correct - initially_correct + score = max(0.0, corrected / initially_incorrect) + + self.logger.info(f"Calculated score: {score:.2f}") # type: ignore + return score diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py new file mode 100644 index 000000000..e7547c49e --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py @@ -0,0 +1,117 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Tuple + +from rai_bench.benchmark_model import ( + EntitiesMismatchException, + Task, +) +from rai_sim.o3de.o3de_bridge import SimulationBridge +from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT, SpawnedEntity + + +class PlaceCubesTask(Task): + # TODO (jm) extract common logic to some parent manipulation task + obj_types = ["red_cube", "blue_cube", "yellow_cube"] + + def get_prompt(self) -> str: + return "Manipulate objects, so that all cubes are adjacent to at least one cube" + + def validate_config(self, simulation_config: SimulationConfig) -> bool: + cubes_num = 0 + for ent in simulation_config.entities: + if ent.prefab_name in self.obj_types: + cubes_num += 1 + if cubes_num > 1: + return True + + return False + + def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]: + """Calculate how many objects are positioned correct and incorrect""" + correct = sum( + 1 + for ent in entities + if self.is_adjacent_to_any( + ent.pose, [e.pose for e in entities if e != ent], 0.15 + ) + ) + incorrect: int = len(entities) - correct + return correct, incorrect + + def calculate_initial_placements( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> tuple[int, int]: + """ + Calculates the number of objects that are correctly and incorrectly placed initially. + """ + initial_cubes = self.filter_entities_by_prefab_type( + simulation_bridge.spawned_entities, prefab_types=self.obj_types + ) + initially_correct, initially_incorrect = self.calculate_correct( + entities=initial_cubes + ) + + self.logger.info( # type: ignore + f"Initially correctly placed cubes: {initially_correct}, Initially incorrectly placed cubes: {initially_incorrect}" + ) + return initially_correct, initially_incorrect + + def calculate_final_placements( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> tuple[int, int]: + """ + Calculates the number of objects that are correctly and incorrectly placed at the end of the simulation. + """ + scene_state = simulation_bridge.get_scene_state() + final_cubes = self.filter_entities_by_prefab_type( + scene_state.entities, prefab_types=self.obj_types + ) + final_correct, final_incorrect = self.calculate_correct(entities=final_cubes) + + self.logger.info( # type: ignore + f"Finally correctly placed cubes: {final_correct}, Finally incorrectly placed cubes: {final_incorrect}" + ) + return final_correct, final_incorrect + + def calculate_result( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> float: + """ + Calculates a score from 0.0 to 1.0, where 0.0 represents the initial placements or worse and 1.0 represents perfect final placements. + """ + initially_correct, initially_incorrect = self.calculate_initial_placements( + simulation_bridge + ) + final_correct, final_incorrect = self.calculate_final_placements( + simulation_bridge + ) + + total_objects = initially_correct + initially_incorrect + if total_objects == 0: + return 1.0 + elif (initially_correct + initially_incorrect) != ( + final_correct + final_incorrect + ): + raise EntitiesMismatchException( + "number of initial entities does not match final entities number." + ) + elif initially_incorrect == 0: + raise ValueError("All objects are placed correctly at the start.") + else: + corrected = final_correct - initially_correct + score = max(0.0, corrected / initially_incorrect) + + self.logger.info(f"Calculated score: {score:.2f}") # type: ignore + return score diff --git a/src/rai_core/rai/agents/conversational_agent.py b/src/rai_core/rai/agents/conversational_agent.py index 072a1c164..739b159ae 100644 --- a/src/rai_core/rai/agents/conversational_agent.py +++ b/src/rai_core/rai/agents/conversational_agent.py @@ -56,7 +56,7 @@ def create_conversational_agent( debug=False, ): _logger = None - if isinstance(logger, RcutilsLogger): + if logger: _logger = logger else: _logger = logging.getLogger(__name__) diff --git a/src/rai_core/rai/communication/ros2/connectors.py b/src/rai_core/rai/communication/ros2/connectors.py index 452dae22e..fb7db7d3d 100644 --- a/src/rai_core/rai/communication/ros2/connectors.py +++ b/src/rai_core/rai/communication/ros2/connectors.py @@ -69,6 +69,8 @@ def __init__( self._topic_api = ROS2TopicAPI(self._node) self._service_api = ROS2ServiceAPI(self._node) self._actions_api = ROS2ActionAPI(self._node) + self._tf_buffer = Buffer(node=self._node) + self.tf_listener = TransformListener(self._tf_buffer, self._node) self._executor = MultiThreadedExecutor() self._executor.add_node(self._node) @@ -179,33 +181,36 @@ def get_transform( source_frame: str, timeout_sec: float = 5.0, ) -> TransformStamped: - tf_buffer = Buffer(node=self._node) - tf_listener = TransformListener(tf_buffer, self._node) transform_available = self.wait_for_transform( - tf_buffer, target_frame, source_frame, timeout_sec + self._tf_buffer, target_frame, source_frame, timeout_sec ) if not transform_available: raise LookupException( f"Could not find transform from {source_frame} to {target_frame} in {timeout_sec} seconds" ) - transform: TransformStamped = tf_buffer.lookup_transform( + transform: TransformStamped = self._tf_buffer.lookup_transform( target_frame, source_frame, rclpy.time.Time(), - timeout=Duration(seconds=timeout_sec), + timeout=Duration(seconds=int(timeout_sec)), ) - tf_listener.unregister() + return transform def terminate_action(self, action_handle: str, **kwargs: Any): self._actions_api.terminate_goal(action_handle) + @property + def node(self) -> Node: + return self._node + def shutdown(self): - self._executor.shutdown() - self._thread.join() + self.tf_listener.unregister() + self._node.destroy_node() self._actions_api.shutdown() self._topic_api.shutdown() - self._node.destroy_node() + self._executor.shutdown() + self._thread.join() class ROS2HRIMessage(HRIMessage): @@ -279,15 +284,19 @@ def __init__( ] _targets = [ - target - if isinstance(target, tuple) - else (target, TopicConfig(is_subscriber=False)) + ( + target + if isinstance(target, tuple) + else (target, TopicConfig(is_subscriber=False)) + ) for target in targets ] _sources = [ - source - if isinstance(source, tuple) - else (source, TopicConfig(is_subscriber=True)) + ( + source + if isinstance(source, tuple) + else (source, TopicConfig(is_subscriber=True)) + ) for source in sources ] diff --git a/src/rai_core/rai/tools/ros/manipulation.py b/src/rai_core/rai/tools/ros/manipulation.py index 8deacd603..9436490b0 100644 --- a/src/rai_core/rai/tools/ros/manipulation.py +++ b/src/rai_core/rai/tools/ros/manipulation.py @@ -15,21 +15,15 @@ from typing import Literal, Type import numpy as np -import rclpy -import rclpy.callback_groups -import rclpy.executors -import rclpy.qos -import rclpy.subscription -import rclpy.task from geometry_msgs.msg import Point, Pose, PoseStamped, Quaternion from langchain_core.tools import BaseTool from pydantic import BaseModel, Field from rai_open_set_vision.tools import GetGrabbingPointTool -from rclpy.client import Client -from rclpy.node import Node from tf2_geometry_msgs import do_transform_pose +from rai.communication.ros2.connectors import ROS2ARIConnector from rai.tools.utils import TF2TransformFetcher +from rai.utils.ros_async import get_future_result from rai_interfaces.srv import ManipulatorMoveTo @@ -52,8 +46,7 @@ class MoveToPointTool(BaseTool): "success of grabbing or releasing objects. Use additional sensors or tools for that information." ) - node: Node - client: Client + connector: ROS2ARIConnector = Field(..., exclude=True) manipulator_frame: str = Field(..., description="Manipulator frame") min_z: float = Field(default=0.135, description="Minimum z coordinate [m]") @@ -72,16 +65,6 @@ class MoveToPointTool(BaseTool): args_schema: Type[MoveToPointToolInput] = MoveToPointToolInput - def __init__(self, node: Node, **kwargs): - super().__init__( - node=node, - client=node.create_client( - ManipulatorMoveTo, - "/manipulator_move_to", - ), - **kwargs, - ) - def _run( self, x: float, @@ -89,6 +72,10 @@ def _run( z: float, task: Literal["grab", "drop"], ) -> str: + client = self.connector.node.create_client( + ManipulatorMoveTo, + "/manipulator_move_to", + ) pose_stamped = PoseStamped() pose_stamped.header.frame_id = self.manipulator_frame pose_stamped.pose = Pose( @@ -117,21 +104,18 @@ def _run( request.initial_gripper_state = False # closed request.final_gripper_state = True # open - future = self.client.call_async(request) - self.node.get_logger().debug( + future = client.call_async(request) + self.connector.node.get_logger().debug( f"Calling ManipulatorMoveTo service with request: x={request.target_pose.pose.position.x:.2f}, y={request.target_pose.pose.position.y:.2f}, z={request.target_pose.pose.position.z:.2f}" ) + response = get_future_result(future, timeout_sec=5.0) + if response is None: + return f"Service call failed for point ({x:.2f}, {y:.2f}, {z:.2f})." - rclpy.spin_until_future_complete(self.node, future, timeout_sec=5.0) - - if future.result() is not None: - response = future.result() - if response.success: - return f"End effector successfully positioned at coordinates ({x:.2f}, {y:.2f}, {z:.2f}). Note: The status of object interaction (grab/drop) is not confirmed by this movement." - else: - return f"Failed to position end effector at coordinates ({x:.2f}, {y:.2f}, {z:.2f})." + if response.success: + return f"End effector successfully positioned at coordinates ({x:.2f}, {y:.2f}, {z:.2f}). Note: The status of object interaction (grab/drop) is not confirmed by this movement." else: - return f"Service call failed for point ({x:.2f}, {y:.2f}, {z:.2f})." + return f"Failed to position end effector at coordinates ({x:.2f}, {y:.2f}, {z:.2f})." class GetObjectPositionsToolInput(BaseModel): @@ -153,14 +137,9 @@ class GetObjectPositionsTool(BaseTool): camera_topic: str # rgb camera topic depth_topic: str camera_info_topic: str # rgb camera info topic - node: Node + connector: ROS2ARIConnector = Field(..., exclude=True) get_grabbing_point_tool: GetGrabbingPointTool - def __init__(self, node: Node, **kwargs): - super(GetObjectPositionsTool, self).__init__( - node=node, get_grabbing_point_tool=GetGrabbingPointTool(node=node), **kwargs - ) - args_schema: Type[GetObjectPositionsToolInput] = GetObjectPositionsToolInput @staticmethod diff --git a/src/rai_core/rai/tools/ros/utils.py b/src/rai_core/rai/tools/ros/utils.py index 8c34e23c4..1e207e31c 100644 --- a/src/rai_core/rai/tools/ros/utils.py +++ b/src/rai_core/rai/tools/ros/utils.py @@ -151,7 +151,9 @@ def wait_for_message( if msg_info is not None: return True, msg_info[0] finally: - node.destroy_subscription(sub) + # TODO(boczekbartek): uncomment when rclpy resolves: https://github.com/ros2/rclpy/issues/1142 + # node.destroy_subscription(sub) + pass return False, None diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/examples/talker.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/examples/talker.py index 370c6e488..72fe82198 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/examples/talker.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/examples/talker.py @@ -28,7 +28,9 @@ def __init__(self): self.declare_parameter("image_path", "") self.cli = self.create_client(RAIGroundingDino, "grounding_dino_classify") while not self.cli.wait_for_service(timeout_sec=1.0): - self.get_logger().info("service not available, waiting again...") + self.get_logger().info( + "service grounding_dino_classify not available, waiting again..." + ) self.req = RAIGroundingDino.Request() self.bridge = CvBridge() @@ -56,7 +58,9 @@ def __init__(self): super().__init__(node_name="GSClientExample", parameter_overrides=[]) self.cli = self.create_client(RAIGroundedSam, "grounded_sam_segment") while not self.cli.wait_for_service(timeout_sec=1.0): - self.get_logger().info("service not available, waiting again...") + self.get_logger().info( + "service grounded_sam_segment not available, waiting again..." + ) self.req = RAIGroundedSam.Request() self.bridge = CvBridge() diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py index 45fe9eb52..85fa70488 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py @@ -34,9 +34,6 @@ class GSamService(Node): def __init__(self): super().__init__(node_name=GSAM_NODE_NAME, parameter_overrides=[]) - self.srv = self.create_service( - RAIGroundedSam, GSAM_SERVICE_NAME, self.segment_callback - ) self.declare_parameter("weights_path", "") try: @@ -49,6 +46,10 @@ def __init__(self): self.get_logger().error("Could not load model") raise Exception("Could not load model") + self.srv = self.create_service( + RAIGroundedSam, GSAM_SERVICE_NAME, self.segment_callback + ) + def _init_weight_path(self): try: found_path = get_package_share_directory("rai_open_set_vision") diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py index 7927d7cf8..eba1ee44a 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py @@ -43,9 +43,7 @@ class GDRequest(TypedDict): class GDinoService(Node): def __init__(self): super().__init__(node_name=GDINO_NODE_NAME, parameter_overrides=[]) - self.srv = self.create_service( - RAIGroundingDino, GDINO_SERVICE_NAME, self.classify_callback - ) + self.declare_parameter("weights_path", "") try: weight_path = self.get_parameter("weights_path").value @@ -57,6 +55,10 @@ def __init__(self): self.get_logger().error("Could not load model") raise Exception("Could not load model") + self.srv = self.create_service( + RAIGroundingDino, GDINO_SERVICE_NAME, self.classify_callback + ) + def _init_weight_path(self) -> Path: try: found_path = get_package_share_directory("rai_open_set_vision") diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py index 336ec4bc7..62656c5a8 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py @@ -15,14 +15,11 @@ from typing import List, NamedTuple, Type import numpy as np -import rclpy -import rclpy.qos import sensor_msgs.msg from pydantic import BaseModel, Field -from rai.node import RaiBaseNode +from rai.communication.ros2.connectors import ROS2ARIConnector from rai.tools.ros import Ros2BaseInput, Ros2BaseTool from rai.tools.ros.utils import convert_ros_img_to_ndarray -from rai.tools.utils import wait_for_message from rai.utils.ros_async import get_future_result from rclpy.exceptions import ( ParameterNotDeclaredException, @@ -82,7 +79,7 @@ class DistanceMeasurement(NamedTuple): # --------------------- Tools --------------------- class GroundingDinoBaseTool(Ros2BaseTool): - node: RaiBaseNode = Field(..., exclude=True, required=True) + connector: ROS2ARIConnector = Field(..., exclude=True) box_threshold: float = Field(default=0.35, description="Box threshold for GDINO") text_threshold: float = Field(default=0.45, description="Text threshold for GDINO") @@ -90,9 +87,11 @@ class GroundingDinoBaseTool(Ros2BaseTool): def _call_gdino_node( self, camera_img_message: sensor_msgs.msg.Image, object_names: list[str] ) -> Future: - cli = self.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) + cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) while not cli.wait_for_service(timeout_sec=1.0): - self.node.get_logger().info("service not available, waiting again...") + self.node.get_logger().info( + f"service {GDINO_SERVICE_NAME} not available, waiting again..." + ) req = RAIGroundingDino.Request() req.source_img = camera_img_message req.classes = " , ".join(object_names) @@ -103,20 +102,16 @@ def _call_gdino_node( return future def get_img_from_topic(self, topic: str, timeout_sec: int = 2): - success, msg = wait_for_message( - sensor_msgs.msg.Image, - self.node, - topic, - qos_profile=rclpy.qos.qos_profile_sensor_data, - time_to_wait=timeout_sec, - ) - - if success: - self.node.get_logger().info(f"Received message of type from topic {topic}") + msg = self.connector.receive_message(topic, timeout_sec=timeout_sec).payload + + if msg is not None: + self.connector.node.get_logger().info( + f"Received message of {type(msg)} from topic {topic}" + ) return msg else: error = f"No message received in {timeout_sec} seconds from topic {topic}" - self.node.get_logger().error(error) + self.connector.node.get_logger().error(error) return error def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py index 29ff0fe18..043802d8f 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py @@ -18,9 +18,10 @@ import numpy as np import rclpy import sensor_msgs.msg +from langchain_core.tools import BaseTool from pydantic import Field -from rai.node import RaiBaseNode -from rai.tools.ros import Ros2BaseInput, Ros2BaseTool +from rai.communication.ros2.connectors import ROS2ARIConnector +from rai.tools.ros import Ros2BaseInput from rai.tools.ros.utils import convert_ros_img_to_base64, convert_ros_img_to_ndarray from rai.utils.ros_async import get_future_result from rclpy import Future @@ -64,8 +65,8 @@ class GetGrabbingPointInput(Ros2BaseInput): # --------------------- Tools --------------------- -class GetSegmentationTool(Ros2BaseTool): - node: RaiBaseNode = Field(..., exclude=True) +class GetSegmentationTool: + connector: ROS2ARIConnector = Field(..., exclude=True) name: str = "" description: str = "" @@ -84,7 +85,7 @@ def _get_gsam_response(self, future: Future) -> Optional[RAIGroundedSam.Response return get_future_result(future) def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: - msg = self.node.get_raw_message_from_topic(topic) + msg = self.connector.receive_message(topic).payload if type(msg) is sensor_msgs.msg.Image: return msg else: @@ -93,9 +94,11 @@ def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: def _call_gdino_node( self, camera_img_message: sensor_msgs.msg.Image, object_name: str ) -> Future: - cli = self.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) + cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) while not cli.wait_for_service(timeout_sec=1.0): - self.node.get_logger().info("service not available, waiting again...") + self.node.get_logger().info( + f"service {GDINO_SERVICE_NAME} not available, waiting again..." + ) req = RAIGroundingDino.Request() req.source_img = camera_img_message req.classes = object_name @@ -108,9 +111,11 @@ def _call_gdino_node( def _call_gsam_node( self, camera_img_message: sensor_msgs.msg.Image, data: RAIGroundingDino.Response ): - cli = self.node.create_client(RAIGroundedSam, "grounded_sam_segment") + cli = self.connector.node.create_client(RAIGroundedSam, "grounded_sam_segment") while not cli.wait_for_service(timeout_sec=1.0): - self.node.get_logger().info("service not available, waiting again...") + self.node.get_logger().info( + "service grounded_sam_segment not available, waiting again..." + ) req = RAIGroundedSam.Request() req.detections = data.detections req.source_img = camera_img_message @@ -126,9 +131,11 @@ def _run( camera_img_msg = self._get_image_message(camera_topic) future = self._call_gdino_node(camera_img_msg, object_name) - logger = self.node.get_logger() + logger = self.connector.node.get_logger() try: - conversion_ratio = self.node.get_parameter("conversion_ratio").value + conversion_ratio = self.connector.node.get_parameter( + "conversion_ratio" + ).value if not isinstance(conversion_ratio, float): logger.error( f"Parameter conversion_ratio was set badly: {type(conversion_ratio)}: {conversion_ratio} expected float. Using default value 0.001" @@ -185,19 +192,72 @@ def depth_to_point_cloud( return points -class GetGrabbingPointTool(GetSegmentationTool): +class GetGrabbingPointTool(BaseTool): + connector: ROS2ARIConnector = Field(..., exclude=True) + name: str = "GetGrabbingPointTool" description: str = "Get the grabbing point of an object" pcd: List[Any] = [] args_schema: Type[GetGrabbingPointInput] = GetGrabbingPointInput + box_threshold: float = Field(default=0.35, description="Box threshold for GDINO") + text_threshold: float = Field(default=0.45, description="Text threshold for GDINO") + + def _get_gdino_response( + self, future: Future + ) -> Optional[RAIGroundingDino.Response]: + return get_future_result(future) + + def _get_gsam_response(self, future: Future) -> Optional[RAIGroundedSam.Response]: + return get_future_result(future) + + def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: + msg = self.connector.receive_message(topic).payload + if type(msg) is sensor_msgs.msg.Image: + return msg + else: + raise Exception("Received wrong message") + + def _call_gdino_node( + self, camera_img_message: sensor_msgs.msg.Image, object_name: str + ) -> Future: + cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) + while not cli.wait_for_service(timeout_sec=1.0): + self.connector.node.get_logger().info( + "service not available, waiting again..." + ) + req = RAIGroundingDino.Request() + req.source_img = camera_img_message + req.classes = object_name + req.box_threshold = self.box_threshold + req.text_threshold = self.text_threshold + + future = cli.call_async(req) + return future + + def _call_gsam_node( + self, camera_img_message: sensor_msgs.msg.Image, data: RAIGroundingDino.Response + ): + cli = self.connector.node.create_client(RAIGroundedSam, "grounded_sam_segment") + while not cli.wait_for_service(timeout_sec=1.0): + self.connector.node.get_logger().info( + "service not available, waiting again..." + ) + req = RAIGroundedSam.Request() + req.detections = data.detections + req.source_img = camera_img_message + future = cli.call_async(req) + + return future def _get_camera_info_message(self, topic: str) -> sensor_msgs.msg.CameraInfo: for _ in range(3): - msg = self.node.get_raw_message_from_topic(topic, timeout_sec=3.0) + msg = self.connector.receive_message(topic, timeout_sec=3.0).payload if isinstance(msg, sensor_msgs.msg.CameraInfo): return msg - self.node.get_logger().warn("Received wrong message type. Retrying...") + self.connector.node.get_logger().warn( + "Received wrong message type. Retrying..." + ) raise Exception("Failed to receive correct CameraInfo message after 3 attempts") @@ -259,16 +319,18 @@ def _run( camera_info_topic: str, object_name: str, ): - camera_img_msg = self._get_image_message(camera_topic) - depth_msg = self._get_image_message(depth_topic) + camera_img_msg = self.connector.receive_message(camera_topic).payload + depth_msg = self.connector.receive_message(depth_topic).payload camera_info = self._get_camera_info_message(camera_info_topic) intrinsic = self._get_intrinsic_from_camera_info(camera_info) future = self._call_gdino_node(camera_img_msg, object_name) - logger = self.node.get_logger() + logger = self.connector.node.get_logger() try: - conversion_ratio = self.node.get_parameter("conversion_ratio").value + conversion_ratio = self.connector.node.get_parameter( + "conversion_ratio" + ).value if not isinstance(conversion_ratio, float): logger.error( f"Parameter conversion_ratio was set badly: {type(conversion_ratio)}: {conversion_ratio} expected float. Using default value 0.001" @@ -280,21 +342,17 @@ def _run( ) conversion_ratio = 0.001 resolved = None - while rclpy.ok(): - resolved = self._get_gdino_response(future) - if resolved is not None: - break + + resolved = get_future_result(future) assert resolved is not None future = self._call_gsam_node(camera_img_msg, resolved) ret = [] - while rclpy.ok(): - resolved = self._get_gsam_response(future) - if resolved is not None: - for img_msg in resolved.masks: - ret.append(convert_ros_img_to_base64(img_msg)) - break + resolved = get_future_result(future) + if resolved is not None: + for img_msg in resolved.masks: + ret.append(convert_ros_img_to_base64(img_msg)) assert resolved is not None rets = [] for mask_msg in resolved.masks: diff --git a/src/rai_sim/README.md b/src/rai_sim/README.md new file mode 100644 index 000000000..8008b4d2d --- /dev/null +++ b/src/rai_sim/README.md @@ -0,0 +1,18 @@ +## RAI Sim + +## Description + +The RAI Sim is a package providing interface to implement connection with a specific simulation. + +### Components + +- `SimulationBridge` - An interface for connecting with a specific simulation. It manages scene setup, spawning, despawning objects, getting current state of the scene. + +- `SimulationConfig` - base config class to specify the entities to be spawned. For each simulation bridge there should be specified custom simulation config specifying additional parameters needed to run and connect with the simulation. + +- `SceneState` - stores the current info about spawned entities + +### Example implementation + +- `O3DExROS2Bridge` - An implementation of SimulationBridge for working with simulation based on O3DE and ROS2. +- `O3DExROS2SimulationConfig` - config class for `O3DExROS2Bridge` diff --git a/src/rai_sim/pyproject.toml b/src/rai_sim/pyproject.toml new file mode 100644 index 000000000..9042be051 --- /dev/null +++ b/src/rai_sim/pyproject.toml @@ -0,0 +1,23 @@ +[tool.poetry] +name = "rai-sim" +version = "0.0.1" +description = "Package to run simulations" +authors = ["Magdalena Kotynia ", "Kacper DÄ…browski ", "Jakub Matejczyk "] +readme = "README.md" +classifiers = [ + "Programming Language :: Python :: 3", + "Development Status :: 4 - Beta", + "License :: OSI Approved :: Apache Software License", +] +packages = [ + { include = "rai_sim", from = "." }, +] + +[tool.poetry.dependencies] +python = "^3.10, <3.13" +PyYAML = "*" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/src/rai_sim/rai_sim/__init__.py b/src/rai_sim/rai_sim/__init__.py new file mode 100644 index 000000000..f792966c5 --- /dev/null +++ b/src/rai_sim/rai_sim/__init__.py @@ -0,0 +1,15 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RAI Simulations package.""" diff --git a/src/rai_sim/rai_sim/o3de/o3de_bridge.py b/src/rai_sim/rai_sim/o3de/o3de_bridge.py new file mode 100644 index 000000000..14ba247d4 --- /dev/null +++ b/src/rai_sim/rai_sim/o3de/o3de_bridge.py @@ -0,0 +1,431 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import shlex +import signal +import subprocess +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +import yaml +from geometry_msgs.msg import Point, PoseStamped, Quaternion +from geometry_msgs.msg import Pose as ROS2Pose +from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage +from rai.utils.ros_async import get_future_result +from std_msgs.msg import Header +from tf2_geometry_msgs import do_transform_pose + +from rai_interfaces.srv import ManipulatorMoveTo +from rai_sim.simulation_bridge import ( + Entity, + Pose, + Rotation, + SceneState, + SimulationBridge, + SimulationConfig, + SpawnedEntity, + Translation, +) + + +class O3DExROS2SimulationConfig(SimulationConfig): + binary_path: Path + robotic_stack_command: str + required_simulation_ros2_interfaces: dict[str, List[str]] + required_robotic_ros2_interfaces: dict[str, List[str]] + + @classmethod + def load_config( + cls, base_config_path: Path, connector_config_path: Path + ) -> "O3DExROS2SimulationConfig": + base_config = SimulationConfig.load_base_config(base_config_path) + + with open(connector_config_path) as f: + connector_content: dict[str, Any] = yaml.safe_load(f) + return cls(**base_config.model_dump(), **connector_content) + + +class O3DExROS2Bridge(SimulationBridge[O3DExROS2SimulationConfig]): + def __init__( + self, connector: ROS2ARIConnector, logger: Optional[logging.Logger] = None + ): + super().__init__(logger=logger) + self.connector = connector + self.current_sim_process = None + self.current_robotic_stack_process = None + self.current_binary_path = None + + def shutdown(self): + self._shutdown_binary() + self._shutdown_robotic_stack() + + def _shutdown_binary(self): + if not self.current_sim_process: + return + self.current_sim_process.send_signal(signal.SIGINT) + self.current_sim_process.wait() + + if self.current_sim_process.poll() is None: + self.logger.error( + f"Parent process PID: {self.current_sim_process.pid} is still running." + ) + raise RuntimeError( + f"Failed to terminate main process PID {self.current_sim_process.pid}" + ) + + self.current_sim_process = None + + def _shutdown_robotic_stack(self): + if not self.current_robotic_stack_process: + return + + self.current_robotic_stack_process.send_signal(signal.SIGINT) + self.current_robotic_stack_process.wait() + + if self.current_robotic_stack_process.poll() is None: + self.logger.error( + f"Parent process PID: {self.current_robotic_stack_process.pid} is still running." + ) + raise RuntimeError( + f"Failed to terminate robotic stack process PID {self.current_robotic_stack_process.pid}" + ) + + def get_available_spawnable_names(self) -> list[str]: + msg = ROS2ARIMessage({}) + response = self._try_service_call( + msg, + target="get_available_spawnable_names", + msg_type="gazebo_msgs/srv/GetWorldProperties", + ) + if response.payload.success: + return response.payload.model_names + else: + raise RuntimeError( + f"Failed to get available spawnable names. Response: {response.payload.status_message}" + ) + + def _spawn_entity(self, entity: Entity): + pose = do_transform_pose( + self._to_ros2_pose(entity.pose), + self.connector.get_transform("odom", "world"), + ) + + msg_content: Dict[str, Any] = { + "name": entity.prefab_name, + "xml": "", + "robot_namespace": entity.name, + "initial_pose": { + "position": { + "x": pose.position.x, # type: ignore + "y": pose.position.y, # type: ignore + "z": pose.position.z, # type: ignore + }, + "orientation": { + "x": pose.orientation.x, # type: ignore + "y": pose.orientation.y, # type: ignore + "z": pose.orientation.z, # type: ignore + "w": pose.orientation.w, # type: ignore + }, + }, + } + + msg = ROS2ARIMessage(payload=msg_content) + response = self._try_service_call( + msg, target="spawn_entity", msg_type="gazebo_msgs/srv/SpawnEntity" + ) + if response and response.payload.success: + self.spawned_entities.append( + SpawnedEntity(id=response.payload.status_message, **entity.model_dump()) + ) + else: + raise RuntimeError( + f"Failed to spawn entity {entity.name}. Response: {response.payload.status_message}" + ) + + def _despawn_entity(self, entity: SpawnedEntity): + msg_content = {"name": entity.id} + + msg = ROS2ARIMessage(payload=msg_content) + + response = self._try_service_call( + msg, target="delete_entity", msg_type="gazebo_msgs/srv/DeleteEntity" + ) + if response.payload.success: + self.spawned_entities.remove(entity) + else: + raise RuntimeError( + f"Failed to delete entity {entity.name}. Response: {response.payload.status_message}" + ) + + def get_object_pose(self, entity: SpawnedEntity) -> Pose: + object_frame = entity.name + "/" + ros2_pose = do_transform_pose( + ROS2Pose(), + self.connector.get_transform(object_frame + "odom", object_frame), + ) + ros2_pose = do_transform_pose( + ros2_pose, self.connector.get_transform("world", "odom") + ) + return self._from_ros2_pose(ros2_pose) + + def get_scene_state(self) -> SceneState: + """ + Get the current scene state. + """ + if not self.current_sim_process: + raise RuntimeError("Simulation is not running.") + entities: list[SpawnedEntity] = [] + for entity in self.spawned_entities: + current_pose = self.get_object_pose(entity) + entities.append( + SpawnedEntity( + id=entity.id, + name=entity.name, + prefab_name=entity.prefab_name, + pose=current_pose, + ) + ) + return SceneState(entities=entities) + + def _is_ros2_stack_ready( + self, required_ros2_stack: dict[str, List[str]], retries: int = 120 + ) -> bool: + for i in range(retries): + available_topics = self.connector.get_topics_names_and_types() + available_services = self.connector.node.get_service_names_and_types() + available_topics_names = [tp[0] for tp in available_topics] + available_services_names = [srv[0] for srv in available_services] + + # Extract action names + available_actions_names: Set[str] = set() + for service in available_services_names: + if "/_action" in service: + action_name = service.split("/_action")[0] + available_actions_names.add(action_name) + + required_services = required_ros2_stack["services"] + required_topics = required_ros2_stack["topics"] + required_actions = required_ros2_stack["actions"] + self.logger.info(f"required services: {required_services}") + self.logger.info(f"required topics: {required_topics}") + self.logger.info(f"required actions: {required_actions}") + self.logger.info(f"available actions: {available_actions_names}") + + missing_services = [ + service + for service in required_services + if service not in available_services_names + ] + missing_topics = [ + topic + for topic in required_topics + if topic not in available_topics_names + ] + missing_actions = [ + action + for action in required_actions + if action not in available_actions_names + ] + + if missing_services: + self.logger.warning( + f"Waiting for missing services {missing_services} out of required services: {required_services}" + ) + + if missing_topics: + self.logger.warning( + f"Waiting for missing topics: {missing_topics} out of required topics: {required_topics}" + ) + + if missing_actions: + self.logger.warning( + f"Waiting for missing actions: {missing_actions} out of required actions: {required_actions}" + ) + + if not (missing_services or missing_topics or missing_actions): + self.logger.info("All required ROS2 stack components are available.") + return True + + time.sleep(0.5) + + self.logger.error( + "Maximum number of retries reached. Required ROS2 stack components are not fully available." + ) + return False + + def setup_scene(self, simulation_config: O3DExROS2SimulationConfig): + if self.current_binary_path != simulation_config.binary_path: + if self.current_sim_process: + self.shutdown() + self._launch_binary(simulation_config) + self._launch_robotic_stack(simulation_config) + self.current_binary_path = simulation_config.binary_path + + else: + while self.spawned_entities: + self._despawn_entity(self.spawned_entities[0]) + self.logger.info(f"Entities after despawn: {self.spawned_entities}") + + for entity in simulation_config.entities: + self._spawn_entity(entity) + + def _launch_binary(self, simulation_config: O3DExROS2SimulationConfig): + command = [simulation_config.binary_path.as_posix()] + self.logger.info(f"Running command: {command}") + self.current_sim_process = subprocess.Popen( + command, + ) + if not self._has_process_started(process=self.current_sim_process): + raise RuntimeError("Process did not start in time.") + if not self._is_ros2_stack_ready( + required_ros2_stack=simulation_config.required_simulation_ros2_interfaces + ): + raise RuntimeError("ROS2 stack is not ready in time.") + + def _launch_robotic_stack(self, simulation_config: O3DExROS2SimulationConfig): + command = shlex.split(simulation_config.robotic_stack_command) + self.logger.info(f"Running command: {command}") + self.current_robotic_stack_process = subprocess.Popen( + command, + ) + if not self._has_process_started(self.current_robotic_stack_process): + raise RuntimeError("Process did not start in time.") + if not self._is_ros2_stack_ready( + required_ros2_stack=simulation_config.required_robotic_ros2_interfaces + ): + raise RuntimeError("ROS2 stack is not ready in time.") + + def _has_process_started(self, process: subprocess.Popen[Any], timeout: int = 15): + start_time = time.time() + while time.time() - start_time < timeout: + if process.poll() is None: + self.logger.info(f"Process started with PID {process.pid}") + return True + time.sleep(1) + return False + + def _try_service_call( + self, msg: ROS2ARIMessage, target: str, msg_type: str, n_retries: int = 3 + ) -> ROS2ARIMessage: + if n_retries < 1: + raise ValueError("Number of retries must be at least 1") + for _ in range(n_retries): + try: + response = self.connector.service_call( + msg, target=target, msg_type=msg_type + ) + except Exception as e: + error_message = f"Error while calling service {target} with msg_type {msg_type}: {e}" + self.logger.error(error_message) + raise RuntimeError(error_message) + if response.payload.success: + return response + self.logger.warning( + f"Retrying {target}, response success: {response.payload.success}" + ) + return response # type: ignore + + # NOTE (mkotynia) probably to be refactored, other bridges may also want to use pose conversion to/from ROS2 format + def _to_ros2_pose(self, pose: Pose) -> ROS2Pose: + """ + Converts pose to pose in ROS2 Pose format. + """ + position = Point( + x=pose.translation.x, y=pose.translation.y, z=pose.translation.z + ) + + if pose.rotation is not None: + orientation = Quaternion( + x=pose.rotation.x, + y=pose.rotation.y, + z=pose.rotation.z, + w=pose.rotation.w, + ) + else: + orientation = Quaternion() + + ros2_pose = ROS2Pose(position=position, orientation=orientation) + + return ros2_pose + + def _from_ros2_pose(self, pose: ROS2Pose) -> Pose: + """ + Converts ROS2Pose to Pose + """ + + translation = Translation( + x=pose.position.x, # type: ignore + y=pose.position.y, # type: ignore + z=pose.position.z, # type: ignore + ) + + rotation = Rotation( + x=pose.orientation.x, # type: ignore + y=pose.orientation.y, # type: ignore + z=pose.orientation.z, # type: ignore + w=pose.orientation.w, # type: ignore + ) + + return Pose(translation=translation, rotation=rotation) + + +class O3DEngineArmManipulationBridge(O3DExROS2Bridge): + def move_arm( + self, + pose: Pose, + initial_gripper_state: bool, + final_gripper_state: bool, + frame_id: str, + ): + """Moves arm to a given position + + Args: + pose (Pose): where to move arm + initial_gripper_state (bool): False means closed grip, True means open grip + final_gripper_state (bool): False means closed grip, True means open grip + frame_id (str): reference frame + """ + + request = ManipulatorMoveTo.Request() + request.initial_gripper_state = initial_gripper_state + request.final_gripper_state = final_gripper_state + + request.target_pose = PoseStamped() + request.target_pose.header = Header() + request.target_pose.header.frame_id = frame_id + + request.target_pose.pose.position.x = pose.translation.x + request.target_pose.pose.position.y = pose.translation.y + request.target_pose.pose.position.z = pose.translation.z + + if pose.rotation: + request.target_pose.pose.orientation.x = pose.rotation.x + request.target_pose.pose.orientation.y = pose.rotation.y + request.target_pose.pose.orientation.z = pose.rotation.z + request.target_pose.pose.orientation.w = pose.rotation.w + + client = self.connector.node.create_client( + ManipulatorMoveTo, + "/manipulator_move_to", + ) + while not client.wait_for_service(timeout_sec=5.0): + self.connector.node.get_logger().info("Service not available, waiting...") + + self.connector.node.get_logger().info("Making request to arm manipulator...") + future = client.call_async(request) + result = get_future_result(future, timeout_sec=5.0) + + self.connector.node.get_logger().debug(f"Moving arm result: {result}") diff --git a/src/rai_sim/rai_sim/simulation_bridge.py b/src/rai_sim/rai_sim/simulation_bridge.py new file mode 100644 index 000000000..336406ad0 --- /dev/null +++ b/src/rai_sim/rai_sim/simulation_bridge.py @@ -0,0 +1,272 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Generic, List, Optional, TypeVar + +import yaml +from pydantic import BaseModel, Field, field_validator + + +class Translation(BaseModel): + """ + Represents the position of an object in 3D space using + x, y, and z coordinates. + """ + + x: float = Field(description="X coordinate in meters") + y: float = Field(description="Y coordinate in meters") + z: float = Field(description="Z coordinate in meters") + + +class Rotation(BaseModel): + """ + Represents a 3D rotation using quaternion representation. + """ + + x: float = Field(description="X component of the quaternion") + y: float = Field(description="Y component of the quaternion") + z: float = Field(description="Z component of the quaternion") + w: float = Field(description="W component of the quaternion") + + +class Pose(BaseModel): + """ + Represents the complete pose (position and orientation) of an object. + """ + + translation: Translation = Field( + description="The position of the object in 3D space" + ) + rotation: Optional[Rotation] = Field( + default=None, + description="The orientation of the object as a quaternion. Optional if orientation is not needed and default orientation is handled by the bridge", + ) + + +class Entity(BaseModel): + """ + Entity that can be spawned in the simulation environment. + """ + + name: str = Field(description="Unique name for the entity") + prefab_name: str = Field( + description="Name of the prefab resource to use for spawning this entity" + ) + pose: Pose = Field(description="Initial pose of the entity") + + +class SpawnedEntity(Entity): + """ + Entity that has been spawned in the simulation environment. + """ + + id: str = Field( + description="Unique identifier assigned to the spawned entity instance" + ) + + +class SimulationConfig(BaseModel): + """ + Setup of simulation - arrangement of objects in the environment. + + Attributes + ---------- + entities : List[Entity] + List of entities to be spawned in the simulation. + """ + + entities: List[Entity] = Field( + description="List of entities to be spawned in the simulation environment" + ) + + @field_validator("entities") + @classmethod + def check_unique_names(cls, entities: List[Entity]) -> List[Entity]: + """ + Validates that all entity names in the configuration are unique. + + Parameters + ---------- + entities : List[Entity] + List of entities to validate. + + Returns + ------- + List[Entity] + The validated list of entities. + + Raises + ------ + ValueError + If any entity names are duplicated. + """ + names = [entity.name for entity in entities] + if len(names) != len(set(names)): + raise ValueError("Each entity must have a unique name.") + return entities + + @classmethod + def load_base_config(cls, base_config_path: Path) -> "SimulationConfig": + """ + Loads a simulation configuration from a YAML file. + + Parameters + ---------- + base_config_path : Path + Path to the YAML configuration file. + + Returns + ------- + SimulationConfig + The loaded simulation configuration. + """ + with open(base_config_path) as f: + content = yaml.safe_load(f) + return cls(**content) + + +class SceneState(BaseModel): + """ + Info about current state of the scene. + + Attributes + ---------- + entities : List[SpawnedEntity] + List of all entities currently present in the scene. + """ + + entities: List[SpawnedEntity] = Field( + description="List of all entities currently spawned in the scene with their current poses" + ) + + +SimulationConfigT = TypeVar("SimulationConfigT", bound=SimulationConfig) + + +class SimulationBridge(ABC, Generic[SimulationConfigT]): + """ + Responsible for communication with simulation. + """ + + def __init__(self, logger: Optional[logging.Logger] = None): + self.spawned_entities: List[ + SpawnedEntity + ] = [] # list of spawned entities with their initial poses + if logger is None: + self.logger = logging.getLogger(__name__) + else: + self.logger = logger + + @abstractmethod + def setup_scene(self, simulation_config: SimulationConfigT): + """ + Runs and sets up the simulation scene according to the provided configuration. + + Parameters + ---------- + simulation_config : SimulationConfigT + Configuration containing the simulation initialization and setup details including + entities to be spawned and their initial poses. + + Returns + ------- + None + """ + pass + + @abstractmethod + def _spawn_entity(self, entity: Entity): + """ + Spawns a single entity in the simulation environment. + + Parameters + ---------- + entity : Entity + Entity object containing the entity's properties + + Returns + ------- + None + + Notes + ----- + The spawned entity should be added to the spawned_entities list maintained + by the simulation bridge. + """ + pass + + @abstractmethod + def _despawn_entity(self, entity: SpawnedEntity): + """ + Removes a previously spawned entity from the simulation environment. + + Parameters + ---------- + entity : SpawnedEntity + Entity object representing the spawned entity to be removed. + + Returns + ------- + None + + Notes + ----- + Despawned entity should be removed from the spawned_entities list maintained + by the simulation bridge. + """ + pass + + @abstractmethod + def get_object_pose(self, entity: SpawnedEntity) -> Pose: + """ + Gets the current pose of a spawned entity. + + This method queries the simulation to get the current position and + orientation of a specific entity. + + Parameters + ---------- + entity : SpawnedEntity + Entity object representing the spawned entity whose pose is + to be retrieved. + + Returns + ------- + Pose + Object containing the entity's current pose. + """ + pass + + @abstractmethod + def get_scene_state(self) -> SceneState: + """ + Gets the current state of the simulation scene. + + Parameters + ---------- + None + + Returns + ------- + SceneState + Object containing the current state of the scene. + + Notes + ----- + SceneState should contain the current poses of spawned_entities. + """ + pass diff --git a/tests/rai_sim/conftest.py b/tests/rai_sim/conftest.py new file mode 100644 index 000000000..4ebb264a9 --- /dev/null +++ b/tests/rai_sim/conftest.py @@ -0,0 +1,75 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import pytest + + +@pytest.fixture +def sample_base_yaml_config(tmp_path: Path) -> Path: + yaml_content = """ + entities: + - name: entity1 + prefab_name: cube + pose: + translation: + x: 1.0 + y: 2.0 + z: 3.0 + + - name: entity2 + prefab_name: carrot + pose: + translation: + x: 1.0 + y: 2.0 + z: 3.0 + rotation: + x: 0.1 + y: 0.2 + z: 0.3 + w: 0.4 + """ + file_path = tmp_path / "test_config.yaml" + file_path.write_text(yaml_content) + return file_path + + +@pytest.fixture +def sample_o3dexros2_config(tmp_path: Path) -> Path: + yaml_content = """ + binary_path: /path/to/binary + robotic_stack_command: "ros2 launch robotic_stack.launch.py" + required_simulation_ros2_interfaces: + services: + - /spawn_entity + - /delete_entity + topics: + - /color_image5 + - /depth_image5 + - /color_camera_info5 + actions: [] + required_robotic_ros2_interfaces: + services: + - /grounding_dino_classify + - /grounded_sam_segment + - /manipulator_move_to + topics: [] + actions: + - /execute_trajectory + """ + file_path = tmp_path / "test_o3dexros2_config.yaml" + file_path.write_text(yaml_content) + return file_path diff --git a/tests/rai_sim/test_o3de_bridge.py b/tests/rai_sim/test_o3de_bridge.py new file mode 100644 index 000000000..b7687304a --- /dev/null +++ b/tests/rai_sim/test_o3de_bridge.py @@ -0,0 +1,379 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import signal +import typing +import unittest +from pathlib import Path +from typing import List, Optional, Tuple, get_args, get_origin +from unittest.mock import MagicMock, patch + +import rclpy +import rclpy.qos +from geometry_msgs.msg import Point, Quaternion, TransformStamped +from geometry_msgs.msg import Pose as ROS2Pose +from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage +from rclpy.node import Node +from rclpy.qos import QoSProfile + +from rai_sim.o3de.o3de_bridge import O3DExROS2Bridge, O3DExROS2SimulationConfig +from rai_sim.simulation_bridge import ( + Entity, + Pose, + Rotation, + SpawnedEntity, + Translation, +) + + +def test_load_config(sample_base_yaml_config: Path, sample_o3dexros2_config: Path): + config = O3DExROS2SimulationConfig.load_config( + sample_base_yaml_config, sample_o3dexros2_config + ) + assert isinstance(config, O3DExROS2SimulationConfig) + assert config.binary_path == Path("/path/to/binary") + assert config.robotic_stack_command == "ros2 launch robotic_stack.launch.py" + assert config.required_simulation_ros2_interfaces == { + "services": ["/spawn_entity", "/delete_entity"], + "topics": ["/color_image5", "/depth_image5", "/color_camera_info5"], + "actions": [], + } + assert config.required_robotic_ros2_interfaces == { + "services": [ + "/grounding_dino_classify", + "/grounded_sam_segment", + "/manipulator_move_to", + ], + "topics": [], + "actions": ["/execute_trajectory"], + } + assert isinstance(config.entities, list) + assert all(isinstance(e, Entity) for e in config.entities) + + assert len(config.entities) == 2 + + +class TestO3DExROS2Bridge(unittest.TestCase): + def setUp(self): + self.mock_connector = MagicMock(spec=ROS2ARIConnector) + self.mock_logger = MagicMock() + self.bridge = O3DExROS2Bridge( + connector=self.mock_connector, logger=self.mock_logger + ) + + # Create test data + self.test_entity = Entity( + name="test_entity1", + prefab_name="cube", + pose=Pose( + translation=Translation(x=1.0, y=2.0, z=3.0), + rotation=Rotation(x=0.0, y=0.0, z=0.0, w=1.0), + ), + ) + + self.test_spawned_entity = SpawnedEntity( + id="entity_id_123", + name="test_entity1", + prefab_name="cube", + pose=Pose( + translation=Translation(x=1.0, y=2.0, z=3.0), + rotation=Rotation(x=0.0, y=0.0, z=0.0, w=1.0), + ), + ) + + self.test_config = O3DExROS2SimulationConfig( + binary_path=Path("/path/to/binary"), + robotic_stack_command="ros2 launch robot.launch.py", + entities=[self.test_entity], + required_simulation_ros2_interfaces={ + "services": [], + "topics": [], + "actions": [], + }, + required_robotic_ros2_interfaces={ + "services": [], + "topics": [], + "actions": [], + }, + ) + + def test_init(self): + self.assertEqual(self.bridge.connector, self.mock_connector) + self.assertEqual(self.bridge.logger, self.mock_logger) + self.assertIsNone(self.bridge.current_sim_process) + self.assertIsNone(self.bridge.current_robotic_stack_process) + self.assertIsNone(self.bridge.current_binary_path) + self.assertEqual(self.bridge.spawned_entities, []) + + @patch("subprocess.Popen") + def test_launch_robotic_stack(self, mock_popen): + mock_process = MagicMock() + mock_process.poll.return_value = None + mock_process.pid = 54321 + mock_popen.return_value = mock_process + self.bridge._launch_robotic_stack(self.test_config) + + mock_popen.assert_called_once_with(["ros2", "launch", "robot.launch.py"]) + self.assertEqual(self.bridge.current_robotic_stack_process, mock_process) + + @patch("subprocess.Popen") + def test_launch_binary(self, mock_popen): + mock_process = MagicMock() + mock_process.poll.return_value = None + mock_process.pid = 54322 + mock_popen.return_value = mock_process + + self.bridge._launch_binary(self.test_config) + + mock_popen.assert_called_once_with(["/path/to/binary"]) + self.assertEqual(self.bridge.current_sim_process, mock_process) + + def test_shutdown_binary(self): + mock_process = MagicMock() + mock_process.poll.return_value = 0 + + self.bridge.current_sim_process = mock_process + + self.bridge._shutdown_binary() + + mock_process.send_signal.assert_called_once_with(signal.SIGINT) + mock_process.wait.assert_called_once() + + self.assertIsNone(self.bridge.current_sim_process) + + def test_shutdown_robotic_stack(self): + self.bridge.current_robotic_stack_process = MagicMock() + self.bridge.current_robotic_stack_process.poll.return_value = 0 + + self.bridge._shutdown_robotic_stack() + + self.bridge.current_robotic_stack_process.send_signal.assert_called_once_with( + signal.SIGINT + ) + self.bridge.current_robotic_stack_process.wait.assert_called_once() + + def test_get_available_spawnable_names(self): + # Mock the response + response = MagicMock() + response.payload.model_names = ["cube", "carrot"] + self.bridge._try_service_call = MagicMock(return_value=response) + + names = self.bridge.get_available_spawnable_names() + + self.bridge._try_service_call.assert_called_once() + self.assertEqual(names, ["cube", "carrot"]) + + def test_to_ros2_pose(self): + # Create a pose + pose = Pose( + translation=Translation(x=1.0, y=2.0, z=3.0), + rotation=Rotation(x=0.1, y=0.2, z=0.3, w=0.4), + ) + + # Convert to ROS2 pose + ros2_pose = self.bridge._to_ros2_pose(pose) + + # Check the conversion + self.assertEqual(ros2_pose.position.x, 1.0) + self.assertEqual(ros2_pose.position.y, 2.0) + self.assertEqual(ros2_pose.position.z, 3.0) + self.assertEqual(ros2_pose.orientation.x, 0.1) + self.assertEqual(ros2_pose.orientation.y, 0.2) + self.assertEqual(ros2_pose.orientation.z, 0.3) + self.assertEqual(ros2_pose.orientation.w, 0.4) + + def test_from_ros2_pose(self): + # Create a ROS2 pose + position = Point(x=1.0, y=2.0, z=3.0) + orientation = Quaternion(x=0.1, y=0.2, z=0.3, w=0.4) + ros2_pose = ROS2Pose(position=position, orientation=orientation) + + # Convert from ROS2Pose to Pose + pose = self.bridge._from_ros2_pose(ros2_pose) + + # Check the conversion + self.assertEqual(pose.translation.x, 1.0) + self.assertEqual(pose.translation.y, 2.0) + self.assertEqual(pose.translation.z, 3.0) + self.assertEqual(pose.rotation.x, 0.1) + self.assertEqual(pose.rotation.y, 0.2) + self.assertEqual(pose.rotation.z, 0.3) + self.assertEqual(pose.rotation.w, 0.4) + + +class TestROS2ARIConnectorInterface(unittest.TestCase): + """Tests to ensure the ROS2ARIConnector interface meets the expectations of O3DExROS2Bridge.""" + + def setUp(self): + rclpy.init() + self.connector = ROS2ARIConnector() + + def tearDown(self): + rclpy.shutdown() + + def test_connector_required_methods_exist(self): + """Test that all required methods exist on the ROS2ARIConnector.""" + connector = ROS2ARIConnector() + + # Check that all required methods exist + self.assertTrue( + hasattr(connector, "service_call"), "service_call method is missing" + ) + self.assertTrue( + hasattr(connector, "get_transform"), "get_transform method is missing" + ) + self.assertTrue( + hasattr(connector, "send_message"), "send_message method is missing" + ) + self.assertTrue( + hasattr(connector, "receive_message"), "receive_message method is missing" + ) + self.assertTrue(hasattr(connector, "shutdown"), "shutdown method is missing") + self.assertTrue( + hasattr(connector, "get_topics_names_and_types"), + "get_topics_names_and_types method is missing", + ) + self.assertTrue( + hasattr(connector, "node"), + "node property is missing", + ) + + def resolve_annotation(self, annotation: type) -> type: + """Helper function to unwrap Optional types. Workaround for problem with asserting Optional types.""" + if get_origin(annotation) is typing.Optional: + return get_args(annotation)[0] + return annotation + + def test_get_transform_signature(self): + signature = inspect.signature(self.connector.get_transform) + parameters = signature.parameters + + expected_params: dict[str, type] = { + "target_frame": str, + "source_frame": str, + "timeout_sec": float, + } + + assert list(parameters.keys()) == list(expected_params.keys()), ( + f"Parameter names do not match, expected: {list(expected_params.keys())}, got: {list(parameters.keys())}" + ) + + for param_name, expected_type in expected_params.items(): + param = parameters[param_name] + self.assertEqual( + self.resolve_annotation(param.annotation), + self.resolve_annotation(expected_type), + f"Parameter '{param_name}' has incorrect type, expected: {expected_type}, got: {param.annotation}", + ) + + # Check return type explicitly + assert signature.return_annotation is TransformStamped, ( + f"Return type is incorrect, expected: TransformStamped, got: {signature.return_annotation}" + ) + + def test_send_message_signature(self): + signature = inspect.signature(self.connector.send_message) + parameters = signature.parameters + + expected_params: dict[str, type] = { + "message": ROS2ARIMessage, + "target": str, + "msg_type": str, + "auto_qos_matching": bool, + "qos_profile": Optional[QoSProfile], + } + + self.assertListEqual( + list(parameters.keys())[: len(expected_params)], + list(expected_params.keys()), + f"Parameter names do not match, expected: {list(expected_params.keys())}, got: {list(parameters.keys())}", + ) + + for param_name, expected_type in expected_params.items(): + param = parameters[param_name] + self.assertEqual( + self.resolve_annotation(param.annotation), + self.resolve_annotation(expected_type), + f"Parameter '{param_name}' has incorrect type, expected: {expected_type}, got: {param.annotation}", + ) + + self.assertIs( + signature.return_annotation, + inspect.Signature.empty, + "send_message should have no return value", + ) + + def test_receive_message_signature(self): + signature = inspect.signature(self.connector.receive_message) + parameters = signature.parameters + + expected_params: dict[str, type] = { + "source": str, + "timeout_sec": float, + "msg_type": Optional[str], + "auto_topic_type": bool, + } + + self.assertListEqual( + list(parameters.keys())[: len(expected_params)], + list(expected_params.keys()), + f"Parameter names do not match, expected: {list(expected_params.keys())}, got: {list(parameters.keys())}", + ) + + for param_name, expected_type in expected_params.items(): + param = parameters[param_name] + self.assertEqual( + self.resolve_annotation(param.annotation), + self.resolve_annotation(expected_type), + f"Parameter '{param_name}' has incorrect type, expected: {expected_type}, got: {param.annotation}", + ) + + self.assertIs( + signature.return_annotation, + ROS2ARIMessage, + f"Return type is incorrect, expected: ROS2ARIMessage, got: {signature.return_annotation}", + ) + + def test_get_topics_names_and_types_signature(self): + signature = inspect.signature(self.connector.get_topics_names_and_types) + parameters = signature.parameters + + expected_params: dict[str, type] = {} + + assert list(parameters.keys()) == list(expected_params.keys()), ( + f"Parameter names do not match, expected: {list(expected_params.keys())}, got: {list(parameters.keys())}" + ) + + for param_name, expected_type in expected_params.items(): + param = parameters[param_name] + self.assertEqual( + self.resolve_annotation(param.annotation), + self.resolve_annotation(expected_type), + f"Parameter '{param_name}' has incorrect type, expected: {expected_type}, got: {param.annotation}", + ) + + self.assertEqual( + signature.return_annotation, + List[Tuple[str, List[str]]], + f"Return type is incorrect, expected: List[Tuple[str, List[str]]], got: {signature.return_annotation}", + ) + + def test_node_property(self): + """Test that the node property returns the expected Node instance.""" + mock_node = MagicMock(spec=Node) + self.connector._node = mock_node + + self.assertEqual(self.connector.node, mock_node) + self.assertIsInstance(self.connector.node, Node) diff --git a/tests/rai_sim/test_simulation_bridge.py b/tests/rai_sim/test_simulation_bridge.py new file mode 100644 index 000000000..39f91ab33 --- /dev/null +++ b/tests/rai_sim/test_simulation_bridge.py @@ -0,0 +1,325 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import unittest +from pathlib import Path +from typing import Optional + +import pytest +from pydantic import ValidationError + +from rai_sim.simulation_bridge import ( + Entity, + Pose, + Rotation, + SceneState, + SimulationBridge, + SimulationConfig, + SpawnedEntity, + Translation, +) + + +# Helper Functions +def create_translation(x: float, y: float, z: float) -> Translation: + return Translation(x=x, y=y, z=z) + + +def create_rotation(x: float, y: float, z: float, w: float) -> Rotation: + return Rotation(x=x, y=y, z=z, w=w) + + +def create_pose(translation: Translation, rotation: Optional[Rotation] = None) -> Pose: + return Pose(translation=translation, rotation=rotation) + + +# Test Cases +def test_translation(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + + assert isinstance(translation.x, float) + assert isinstance(translation.y, float) + assert isinstance(translation.z, float) + + assert translation.x == 1.1 + assert translation.y == 2.2 + assert translation.z == 3.3 + + +def test_rotation(): + rotation = Rotation(x=0.1, y=0.2, z=0.3, w=0.4) + + assert isinstance(rotation.x, float) + assert isinstance(rotation.y, float) + assert isinstance(rotation.z, float) + assert isinstance(rotation.w, float) + + assert rotation.x == 0.1 + assert rotation.y == 0.2 + assert rotation.z == 0.3 + assert rotation.w == 0.4 + + +def test_pose(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = Rotation(x=0.1, y=0.2, z=0.3, w=0.4) + + pose = Pose(translation=translation, rotation=rotation) + + assert isinstance(pose.translation, Translation) + assert isinstance(pose.rotation, Rotation) + + assert pose.translation.x == 1.1 + assert pose.rotation.w == 0.4 + + +def test_optional_rotation(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + pose = create_pose(translation=translation) + + assert isinstance(pose.translation, Translation) + assert pose.rotation is None + + +def test_entity(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = create_rotation(x=0.1, y=0.2, z=0.3, w=0.4) + pose = create_pose(translation=translation, rotation=rotation) + + entity = Entity(name="test_cube", prefab_name="cube", pose=pose) + + assert isinstance(entity.name, str) + assert isinstance(entity.prefab_name, str) + assert isinstance(entity.pose, Pose) + + assert entity.name == "test_cube" + assert entity.prefab_name == "cube" + assert entity.pose == pose + + +def test_spawned_entity(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = create_rotation(x=0.1, y=0.2, z=0.3, w=0.4) + pose = create_pose(translation=translation, rotation=rotation) + + spawned_entity = SpawnedEntity( + name="test_cube", + prefab_name="cube", + pose=pose, + id="id_123", + ) + + assert isinstance(spawned_entity.name, str) + assert isinstance(spawned_entity.prefab_name, str) + assert isinstance(spawned_entity.pose, Pose) + assert isinstance(spawned_entity.id, str) + + assert spawned_entity.id == "id_123" + + +def test_simulation_config_unique_names(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = create_rotation(x=0.1, y=0.2, z=0.3, w=0.4) + pose = create_pose(translation=translation, rotation=rotation) + + entities = [ + Entity(name="entity1", prefab_name="cube", pose=pose), + Entity(name="entity2", prefab_name="carrot", pose=pose), + ] + + config = SimulationConfig(entities=entities) + + assert isinstance(config.entities, list) + assert all(isinstance(e, Entity) for e in config.entities) + + assert len(config.entities) == 2 + + +def test_simulation_config_duplicate_names(): + translation = create_translation(x=1.1, y=2.2, z=3.3) + rotation = create_rotation(x=0.1, y=0.2, z=0.3, w=0.4) + pose = create_pose(translation=translation, rotation=rotation) + + entities = [ + Entity(name="duplicate", prefab_name="cube", pose=pose), + Entity(name="duplicate", prefab_name="carrot", pose=pose), + ] + + with pytest.raises(ValidationError): + SimulationConfig(entities=entities) + + +def test_load_base_config(sample_base_yaml_config: Path): + config = SimulationConfig.load_base_config(sample_base_yaml_config) + + assert isinstance(config.entities, list) + assert all(isinstance(e, Entity) for e in config.entities) + + assert len(config.entities) == 2 + + +class MockSimulationBridge(SimulationBridge[SimulationConfig]): + """Mock implementation of SimulationBridge for testing.""" + + def setup_scene(self, simulation_config: SimulationConfig): + """Mock implementation of setup_scene.""" + for entity in simulation_config.entities: + self._spawn_entity(entity) + + def _spawn_entity(self, entity: Entity): + """Mock implementation of _spawn_entity.""" + spawned_entity = SpawnedEntity( + id=f"id_{entity.name}", + name=entity.name, + prefab_name=entity.prefab_name, + pose=entity.pose, + ) + self.spawned_entities.append(spawned_entity) + + def _despawn_entity(self, entity: SpawnedEntity): + """Mock implementation of _despawn_entity.""" + self.spawned_entities = [e for e in self.spawned_entities if e.id != entity.id] + + def get_object_pose(self, entity: SpawnedEntity) -> Pose: + """Mock implementation of get_object_pose.""" + for spawned_entity in self.spawned_entities: + if spawned_entity.id == entity.id: + return spawned_entity.pose + raise ValueError(f"Entity with id {entity.id} not found") + + def get_scene_state(self) -> SceneState: + """Mock implementation of get_scene_state.""" + return SceneState(entities=self.spawned_entities) + + +class TestSimulationBridge(unittest.TestCase): + def setUp(self): + """Set up test fixtures.""" + self.logger = logging.getLogger("test_logger") + self.bridge = MockSimulationBridge(logger=self.logger) + + # Create test entities + self.test_entity1: Entity = Entity( + name="test_entity1", + prefab_name="test_prefab1", + pose=Pose( + translation=Translation(x=1.0, y=2.0, z=3.0), + rotation=Rotation(x=0.0, y=0.0, z=0.0, w=1.0), + ), + ) + + self.test_entity2: Entity = Entity( + name="test_entity2", + prefab_name="test_prefab2", + pose=Pose(translation=Translation(x=4.0, y=5.0, z=6.0), rotation=None), + ) + + # Create a test configuration + self.test_config = SimulationConfig( + entities=[self.test_entity1, self.test_entity2] + ) + + def test_init(self): + # Test with provided logger + bridge = MockSimulationBridge(logger=self.logger) + self.assertEqual(bridge.logger, self.logger) + self.assertEqual(len(bridge.spawned_entities), 0) + + # Test with default logger + bridge = MockSimulationBridge() + self.assertIsNotNone(bridge.logger) + self.assertEqual(len(bridge.spawned_entities), 0) + + def test_setup_scene(self): + self.bridge.setup_scene(self.test_config) + + # Check if entities were spawned + self.assertEqual(len(self.bridge.spawned_entities), 2) + + # Check if the spawned entities have correct properties + self.assertEqual(self.bridge.spawned_entities[0].name, "test_entity1") + self.assertEqual(self.bridge.spawned_entities[0].prefab_name, "test_prefab1") + self.assertEqual(self.bridge.spawned_entities[0].pose.translation.x, 1.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.translation.y, 2.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.translation.z, 3.0) + assert self.bridge.spawned_entities[0].pose.rotation + self.assertEqual(self.bridge.spawned_entities[0].pose.rotation.x, 0.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.rotation.y, 0.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.rotation.z, 0.0) + self.assertEqual(self.bridge.spawned_entities[0].pose.rotation.w, 1.0) + + self.assertEqual(self.bridge.spawned_entities[1].name, "test_entity2") + self.assertEqual(self.bridge.spawned_entities[1].prefab_name, "test_prefab2") + self.assertEqual(self.bridge.spawned_entities[1].pose.translation.x, 4.0) + self.assertEqual(self.bridge.spawned_entities[1].pose.translation.y, 5.0) + self.assertEqual(self.bridge.spawned_entities[1].pose.translation.z, 6.0) + self.assertIsNone(self.bridge.spawned_entities[1].pose.rotation) + + def test_spawn_entity(self): + self.bridge._spawn_entity(self.test_entity1) # type: ignore + spawned_entity = self.bridge.spawned_entities[0] + # Check if entity was added to spawned_entities + self.assertEqual(len(self.bridge.spawned_entities), 1) + self.assertIsInstance(spawned_entity, SpawnedEntity) + + def test_despawn_entity(self): + # First spawn an entity + self.bridge._spawn_entity(self.test_entity1) # type: ignore + self.assertEqual(len(self.bridge.spawned_entities), 1) + + # Then despawn it + self.bridge._despawn_entity(self.bridge.spawned_entities[0]) # type: ignore + self.assertEqual(len(self.bridge.spawned_entities), 0) + + def test_get_object_pose(self): + # First spawn an entity + self.bridge._spawn_entity(self.test_entity1) # type: ignore + + # Get the pose + pose = self.bridge.get_object_pose(self.bridge.spawned_entities[0]) + + # Check if the pose matches + self.assertEqual(pose.translation.x, 1.0) + self.assertEqual(pose.translation.y, 2.0) + self.assertEqual(pose.translation.z, 3.0) + assert pose.rotation + self.assertEqual(pose.rotation.x, 0.0) + self.assertEqual(pose.rotation.y, 0.0) + self.assertEqual(pose.rotation.z, 0.0) + self.assertEqual(pose.rotation.w, 1.0) + + # Test for non-existent entity + non_existent_entity = SpawnedEntity( + id="non_existent", + name="non_existent", + prefab_name="non_existent", + pose=Pose(translation=Translation(x=0.0, y=0.0, z=0.0)), + ) + + with self.assertRaises(ValueError): + self.bridge.get_object_pose(non_existent_entity) + + def test_get_scene_state(self): + # First spawn some entities + self.bridge._spawn_entity(self.test_entity1) # type: ignore + self.bridge._spawn_entity(self.test_entity2) # type: ignore + + # Get the scene state + scene_state = self.bridge.get_scene_state() + + # Check if the scene state contains the correct entities + self.assertEqual(len(scene_state.entities), 2) + self.assertEqual(scene_state.entities[0].name, "test_entity1") + self.assertEqual(scene_state.entities[1].name, "test_entity2")