Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions src/rai_bench/rai_bench/examples/o3de_test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import rclpy
from langchain.tools import BaseTool
from rai_open_set_vision.tools import GetGrabbingPointTool

from rai.agents.conversational_agent import create_conversational_agent
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros.manipulation import (
Expand All @@ -32,8 +34,6 @@
GetROS2TopicsNamesAndTypesTool,
)
from rai.utils.model_initialization import get_llm_model
from rai_open_set_vision.tools import GetGrabbingPointTool

from rai_bench.benchmark_model import Benchmark
from rai_bench.o3de_test_bench.scenarios import (
easy_scenarios,
Expand Down Expand Up @@ -164,26 +164,31 @@

all_scenarios = t_scenarios + e_scenarios + m_scenarios + h_scenarios + vh_scenarios
o3de = O3DEngineArmManipulationBridge(connector, logger=agent_logger)
# define benchamrk
results_filename = f"{experiment_dir}/results.csv"
benchmark = Benchmark(
simulation_bridge=o3de,
scenarios=all_scenarios,
logger=bench_logger,
results_filename=results_filename,
)
for i in range(len(all_scenarios)):
agent = create_conversational_agent(
llm, tools, system_prompt, logger=agent_logger
try:
# define benchamrk
results_filename = f"{experiment_dir}/results.csv"
benchmark = Benchmark(
simulation_bridge=o3de,
scenarios=all_scenarios,
logger=bench_logger,
results_filename=results_filename,
)
benchmark.run_next(agent=agent)
o3de.reset_arm()
time.sleep(0.2) # admire the end position for a second ;)

bench_logger.info("===============================================================")
bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!")
bench_logger.info("===============================================================")
for i in range(len(all_scenarios)):
agent = create_conversational_agent(
llm, tools, system_prompt, logger=agent_logger
)
benchmark.run_next(agent=agent)
o3de.reset_arm()
time.sleep(0.2) # admire the end position for a second ;)

connector.shutdown()
o3de.shutdown()
rclpy.shutdown()
bench_logger.info(
"==============================================================="
)
bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!")
bench_logger.info(
"==============================================================="
)
finally:
connector.shutdown()
o3de.shutdown()
rclpy.shutdown()
21 changes: 7 additions & 14 deletions src/rai_sim/rai_sim/o3de/o3de_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
import yaml
from geometry_msgs.msg import Point, PoseStamped, Quaternion
from geometry_msgs.msg import Pose as ROS2Pose
from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage
from rai.utils.ros_async import get_future_result
from std_msgs.msg import Header
from std_srvs.srv import Trigger
from tf2_geometry_msgs import do_transform_pose

from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage
from rai.utils.ros_async import get_future_result
from rai_interfaces.srv import ManipulatorMoveTo
from rai_sim.simulation_bridge import (
Entity,
Expand Down Expand Up @@ -396,19 +395,13 @@ def _from_ros2_pose(self, pose: ROS2Pose) -> Pose:

class O3DEngineArmManipulationBridge(O3DExROS2Bridge):
def reset_arm(self):
client = self.connector.node.create_client(
Trigger,
"/reset_manipulator",
self.connector.service_call(
ROS2ARIMessage(payload={}),
target="/reset_manipulator",
msg_type="std_srvs/srv/Trigger",
)
while not client.wait_for_service(timeout_sec=5.0):
self.connector.node.get_logger().info("Service not available, waiting...")

self.connector.node.get_logger().info("Making request to reset manipulator...")
request = Trigger.Request()
future = client.call_async(request)
result = get_future_result(future, timeout_sec=5.0)

self.connector.node.get_logger().debug(f"Reset manipulator result: {result}")
self.connector.node.get_logger().debug("Reset manipulator arm: DONE")

def move_arm(
self,
Expand Down