diff --git a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py index 44b18cf39..86528922f 100644 --- a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py +++ b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py @@ -21,6 +21,8 @@ import rclpy from langchain.tools import BaseTool +from rai_open_set_vision.tools import GetGrabbingPointTool + from rai.agents.conversational_agent import create_conversational_agent from rai.communication.ros2.connectors import ROS2ARIConnector from rai.tools.ros.manipulation import ( @@ -32,8 +34,6 @@ GetROS2TopicsNamesAndTypesTool, ) from rai.utils.model_initialization import get_llm_model -from rai_open_set_vision.tools import GetGrabbingPointTool - from rai_bench.benchmark_model import Benchmark from rai_bench.o3de_test_bench.scenarios import ( easy_scenarios, @@ -164,26 +164,31 @@ all_scenarios = t_scenarios + e_scenarios + m_scenarios + h_scenarios + vh_scenarios o3de = O3DEngineArmManipulationBridge(connector, logger=agent_logger) - # define benchamrk - results_filename = f"{experiment_dir}/results.csv" - benchmark = Benchmark( - simulation_bridge=o3de, - scenarios=all_scenarios, - logger=bench_logger, - results_filename=results_filename, - ) - for i in range(len(all_scenarios)): - agent = create_conversational_agent( - llm, tools, system_prompt, logger=agent_logger + try: + # define benchamrk + results_filename = f"{experiment_dir}/results.csv" + benchmark = Benchmark( + simulation_bridge=o3de, + scenarios=all_scenarios, + logger=bench_logger, + results_filename=results_filename, ) - benchmark.run_next(agent=agent) - o3de.reset_arm() - time.sleep(0.2) # admire the end position for a second ;) - - bench_logger.info("===============================================================") - bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") - bench_logger.info("===============================================================") + for i in range(len(all_scenarios)): + agent = create_conversational_agent( + llm, tools, system_prompt, logger=agent_logger + ) + benchmark.run_next(agent=agent) + o3de.reset_arm() + time.sleep(0.2) # admire the end position for a second ;) - connector.shutdown() - o3de.shutdown() - rclpy.shutdown() + bench_logger.info( + "===============================================================" + ) + bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") + bench_logger.info( + "===============================================================" + ) + finally: + connector.shutdown() + o3de.shutdown() + rclpy.shutdown() diff --git a/src/rai_sim/rai_sim/o3de/o3de_bridge.py b/src/rai_sim/rai_sim/o3de/o3de_bridge.py index 8fe6fbe0f..33682881e 100644 --- a/src/rai_sim/rai_sim/o3de/o3de_bridge.py +++ b/src/rai_sim/rai_sim/o3de/o3de_bridge.py @@ -23,12 +23,11 @@ import yaml from geometry_msgs.msg import Point, PoseStamped, Quaternion from geometry_msgs.msg import Pose as ROS2Pose -from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage -from rai.utils.ros_async import get_future_result from std_msgs.msg import Header -from std_srvs.srv import Trigger from tf2_geometry_msgs import do_transform_pose +from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage +from rai.utils.ros_async import get_future_result from rai_interfaces.srv import ManipulatorMoveTo from rai_sim.simulation_bridge import ( Entity, @@ -396,19 +395,13 @@ def _from_ros2_pose(self, pose: ROS2Pose) -> Pose: class O3DEngineArmManipulationBridge(O3DExROS2Bridge): def reset_arm(self): - client = self.connector.node.create_client( - Trigger, - "/reset_manipulator", + self.connector.service_call( + ROS2ARIMessage(payload={}), + target="/reset_manipulator", + msg_type="std_srvs/srv/Trigger", ) - while not client.wait_for_service(timeout_sec=5.0): - self.connector.node.get_logger().info("Service not available, waiting...") - - self.connector.node.get_logger().info("Making request to reset manipulator...") - request = Trigger.Request() - future = client.call_async(request) - result = get_future_result(future, timeout_sec=5.0) - self.connector.node.get_logger().debug(f"Reset manipulator result: {result}") + self.connector.node.get_logger().debug("Reset manipulator arm: DONE") def move_arm( self,