diff --git a/demos.repos b/demos.repos index bda98f1cf..f0b6a59ce 100644 --- a/demos.repos +++ b/demos.repos @@ -6,4 +6,4 @@ repositories: src/examples/rai-manipulation-demo: type: git url: https://github.com/RobotecAI/rai-manipulation-demo.git - version: development + version: kd/wait_for_clock diff --git a/src/rai_bench/README.md b/src/rai_bench/README.md index 3db5d04ba..7eca910c5 100644 --- a/src/rai_bench/README.md +++ b/src/rai_bench/README.md @@ -1,36 +1,56 @@ ## RAI Benchmark -## Description +### Description The RAI Bench is a package including benchmarks and providing frame for creating new benchmarks -## Frame Components - -Frame components can be found in `src/rai_bench/rai_bench/benchmark_model.py` +### Frame Components - `Task` - `Scenario` - `Benchmark` -For more information about these classes go to -> `src/rai_bench/rai_bench/benchmark_model.py` +For more information about these classes go to -> [benchmark_model](./rai_bench/benchmark_model.py) + +### O3DE Test Benchmark + +The O3DE Test Benchmark [o3de_test_benchmark_module](./rai_bench/o3de_test_bench/) provides tasks and scene configurations for robotic arm manipulation task. The tasks use a common `ManipulationTask` logic and can be parameterized, which allows for many task variants. The current tasks include: + +- **MoveObjectToLeftTask** +- **GroupObjectsTask** +- **BuildCubeTowerTask** +- **PlaceObjectAtCoordTask** +- **RotateObjectTask** (currently not applicable due to limitations in the ManipulatorMoveTo tool) + +The result of a task is a value between 0 and 1, calculated like initially_misplaced_now_correct / initially_misplaced. This score is calculated at the end of each scenario. -### O3DE TEST BENCHMARK +Current O3DE simulation binaries: -O3DE Test Benchmark (`src/rai_bench/rai_bench/o3de_test_bench/`), contains 2 Tasks(`tasks/`) - GrabCarrotTask and PlaceCubesTask (these tasks implement calculating scores) and 4 scene_configs(`configs/`) for O3DE robotic arm simulation. +### Running -Both tasks calculate score, taking into consideration 4 values: +1. Download O3DE simulation binary and unzip it. -- initially_misplaced_now_correct -- initially_misplaced_still_incorrect -- initially_correct_still_correct -- initially_correct_now_incorrect + - [ros2-humble](https://robotec-ml-rai-public.s3.eu-north-1.amazonaws.com/RAIManipulationDemo_jammyhumble.zip) + - [ros2-jazzy](https://robotec-ml-rai-public.s3.eu-north-1.amazonaws.com/RAIManipulationDemo_noblejazzy.zip) -The result is a value between 0 and 1, calculated like (initially_misplaced_now_correct + initially_correct_still_correct) / number_of_initial_objects. -This score is calculated at the beggining and at the end of each scenario. +2. Follow step 2 from [Manipulation demo Setup section](../../docs/demos/manipulation.md#setup) + +3. Adjust the path to the binary in: [o3de_config.yaml](./rai_bench/o3de_test_bench/configs/o3de_config.yaml) +4. Run benchmark with: + + ```bash + cd rai + source setup_shell.sh + python src/rai_bench/rai_bench/examples/o3de_test_benchmark.py + ``` + +> [!NOTE] +> For now benchmark runs all available scenarios (~160). See [Examples](#example-usege) +> section for details. ### Example usage -Example of how to load scenes, define scenarios and run benchmark can be found in `src/rai_bench/rai_bench/examples/o3de_test_benchmark.py` +Example of how to load scenes, define scenarios and run benchmark can be found in [o3de_test_benchmark_example](./rai_bench/examples/o3de_test_benchmark.py) Scenarios can be loaded manually like: @@ -52,3 +72,33 @@ scenarios = Benchmark.create_scenarios( ``` which will result in list of scenarios with combination of every possible task and scene(task decides if scene config is suitable for it). + +or can be imported from exisitng packets [scenarios_packets](./rai_bench/o3de_test_bench/scenarios.py): + +```python +t_scenarios = trivial_scenarios( + configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger + ) +e_scenarios = easy_scenarios( + configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger +) +m_scenarios = medium_scenarios( + configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger +) +h_scenarios = hard_scenarios( + configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger +) +vh_scenarios = very_hard_scenarios( + configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger +) +``` + +which are grouped by their subjective difficulty. For now there are 10 trivial, 42 easy, 23 medium, 38 hard and 47 very hard scenarios. +Check docstrings and code in [scenarios_packets](./rai_bench/o3de_test_bench/scenarios.py) if you want to know how scenarios are assigned to difficulty level. + +### Development + +When creating new task or changing existing ones, make sure to add unit tests for score calculation in [rai_bench_tests](../../tests/rai_bench/). +This applies also when you are adding or changing the helper methods in `Task` or `ManipulationTask`. + +The number of scenarios can be easily extened without writing new tasks, by increasing number of variants of the same task and adding more simulation configs but it won't improve variety of scenarios as much as creating new tasks. diff --git a/src/rai_bench/rai_bench/benchmark_model.py b/src/rai_bench/rai_bench/benchmark_model.py index 83f4c1f66..804eaeea4 100644 --- a/src/rai_bench/rai_bench/benchmark_model.py +++ b/src/rai_bench/rai_bench/benchmark_model.py @@ -11,26 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import csv import logging +import math import time from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, Union +from collections import defaultdict +from typing import Any, Dict, Generic, List, Set, TypeVar, Union from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langgraph.graph.state import CompiledStateGraph from rai.messages import HumanMultimodalMessage from rclpy.impl.rcutils_logger import RcutilsLogger from rai_sim.simulation_bridge import ( + Entity, Pose, SimulationBridge, SimulationConfig, SimulationConfigT, - SpawnedEntity, ) loggers_type = Union[RcutilsLogger, logging.Logger] +EntityT = TypeVar("EntityT", bound=Entity) class EntitiesMismatchException(Exception): @@ -58,7 +61,9 @@ def __init__( @abstractmethod def get_prompt(self) -> str: - """Returns the task instruction - the prompt that will be passed to agent""" + """ + Returns the task instruction - the prompt that will be passed to agent + """ pass @abstractmethod @@ -77,16 +82,39 @@ def calculate_result( self, simulation_bridge: SimulationBridge[SimulationConfigT] ) -> float: """ - Calculates result of the task, based on info retrieved from simulation. - Should return score between 0.0 and 1. + Calculate the task result (score) based on the simulation information. + + Parameters + ---------- + simulation_bridge : SimulationBridge[SimulationConfigT] + The simulation bridge used to retrieve simulation data. + + Returns + ------- + float + A score between 0.0 and 1.0. """ pass - def filter_entities_by_prefab_type( - self, entities: List[SpawnedEntity], prefab_types: List[str] - ) -> List[SpawnedEntity]: - """Filter and return only these entities that match provided prefab types""" - return [ent for ent in entities if ent.prefab_name in prefab_types] + def filter_entities_by_object_type( + self, entities: List[EntityT], object_types: List[str] + ) -> List[EntityT]: + """ + Filter and return only the entities that match the provided prefab types. + + Parameters + ---------- + entities : List[EntityT] + The list of entities to filter. + object_types : List[str] + The allowed object types. + + Returns + ------- + List[EntityT] + A list of entities whose prefab_name is in object_types. + """ + return [ent for ent in entities if ent.prefab_name in object_types] def euclidean_distance(self, pos1: Pose, pos2: Pose) -> float: """Calculate euclidean distance between 2 positions""" @@ -98,19 +126,43 @@ def euclidean_distance(self, pos1: Pose, pos2: Pose) -> float: def is_adjacent(self, pos1: Pose, pos2: Pose, threshold_distance: float): """ - Check if positions are adjacent to each other, the threshold_distance is a distance - in simulation, refering to how close they have to be to classify them as adjacent + Check if two positions are adjacent, based on a threshold distance. + + Parameters + ---------- + pos1 : Pose + The first position. + pos2 : Pose + The second position. + threshold_distance : float + The maximum allowed distance for the positions to be considered adjacent. + + Returns + ------- + bool + True if the Euclidean distance between pos1 and pos2 is less than threshold_distance, False otherwise. """ - self.logger.debug( # type: ignore - f"Euclidean distance: {self.euclidean_distance(pos1, pos2)}, pos1: {pos1}, pos2: {pos2}" - ) return self.euclidean_distance(pos1, pos2) < threshold_distance def is_adjacent_to_any( self, pos1: Pose, positions: List[Pose], threshold_distance: float ) -> bool: """ - Check if given position is adjacent to any position in the given list. + Check if a position is adjacent to any position in a given list. + + Parameters + ---------- + pos1 : Pose + The position to check. + positions : List[Pose] + A list of positions to compare against. + threshold_distance : float + The distance threshold for adjacency. + + Returns + ------- + bool + True if pos1 is adjacent to any position in positions, False otherwise. """ return any( @@ -119,9 +171,19 @@ def is_adjacent_to_any( def count_adjacent(self, positions: List[Pose], threshold_distance: float) -> int: """ - Count how many adjacent positions are in the given list. - Note that position has to be adjacent to only 1 other position - to be counted, not all of them + Count how many positions in the list are adjacent to at least one other position. + + Parameters + ---------- + positions : List[Pose] + A list of positions. + threshold_distance : float + The distance threshold to determine adjacency. + + Returns + ------- + int + The count of positions that are adjacent to at least one other position. """ adjacent_count = 0 @@ -134,10 +196,171 @@ def count_adjacent(self, positions: List[Pose], threshold_distance: float) -> in return adjacent_count + def build_neighbourhood_list( + self, entities: List[EntityT], threshold_distance: float = 0.15 + ) -> Dict[EntityT, List[EntityT]]: + """ + Build a neighbourhood list assigning a list of neighbours to every entity based on a threshold distance. + + Parameters + ---------- + entities : List[EntityT] + The list of entities. + threshold_distance : float, optional + The maximum distance between entities to consider them neighbours. Default is 0.15. + + Returns + ------- + Dict[EntityT, List[EntityT]] + A dictionary mapping each entity to a list of neighbouring entities. + """ + neighbourhood_graph: Dict[EntityT, List[EntityT]] = { + entity: [] for entity in entities + } + for entity in entities: + neighbourhood_graph[entity] = [ + other + for other in entities + if entity != other + and self.is_adjacent(entity.pose, other.pose, threshold_distance) + ] + return neighbourhood_graph + + def group_entities_by_type( + self, entities: List[EntityT] + ) -> Dict[str, List[EntityT]]: + """ + Group entities by their prefab type. + + Parameters + ---------- + entities : List[EntityT] + The list of entities to group. + + Returns + ------- + Dict[str, List[EntityT]] + A dictionary with keys as prefab names and values as lists of entities of that type. + """ + entities_by_type: Dict[str, List[EntityT]] = defaultdict(list) + for entity in entities: + entities_by_type[entity.prefab_name].append(entity) + return entities_by_type + + def check_neighbourhood_types( + self, + neighbourhood: List[EntityT], + allowed_types: List[str], + ) -> bool: + """ + Check if all entities in the neighbourhood are of the allowed types. + + Parameters + ---------- + neighbourhood : List[EntityT] + The list of neighbouring entities. + allowed_types : List[str] + The allowed prefab types. + + Returns + ------- + bool + True if the neighbourhood is empty or if all neighbours have a prefab_name in allowed_types, False otherwise. + """ + return not neighbourhood or all( + adj.prefab_name in allowed_types for adj in neighbourhood + ) + + def find_clusters( + self, neighbourhood_list: Dict[EntityT, List[EntityT]] + ) -> List[List[EntityT]]: + """ + Identify clusters of entities using a DFS algorithm. + + Each connected component in the neighbourhood graph is considered a cluster. + Lone entities are counted as their own cluster. + + Parameters + ---------- + neighbourhood_list : Dict[EntityT, List[EntityT]] + A dictionary mapping entities to their list of neighbours. + + Returns + ------- + List[List[EntityT]] + A list of clusters, where each cluster is a list of connected entities. + """ + visited: Set[EntityT] = set() + clusters: List[List[EntityT]] = [] + + def dfs(node: EntityT, cluster: List[EntityT]): + visited.add(node) + cluster.append(node) + for neighbor in neighbourhood_list.get(node, []): + if neighbor not in visited: + dfs(neighbor, cluster) + + for node in neighbourhood_list.keys(): + if node not in visited: + component: List[EntityT] = [] + dfs(node, component) + clusters.append(component) + + return clusters + + def group_entities_along_z_axis( + # NOTE (jmatejcz) figure out how to group by other coords and orientation, without reapeting code + self, + entities: List[EntityT], + margin: float, + ) -> List[List[EntityT]]: + """ + Group entities that are aligned along the z axis based on their x and y coordinates. + + Entities are first sorted by their x and y coordinates. Then, each entity is added to an existing group + if its (x, y) distance from the first entity in the group is within the specified margin. + Otherwise, a new group is created. + + Example + ---------- + You have 2 separate vertical towers of cubes. + In that case method will return 2 groups of entities, one for each tower. + + Parameters + ---------- + entities : List[EntityT] + The list of entities to group. + margin : float + The maximum allowable Euclidean distance in the x-y plane to consider entities as part of the same group. + + Returns + ------- + List[List[EntityT]] + A list of groups (clusters) of entities. + """ + + entities = sorted( + entities, key=lambda ent: (ent.pose.translation.x, ent.pose.translation.y) + ) + + groups: List[List[EntityT]] = [] + for entity in entities: + placed = False + for group in groups: + dx = group[0].pose.translation.x - entity.pose.translation.x + dy = group[0].pose.translation.y - entity.pose.translation.y + if math.sqrt(dx * dx + dy * dy) <= margin: + group.append(entity) + placed = True + break + if not placed: + groups.append([entity]) + return groups + class Scenario(Generic[SimulationConfigT]): """ - A Scenarios are defined by a pair of Task and Simlation Config. + A Scenario are defined by a pair of Task and Simlation Config. Each Scenario is executed separatly by a Benchmark. """ @@ -147,12 +370,27 @@ def __init__( simulation_config: SimulationConfigT, simulation_config_path: str, ) -> None: - if not task.validate_config(simulation_config): - raise ValueError("This scene is invalid for this task.") + """ + Initialize a Scenario. + + Parameters + ---------- + task : Task + The task to be executed. + simulation_config : SimulationConfigT + The simulation configuration for the scenario. + simulation_config_path : str + The file path to the simulation configuration. + + Raises + ------ + ValueError + If the provided simulation configuration is not valid for the task. + """ self.task = task self.simulation_config = simulation_config - # NOTE (jm) needed for logging which config was used, - # there probably is better method to do it + # NOTE (jmatejcz) needed for logging which config was used, + # there probably is better way to do it self.simulation_config_path = simulation_config_path @@ -195,14 +433,33 @@ def create_scenarios( tasks: List[Task], simulation_configs: List[SimulationConfigT], simulation_configs_paths: List[str], + logger: loggers_type | None = None, ) -> List[Scenario[SimulationConfigT]]: - # TODO (jm) hacky_fix, taking paths as args here, not the best solution, + """ + Create scenarios by pairing each task with each suitable simulation configuration. + + Parameters + ---------- + tasks : List[Task] + The list of tasks. + simulation_configs : List[SimulationConfigT] + The list of simulation configurations. + simulation_configs_paths : List[str] + The corresponding file paths for the simulation configurations. + + Returns + ------- + List[Scenario[SimulationConfigT]] + A list of scenarios generated from the given tasks and simulation configurations. + """ + # NOTE (jmatejcz) hacky_fix, taking paths as args here, not the best solution, # but more changes to code would be required scenarios: List[Scenario[SimulationConfigT]] = [] - + if not logger: + logger = logging.getLogger(__name__) for task in tasks: for sim_conf, sim_path in zip(simulation_configs, simulation_configs_paths): - try: + if task.validate_config(simulation_config=sim_conf): scenarios.append( Scenario( task=task, @@ -210,9 +467,9 @@ def create_scenarios( simulation_config_path=sim_path, ) ) - except ValueError as e: - print( - f"Could not create Scenario from task: {task.get_prompt()} and simulation_config: {sim_conf}, {e}" + else: + logger.debug( + f"Simulation config: {sim_path} is not suitable for task: {task.get_prompt()}" ) return scenarios @@ -224,25 +481,36 @@ def _initialize_results_file(self): writer = csv.DictWriter(file, fieldnames=self.fieldnames) writer.writeheader() - def run_next(self, agent) -> None: + def run_next(self, agent: CompiledStateGraph) -> None: """ - Runs the next scenario + Run the next scenario in the benchmark. + + Parameters + ---------- + agent : CompiledStateGraph + The agent used to execute the scenario. + + This method sets up the scene, streams the agent's responses, logs messages, + counts tool calls, calculates the final task score, and writes the result to a CSV file. """ try: i, scenario = next(self.scenarios) # Get the next scenario self.simulation_bridge.setup_scene(scenario.simulation_config) - self._logger.info( # type: ignore + self._logger.info( "======================================================================================" ) - self._logger.info( # type: ignore - f"RUNNING SCENARIO NUMBER {i + 1} / {self.num_of_scenarios}, TASK: {scenario.task.get_prompt()}" + self._logger.info( + f"RUNNING SCENARIO NUMBER {i + 1} / {self.num_of_scenarios}\n TASK: {scenario.task.get_prompt()}\n SIMULATION_CONFIG: {scenario.simulation_config_path}" ) tool_calls_num = 0 ts = time.perf_counter() for state in agent.stream( - {"messages": [HumanMessage(content=scenario.task.get_prompt())]} + {"messages": [HumanMessage(content=scenario.task.get_prompt())]}, + { + "recursion_limit": 100 + }, # NOTE (jmatejcz) what should be recursion limit? ): graph_node_name = list(state.keys())[0] msg = state[graph_node_name]["messages"][-1] @@ -256,34 +524,34 @@ def run_next(self, agent) -> None: last_msg = msg.content[0].get("text", "") else: last_msg = msg.content - self._logger.debug(f"{graph_node_name}: {last_msg}") # type: ignore + self._logger.debug(f"{graph_node_name}: {last_msg}") else: raise ValueError(f"Unexpected type of message: {type(msg)}") if isinstance(msg, AIMessage): - # TODO (jm) figure out more robust way of counting tool calls tool_calls_num += len(msg.tool_calls) - self._logger.info(f"AI Message: {msg}") # type: ignore + self._logger.info(f"AI Message: {msg}") te = time.perf_counter() - - result = scenario.task.calculate_result(self.simulation_bridge) - total_time = te - ts - self._logger.info( # type: ignore - f"TASK SCORE: {result}, TOTAL TIME: {total_time:.3f}, NUM_OF_TOOL_CALLS: {tool_calls_num}" - ) - - scenario_result: Dict[str, Any] = { - "task": scenario.task.get_prompt(), - "simulation_config": scenario.simulation_config_path, - "final_score": result, - "total_time": f"{total_time:.3f}", - "number_of_tool_calls": tool_calls_num, - } - self.results.append(scenario_result) - self._save_scenario_result_to_csv(scenario_result) + try: + result = scenario.task.calculate_result(self.simulation_bridge) + total_time = te - ts + self._logger.info( + f"TASK SCORE: {result}, TOTAL TIME: {total_time:.3f}, NUM_OF_TOOL_CALLS: {tool_calls_num}" + ) + scenario_result: Dict[str, Any] = { + "task": scenario.task.get_prompt(), + "simulation_config": scenario.simulation_config_path, + "final_score": result, + "total_time": f"{total_time:.3f}", + "number_of_tool_calls": tool_calls_num, + } + self.results.append(scenario_result) + self._save_scenario_result_to_csv(scenario_result) + except EntitiesMismatchException as e: + self._logger.error(e) except StopIteration: print("No more scenarios left to run.") diff --git a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py index 8e0031d6c..4a6dabeb0 100644 --- a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py +++ b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py @@ -23,19 +23,28 @@ from langchain.tools import BaseTool from rai.agents.conversational_agent import create_conversational_agent from rai.communication.ros2.connectors import ROS2ARIConnector -from rai.tools.ros.manipulation import MoveToPointTool -from rai.tools.ros2.topics import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool +from rai.tools.ros.manipulation import ( + GetObjectPositionsTool, + MoveToPointTool, +) +from rai.tools.ros2.topics import ( + GetROS2ImageTool, + GetROS2TopicsNamesAndTypesTool, +) from rai.utils.model_initialization import get_llm_model +from rai_open_set_vision.tools import GetGrabbingPointTool -from rai_bench.benchmark_model import Benchmark, Task -from rai_bench.o3de_test_bench.tasks import GrabCarrotTask, PlaceCubesTask +from rai_bench.benchmark_model import Benchmark +from rai_bench.o3de_test_bench.scenarios import ( + easy_scenarios, + hard_scenarios, + medium_scenarios, + trivial_scenarios, + very_hard_scenarios, +) from rai_sim.o3de.o3de_bridge import ( O3DEngineArmManipulationBridge, - O3DExROS2SimulationConfig, - Pose, ) -from rai_sim.simulation_bridge import Rotation, Translation -from rai_sim.tools import GetObjectPositionsGroundTruthTool if __name__ == "__main__": rclpy.init() @@ -54,6 +63,21 @@ z - up to down (positive is up) Before starting the task, make sure to grab the camera image to understand the environment. """ + # define tools + tools: List[BaseTool] = [ + GetObjectPositionsTool( + connector=connector, + target_frame="panda_link0", + source_frame="RGBDCamera5", + camera_topic="/color_image5", + depth_topic="/depth_image5", + camera_info_topic="/color_camera_info5", + get_grabbing_point_tool=GetGrabbingPointTool(connector=connector), + ), + MoveToPointTool(connector=connector, manipulator_frame="panda_link0"), + GetROS2ImageTool(connector=connector), + GetROS2TopicsNamesAndTypesTool(connector=connector), + ] # define loggers now = datetime.now() experiment_dir = ( @@ -121,68 +145,50 @@ # ), # ] - ### Create scenarios automatically - simulation_configs_paths = [ - configs_dir + "scene1.yaml", - configs_dir + "scene2.yaml", - configs_dir + "scene3.yaml", - configs_dir + "scene4.yaml", - ] - simulations_configs = [ - O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) - for path in simulation_configs_paths - ] - tasks: List[Task] = [ - GrabCarrotTask(logger=bench_logger), - PlaceCubesTask(logger=bench_logger), - ] - scenarios = Benchmark.create_scenarios( - tasks=tasks, - simulation_configs=simulations_configs, - simulation_configs_paths=simulation_configs_paths, + ### import ready scenarios + t_scenarios = trivial_scenarios( + configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger ) - - # custom request to arm - base_arm_pose = Pose( - translation=Translation(x=0.1, y=0.5, z=0.4), - rotation=Rotation(x=1.0, y=0.0, z=0.0, w=0.0), + e_scenarios = easy_scenarios( + configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger + ) + m_scenarios = medium_scenarios( + configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger + ) + h_scenarios = hard_scenarios( + configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger + ) + vh_scenarios = very_hard_scenarios( + configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger ) + all_scenarios = t_scenarios + e_scenarios + m_scenarios + h_scenarios + vh_scenarios o3de = O3DEngineArmManipulationBridge(connector, logger=agent_logger) - # define benchamrk - results_filename = f"{experiment_dir}/results.csv" - benchmark = Benchmark( - simulation_bridge=o3de, - scenarios=scenarios, - logger=bench_logger, - results_filename=results_filename, - ) - for i, s in enumerate(scenarios): - # define tools - tools: List[BaseTool] = [ - GetObjectPositionsGroundTruthTool( - simulation=o3de, - ), - MoveToPointTool(connector=connector, manipulator_frame="panda_link0"), - GetROS2ImageTool(connector=connector), - GetROS2TopicsNamesAndTypesTool(connector=connector), - ] - agent = create_conversational_agent( - llm, tools, system_prompt, logger=agent_logger + try: + # define benchamrk + results_filename = f"{experiment_dir}/results.csv" + benchmark = Benchmark( + simulation_bridge=o3de, + scenarios=all_scenarios, + logger=bench_logger, + results_filename=results_filename, ) - benchmark.run_next(agent=agent) - o3de.move_arm( - pose=base_arm_pose, - initial_gripper_state=True, - final_gripper_state=False, - frame_id="panda_link0", - ) # return to case position - time.sleep(0.2) # admire the end position for a second ;) + for i in range(len(all_scenarios)): + agent = create_conversational_agent( + llm, tools, system_prompt, logger=agent_logger + ) + benchmark.run_next(agent=agent) + o3de.reset_arm() + time.sleep(0.2) # admire the end position for a second ;) - bench_logger.info("===============================================================") - bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") - bench_logger.info("===============================================================") - - connector.shutdown() - o3de.shutdown() - rclpy.shutdown() + bench_logger.info( + "===============================================================" + ) + bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!") + bench_logger.info( + "===============================================================" + ) + finally: + connector.shutdown() + o3de.shutdown() + rclpy.shutdown() diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/1a.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1a.yaml new file mode 100644 index 000000000..e932c6e9a --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/1a.yaml @@ -0,0 +1,13 @@ +entities: + - name: apple1 + prefab_name: apple + pose: + translation: + x: 0.3 + y: -0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/1a_1t.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1a_1t.yaml new file mode 100644 index 000000000..c59b0476f --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/1a_1t.yaml @@ -0,0 +1,25 @@ +entities: + - name: apple1 + prefab_name: apple + pose: + translation: + x: 0.3 + y: -0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: tomato1 + prefab_name: tomato + pose: + translation: + x: 0.2 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/1a_2bc.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1a_2bc.yaml new file mode 100644 index 000000000..c653157d6 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/1a_2bc.yaml @@ -0,0 +1,37 @@ +entities: + - name: apple1 + prefab_name: apple + pose: + translation: + x: 0.3 + y: -0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube1 + prefab_name: blue_cube + pose: + translation: + x: 0.1 + y: 0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube2 + prefab_name: blue_cube + pose: + translation: + x: 0.2 + y: -0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/1bc_1rc_1yc.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1bc_1rc_1yc.yaml new file mode 100644 index 000000000..69c7dc854 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/1bc_1rc_1yc.yaml @@ -0,0 +1,37 @@ +entities: + - name: blue_cube1 + prefab_name: blue_cube + pose: + translation: + x: 0.15 + y: 0.35 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: red_cube1 + prefab_name: red_cube + pose: + translation: + x: 0.45 + y: 0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: yellow_cube1 + prefab_name: yellow_cube + pose: + translation: + x: 0.25 + y: 0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot.yaml new file mode 100644 index 000000000..d5266fc28 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot.yaml @@ -0,0 +1,13 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: -0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot_1a_1t_1bc_1corn.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot_1a_1t_1bc_1corn.yaml new file mode 100644 index 000000000..a1c9f9e68 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot_1a_1t_1bc_1corn.yaml @@ -0,0 +1,61 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.1 + y: -0.2 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: apple1 + prefab_name: tomato + pose: + translation: + x: 0.2 + y: -0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: tomato1 + prefab_name: tomato + pose: + translation: + x: 0.3 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube1 + prefab_name: blue_cube + pose: + translation: + x: 0.4 + y: 0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: corn1 + prefab_name: corn + pose: + translation: + x: 0.55 + y: 0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot_1a_2t_1bc_1rc_3yc_stacked.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot_1a_2t_1bc_1rc_3yc_stacked.yaml new file mode 100644 index 000000000..5986dd370 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot_1a_2t_1bc_1rc_3yc_stacked.yaml @@ -0,0 +1,110 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.05 + y: -0.45 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: apple1 + prefab_name: apple + pose: + translation: + x: 0.20 + y: -0.10 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: tomato1 + prefab_name: tomato + pose: + translation: + x: 0.25 + y: -0.10 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: tomato2 + prefab_name: tomato + pose: + translation: + x: 0.25 + y: -0.15 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube1 + prefab_name: blue_cube + pose: + translation: + x: 0.40 + y: 0.40 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: red_cube1 + prefab_name: red_cube + pose: + translation: + x: 0.45 + y: 0.40 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: yellow_cube1 + prefab_name: yellow_cube + pose: + translation: + x: 0.50 + y: 0.40 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: yellow_cube2 + prefab_name: yellow_cube + pose: + translation: + x: 0.50 + y: 0.40 + z: 0.15 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: yellow_cube3 + prefab_name: yellow_cube + pose: + translation: + x: 0.50 + y: 0.40 + z: 0.25 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot_1bc.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot_1bc.yaml new file mode 100644 index 000000000..506c44724 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot_1bc.yaml @@ -0,0 +1,25 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.3 + y: -0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube1 + prefab_name: blue_cube + pose: + translation: + x: 0.2 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot_1corn.yaml similarity index 100% rename from src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml rename to src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot_1corn.yaml diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot_1t_1rc.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot_1t_1rc.yaml new file mode 100644 index 000000000..abdc7ccc1 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/1carrot_1t_1rc.yaml @@ -0,0 +1,37 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.4 + y: 0.0 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: tomato1 + prefab_name: tomato + pose: + translation: + x: 0.45 + y: 0.0 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: red_cube1 + prefab_name: red_cube + pose: + translation: + x: 0.35 + y: 0.0 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/1rc.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1rc.yaml new file mode 100644 index 000000000..8423048b4 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/1rc.yaml @@ -0,0 +1,13 @@ +entities: + - name: red_cube1 + prefab_name: red_cube + pose: + translation: + x: 0.2 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/1rc_2bc_3yc.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1rc_2bc_3yc.yaml new file mode 100644 index 000000000..6151757ad --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/1rc_2bc_3yc.yaml @@ -0,0 +1,73 @@ +entities: + - name: cube1 + prefab_name: red_cube + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: cube2 + prefab_name: blue_cube + pose: + translation: + x: 0.4 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: cube3 + prefab_name: yellow_cube + pose: + translation: + x: 0.5 + y: -0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: cube4 + prefab_name: yellow_cube + pose: + translation: + x: 0.5 + y: 0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: cube5 + prefab_name: yellow_cube + pose: + translation: + x: 0.6 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: cube6 + prefab_name: blue_cube + pose: + translation: + x: 0.6 + y: -0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/1t.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1t.yaml new file mode 100644 index 000000000..cab0487ef --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/1t.yaml @@ -0,0 +1,13 @@ +entities: + - name: tomato1 + prefab_name: tomato + pose: + translation: + x: 0.2 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/1yc.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1yc.yaml new file mode 100644 index 000000000..356a5df22 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/1yc.yaml @@ -0,0 +1,13 @@ +entities: + - name: yellow_cube1 + prefab_name: yellow_cube + pose: + translation: + x: 0.5 + y: 0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/1yc_1rc.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/1yc_1rc.yaml new file mode 100644 index 000000000..4a2244311 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/1yc_1rc.yaml @@ -0,0 +1,25 @@ +entities: + - name: yellow_cube1 + prefab_name: yellow_cube + pose: + translation: + x: 0.2 + y: 0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: red_cube1 + prefab_name: red_cube + pose: + translation: + x: 0.4 + y: 0.2 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/2a_1bc.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/2a_1bc.yaml new file mode 100644 index 000000000..b7bccd51a --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/2a_1bc.yaml @@ -0,0 +1,37 @@ +entities: + - name: apple1 + prefab_name: apple + pose: + translation: + x: 0.3 + y: -0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: apple2 + prefab_name: apple + pose: + translation: + x: 0.3 + y: -0.15 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube1 + prefab_name: blue_cube + pose: + translation: + x: 0.3 + y: -0.2 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/2a_1c_2rc.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/2a_1c_2rc.yaml new file mode 100644 index 000000000..c4dc5a065 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/2a_1c_2rc.yaml @@ -0,0 +1,75 @@ +entities: + - name: apple1 + prefab_name: apple + pose: + translation: + x: 0.1 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: apple2 + prefab_name: apple + pose: + translation: + x: 0.35 + y: 0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: corn1 + prefab_name: corn + pose: + translation: + x: 0.2 + y: 0.0 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: corn2 + prefab_name: corn + pose: + translation: + x: 0.1 + y: -0.45 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: red_cube1 + prefab_name: red_cube + pose: + translation: + x: 0.05 + y: 0.25 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: red_cube2 + prefab_name: red_cube + pose: + translation: + x: 0.25 + y: -0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/2carrots_1a_1t_1bc_1rc_1yc_1corn.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/2carrots_1a_1t_1bc_1rc_1yc_1corn.yaml new file mode 100644 index 000000000..3f5717d31 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/2carrots_1a_1t_1bc_1rc_1yc_1corn.yaml @@ -0,0 +1,97 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.1 + y: -0.2 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: apple1 + prefab_name: apple + pose: + translation: + x: 0.3 + y: -0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: tomato1 + prefab_name: tomato + pose: + translation: + x: 0.35 + y: -0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube1 + prefab_name: blue_cube + pose: + translation: + x: 0.5 + y: 0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: red_cube1 + prefab_name: red_cube + pose: + translation: + x: 0.55 + y: 0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: yellow_cube1 + prefab_name: yellow_cube + pose: + translation: + x: 0.6 + y: 0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: corn1 + prefab_name: corn + pose: + translation: + x: 0.6 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: carrot2 + prefab_name: carrot + pose: + translation: + x: 0.15 + y: -0.2 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/2carrots_2a.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/2carrots_2a.yaml new file mode 100644 index 000000000..4e8b92025 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/2carrots_2a.yaml @@ -0,0 +1,50 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: corn1 + prefab_name: apple + pose: + translation: + x: 0.5 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: carrot2 + prefab_name: carrot + pose: + translation: + x: 0.1 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: corn2 + prefab_name: apple + pose: + translation: + x: 0.1 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/2rc.yaml similarity index 100% rename from src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml rename to src/rai_bench/rai_bench/o3de_test_bench/configs/2rc.yaml diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/2rc_3bc_4yc_stacked.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/2rc_3bc_4yc_stacked.yaml new file mode 100644 index 000000000..2f4e783dc --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/2rc_3bc_4yc_stacked.yaml @@ -0,0 +1,112 @@ +entities: + - name: red_cube1 + prefab_name: red_cube + pose: + translation: + x: 0.1 + y: -0.2 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: red_cube2 + prefab_name: red_cube + pose: + translation: + x: 0.1 + y: -0.2 + z: 0.15 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: blue_cube1 + prefab_name: blue_cube + pose: + translation: + x: 0.1 + y: -0.2 + z: 0.25 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: blue_cube2 + prefab_name: blue_cube + pose: + translation: + x: 0.2 + y: -0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: blue_cube3 + prefab_name: blue_cube + pose: + translation: + x: 0.2 + y: -0.1 + z: 0.15 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: yellow_cube1 + prefab_name: yellow_cube + pose: + translation: + x: 0.2 + y: -0.1 + z: 0.25 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: yellow_cube2 + prefab_name: yellow_cube + pose: + translation: + x: 0.2 + y: -0.1 + z: 0.35 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: yellow_cube3 + prefab_name: yellow_cube + pose: + translation: + x: 0.5 + y: -0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: yellow_cube4 + prefab_name: yellow_cube + pose: + translation: + x: 0.4 + y: -0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/2t.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/2t.yaml new file mode 100644 index 000000000..eed15f6bb --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/2t.yaml @@ -0,0 +1,25 @@ +entities: + - name: tomato1 + prefab_name: tomato + pose: + translation: + x: 0.3 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: tomato2 + prefab_name: tomato + pose: + translation: + x: 0.4 + y: 0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/2t_3a_1corn_2rc.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/2t_3a_1corn_2rc.yaml new file mode 100644 index 000000000..aa773cc80 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/2t_3a_1corn_2rc.yaml @@ -0,0 +1,97 @@ +entities: + - name: tomato1 + prefab_name: tomato + pose: + translation: + x: 0.2 + y: 0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: tomato2 + prefab_name: tomato + pose: + translation: + x: 0.3 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: apple1 + prefab_name: apple + pose: + translation: + x: 0.35 + y: 0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: apple2 + prefab_name: apple + pose: + translation: + x: 0.4 + y: -0.2 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: apple3 + prefab_name: apple + pose: + translation: + x: 0.2 + y: 0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: corn1 + prefab_name: corn + pose: + translation: + x: 0.1 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: red_cube1 + prefab_name: red_cube + pose: + translation: + x: 0.5 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: red_cube2 + prefab_name: red_cube + pose: + translation: + x: 0.25 + y: -0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/2yc_1bc_1rc.yaml similarity index 100% rename from src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml rename to src/rai_bench/rai_bench/o3de_test_bench/configs/2yc_1bc_1rc.yaml diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/3a_4t_2bc.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/3a_4t_2bc.yaml new file mode 100644 index 000000000..73013b2cd --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/3a_4t_2bc.yaml @@ -0,0 +1,112 @@ +entities: + - name: apple1 + prefab_name: apple + pose: + translation: + x: 0.4 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: apple2 + prefab_name: apple + pose: + translation: + x: 0.3 + y: -0.2 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: apple3 + prefab_name: apple + pose: + translation: + x: 0.5 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: tomato1 + prefab_name: tomato + pose: + translation: + x: 0.4 + y: -0.35 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: tomato2 + prefab_name: tomato + pose: + translation: + x: 0.4 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: tomato3 + prefab_name: tomato + pose: + translation: + x: 0.1 + y: -0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: tomato4 + prefab_name: tomato + pose: + translation: + x: 0.5 + y: 0.35 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube1 + prefab_name: blue_cube + pose: + translation: + x: 0.6 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube2 + prefab_name: blue_cube + pose: + translation: + x: 0.35 + y: 0.30 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/3carrots_1a_1t_2bc_2yc.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/3carrots_1a_1t_2bc_2yc.yaml new file mode 100644 index 000000000..df5418540 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/3carrots_1a_1t_2bc_2yc.yaml @@ -0,0 +1,109 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.1 + y: -0.2 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: carrot2 + prefab_name: carrot + pose: + translation: + x: 0.4 + y: 0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: carrot3 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: -0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: apple1 + prefab_name: apple + pose: + translation: + x: 0.3 + y: -0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube1 + prefab_name: blue_cube + pose: + translation: + x: 0.1 + y: 0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube2 + prefab_name: blue_cube + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: tomato1 + prefab_name: tomato + pose: + translation: + x: 0.2 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: yellow_cube1 + prefab_name: yellow_cube + pose: + translation: + x: 0.2 + y: 0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: yellow_cube2 + prefab_name: yellow_cube + pose: + translation: + x: 0.3 + y: -0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/3carrots_1a_2bc_1rc_1yc_1corn.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/3carrots_1a_2bc_1rc_1yc_1corn.yaml new file mode 100644 index 000000000..ba8782eca --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/3carrots_1a_2bc_1rc_1yc_1corn.yaml @@ -0,0 +1,112 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.0 + y: -0.2 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: carrot2 + prefab_name: carrot + pose: + translation: + x: 0.05 + y: -0.2 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: carrot3 + prefab_name: carrot + pose: + translation: + x: 0.05 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: apple1 + prefab_name: apple + pose: + translation: + x: 0.2 + y: -0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: blue_cube1 + prefab_name: blue_cube + pose: + translation: + x: 0.4 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: red_cube1 + prefab_name: red_cube + pose: + translation: + x: 0.45 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: yellow_cube1 + prefab_name: yellow_cube + pose: + translation: + x: 0.5 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: corn1 + prefab_name: corn + pose: + translation: + x: 0.6 + y: 0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube2 + prefab_name: blue_cube + pose: + translation: + x: 0.4 + y: 0.45 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/3rc_3bc_stacked.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/3rc_3bc_stacked.yaml new file mode 100644 index 000000000..bf9063b88 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/3rc_3bc_stacked.yaml @@ -0,0 +1,78 @@ +entities: + - name: cube1 + prefab_name: red_cube + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: cube2 + prefab_name: blue_cube + pose: + translation: + x: 0.4 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: cube3 + prefab_name: red_cube + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.15 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: cube4 + prefab_name: red_cube + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.25 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: cube5 + prefab_name: blue_cube + pose: + translation: + x: 0.3 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + + - name: cube6 + prefab_name: blue_cube + pose: + translation: + x: 0.3 + y: -0.3 + z: 0.15 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/4bc.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/4bc.yaml new file mode 100644 index 000000000..edfe61abe --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/4bc.yaml @@ -0,0 +1,49 @@ +entities: + - name: blue_cube1 + prefab_name: blue_cube + pose: + translation: + x: 0.1 + y: -0.2 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube2 + prefab_name: blue_cube + pose: + translation: + x: 0.2 + y: -0.1 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube3 + prefab_name: blue_cube + pose: + translation: + x: 0.3 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 + - name: blue_cube4 + prefab_name: blue_cube + pose: + translation: + x: 0.4 + y: 0.5 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.0 + w: 1.0 diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/4carrots.yaml similarity index 100% rename from src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml rename to src/rai_bench/rai_bench/o3de_test_bench/configs/4carrots.yaml diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/4carrots_rotated.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/4carrots_rotated.yaml new file mode 100644 index 000000000..b77b61f7c --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/4carrots_rotated.yaml @@ -0,0 +1,49 @@ +entities: + - name: carrot1 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: -0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.7071 # sin(45°) + w: 0.7071 # cos(45°) -> 90° rotation + - name: carrot2 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: 0.3 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.3827 # sin(22.5°) + w: 0.9239 # cos(22.5°) -> 45° rotation + - name: carrot3 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: 0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 0.9239 # sin(67.5°) + w: 0.3827 # cos(67.5°) -> 135° rotation + - name: carrot4 + prefab_name: carrot + pose: + translation: + x: 0.5 + y: -0.4 + z: 0.05 + rotation: + x: 0.0 + y: 0.0 + z: 1.0 # sin(90°) + w: 0.0 # cos(90°) -> 180° rotation diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/o3de_config.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/o3de_config.yaml new file mode 100644 index 000000000..9a86f79ca --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/o3de_config.yaml @@ -0,0 +1,19 @@ +binary_path: /path/to/RAIManipulationDemo/RAIManipulationDemo.GameLauncher +level: RoboticManipulationBenchmark +robotic_stack_command: ros2 launch examples/manipulation-demo-no-binary.launch.py +required_simulation_ros2_interfaces: + services: + - /spawn_entity + - /delete_entity + topics: + - /color_image5 + - /depth_image5 + - /color_camera_info5 + actions: [] +required_robotic_ros2_interfaces: + services: + - /grounding_dino_classify + - /grounded_sam_segment + - /manipulator_move_to + topics: [] + actions: [] diff --git a/src/rai_bench/rai_bench/o3de_test_bench/scenarios.py b/src/rai_bench/rai_bench/o3de_test_bench/scenarios.py new file mode 100644 index 000000000..cdd26e791 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/scenarios.py @@ -0,0 +1,569 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path +from typing import List, Union + +from rclpy.impl.rcutils_logger import RcutilsLogger + +from rai_bench.benchmark_model import ( + Benchmark, + Scenario, + Task, +) +from rai_bench.o3de_test_bench.tasks import ( + BuildCubeTowerTask, + GroupObjectsTask, + MoveObjectsToLeftTask, + PlaceCubesTask, + PlaceObjectAtCoordTask, +) +from rai_sim.o3de.o3de_bridge import ( + O3DExROS2SimulationConfig, +) + +loggers_type = Union[RcutilsLogger, logging.Logger] + + +def trivial_scenarios( + configs_dir: str, connector_path: str, logger: loggers_type | None +) -> List[Scenario[O3DExROS2SimulationConfig]]: + """Packet of trivial scenarios. The grading is subjective. + This packet contains easy variants of 'easy' tasks with minimalistic scenes setups(1 object). + + In this packet: + PlaceObjectAtCoordTask with large allowable_displacement + MoveObjectsToLeftTask with only 1 object type + + This level of difficulty requires recognizing position of object and moving it once + + Parameters + ---------- + configs_dir : str + path to directory with simulation configs + connector_path : str + path to connector config + + + Returns + ------- + List[Scenario[O3DExROS2SimulationConfig]] + list of trivial scenarios + """ + simulation_configs_paths: List[str] = [ + configs_dir + "1a.yaml", + configs_dir + "1rc.yaml", + configs_dir + "1t.yaml", + configs_dir + "1yc.yaml", + configs_dir + "1carrot.yaml", + ] + simulations_configs = [ + O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) + for path in simulation_configs_paths + ] + # place object at coords + place_obj_types = [ + "apple", + "carrot", + "yellow_cube", + "red_cube", + ] + target_coords = [(0.3, 0.3), (0.2, -0.4)] + allowable_displacements = [0.1] # large margin + place_object_tasks: List[Task] = [] + for obj in place_obj_types: + for coord in target_coords: + for disp in allowable_displacements: + place_object_tasks.append( + PlaceObjectAtCoordTask(obj, coord, disp, logger=logger) + ) + easy_place_objects_scenarios = Benchmark.create_scenarios( + tasks=place_object_tasks, + simulation_configs=simulations_configs, + simulation_configs_paths=simulation_configs_paths, + ) + # move objects to the left + object_groups = [["carrot"], ["red_cube"], ["tomato"], ["yellow_cube"]] + + move_to_left_tasks: List[Task] = [ + MoveObjectsToLeftTask(obj_types=objects, logger=logger) + for objects in object_groups + ] + + easy_move_to_left_scenarios = Benchmark.create_scenarios( + tasks=move_to_left_tasks, + simulation_configs=simulations_configs, + simulation_configs_paths=simulation_configs_paths, + ) + + return [*easy_move_to_left_scenarios, *easy_place_objects_scenarios] + + +def easy_scenarios( + configs_dir: str, connector_path: str, logger: loggers_type | None +) -> List[Scenario[O3DExROS2SimulationConfig]]: + """Packet of easy scenarios. The grading is subjective. + This packet contains easy variants of 'easy' tasks with scenes containg no more than 3 objects + + In this packet: + PlaceObjectAtCoordTask with small allowable_displacement + MoveObjectsToLeftTask with only 1 object type + PlaceCubesTask with large threshold + + This level of difficulty requires recognizing proper type of object. + Some scenarios will require moving more than 1 object or moving with more precision. + + Parameters + ---------- + configs_dir : str + path to directory with simulation configs + connector_path : str + path to connector config + + + Returns + ------- + List[Scenario[O3DExROS2SimulationConfig]] + list of easy scenarios + """ + simulation_configs_paths: List[str] = [ + configs_dir + "1a_1t.yaml", + configs_dir + "1a_2bc.yaml", + configs_dir + "1bc_1rc_1yc.yaml", + configs_dir + "1carrot_1bc.yaml", + configs_dir + "1carrot_1corn.yaml", + configs_dir + "1yc_1rc.yaml", + configs_dir + "2rc.yaml", + configs_dir + "2t.yaml", + configs_dir + "2a_1bc.yaml", + configs_dir + "1carrot_1t_1rc.yaml", + ] + simulations_configs = [ + O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) + for path in simulation_configs_paths + ] + # place object at coords + place_obj_types = [ + "apple", + "tomato", + "carrot", + "yellow_cube", + "red_cube", + ] + target_coords = [(0.3, 0.3), (0.2, -0.4)] + allowable_displacements = [0.1] # large margin + place_object_tasks: List[Task] = [] + for obj in place_obj_types: + for coord in target_coords: + for disp in allowable_displacements: + place_object_tasks.append( + PlaceObjectAtCoordTask(obj, coord, disp, logger=logger) + ) + easy_place_objects_scenarios = Benchmark.create_scenarios( + tasks=place_object_tasks, + simulation_configs=simulations_configs, + simulation_configs_paths=simulation_configs_paths, + logger=logger, + ) + # move objects to the left + object_groups = [ + ["carrot"], + ["red_cube"], + ["blue_cube"], + ["yellow_cube"], + ["tomato"], + ] + + move_to_left_tasks: List[Task] = [ + MoveObjectsToLeftTask(obj_types=objects, logger=logger) + for objects in object_groups + ] + + easy_move_to_left_scenarios = Benchmark.create_scenarios( + tasks=move_to_left_tasks, + simulation_configs=simulations_configs, + simulation_configs_paths=simulation_configs_paths, + ) + + # place cubes + task = PlaceCubesTask(threshold_distance=0.2, logger=logger) + easy_place_cubes_scenarios = Benchmark.create_scenarios( + tasks=[task], + simulation_configs=simulations_configs, + simulation_configs_paths=simulation_configs_paths, + ) + + return [ + *easy_move_to_left_scenarios, + *easy_place_objects_scenarios, + *easy_place_cubes_scenarios, + ] + + +def medium_scenarios( + configs_dir: str, connector_path: str, logger: loggers_type | None +) -> List[Scenario[O3DExROS2SimulationConfig]]: + """Packet of medium scenarios. The grading is subjective. + This packet contains harder variants of 'easy' tasks with scenes containg 4-7 objects + and easy variants of 'hard' tasks with scenes contating 2-3 objects + + In this packet: + MoveObjectsToLeftTask with multiple object types to move + PlaceCubesTask with small threshold + BuildTowerTask with only one type of objects to move + GroupObjectsTask with only one type of objects to move + + This level of difficulty requires recognizing multiple proper type of objects. + All scenarios will require moving more than 1 object. + Some tasks will require good spacial awareness to make structures. + + Parameters + ---------- + configs_dir : str + path to directory with simulation configs + connector_path : str + path to connector config + + + Returns + ------- + List[Scenario[O3DExROS2SimulationConfig]] + list of easy scenarios + """ + medium_simulation_configs_paths: List[str] = [ + configs_dir + "1rc_2bc_3yc.yaml", + configs_dir + "2carrots_2a.yaml", + configs_dir + "2yc_1bc_1rc.yaml", + configs_dir + "4carrots.yaml", + configs_dir + "1carrot_1a_1t_1bc_1corn.yaml", + configs_dir + "4bc.yaml", + configs_dir + "2a_1c_2rc.yaml", + ] + + easy_simulation_configs_paths: List[str] = [ + configs_dir + "1a_1t.yaml", + configs_dir + "1a_2bc.yaml", + configs_dir + "1bc_1rc_1yc.yaml", + configs_dir + "1carrot_1bc.yaml", + configs_dir + "1carrot_1corn.yaml", + configs_dir + "1yc_1rc.yaml", + configs_dir + "2rc.yaml", + configs_dir + "2t.yaml", + configs_dir + "2a_1bc.yaml", + configs_dir + "1carrot_1t_1rc.yaml", + ] + medium_simulations_configs = [ + O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) + for path in medium_simulation_configs_paths + ] + easy_simulations_configs = [ + O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) + for path in easy_simulation_configs_paths + ] + # move objects to the left + object_groups = [ + ["red_cube", "blue_cube"], + ["carrots"], + ["carrots", "apple"], + ["yellow_cube", "blue_cube"], + ["tomato", "apple"], + ["blue_cube"], + ] + + move_to_left_tasks: List[Task] = [ + MoveObjectsToLeftTask(obj_types=objects, logger=logger) + for objects in object_groups + ] + + move_to_left_scenarios = Benchmark.create_scenarios( + tasks=move_to_left_tasks, + simulation_configs=medium_simulations_configs, + simulation_configs_paths=medium_simulation_configs_paths, + logger=logger, + ) + + # place cubes + task = PlaceCubesTask(threshold_distance=0.1, logger=logger) + easy_place_cubes_scenarios = Benchmark.create_scenarios( + tasks=[task], + simulation_configs=medium_simulations_configs, + simulation_configs_paths=medium_simulation_configs_paths, + logger=logger, + ) + + # build tower task + object_groups = [ + ["red_cube", "blue_cube", "yellow_cube"], + ] + + build_tower_tasks: List[Task] = [ + BuildCubeTowerTask(obj_types=objects, logger=logger) + for objects in object_groups + ] + + build_tower_scenarios = Benchmark.create_scenarios( + tasks=build_tower_tasks, + simulation_configs=easy_simulations_configs, + simulation_configs_paths=easy_simulation_configs_paths, + ) + + # group object task + object_groups = [ + ["apple"], + ["carrot"], + ["tomato"], + ["red_cube"], + ["tomato"], + ["blue_cube"], + ] + + group_object_tasks: List[Task] = [ + GroupObjectsTask(obj_types=objects, logger=logger) for objects in object_groups + ] + + group_object_scenarios = Benchmark.create_scenarios( + tasks=group_object_tasks, + simulation_configs=easy_simulations_configs, + simulation_configs_paths=easy_simulation_configs_paths, + ) + return [ + *move_to_left_scenarios, + *build_tower_scenarios, + *easy_place_cubes_scenarios, + *group_object_scenarios, + ] + + +def hard_scenarios( + configs_dir: str, connector_path: str, logger: loggers_type | None +) -> List[Scenario[O3DExROS2SimulationConfig]]: + """Packet of hard scenarios. The grading is subjective. + This packet contains harder variants of 'easy' tasks with majority of scenes containg 8+ objects, + Objects can be positioned in an unusual way, for example stacked. + And easy variants of 'hard' tasks with scenes containing 4-7 objects + + In this packet: + MoveObjectsToLeftTask with multiple object types to move + PlaceCubesTask with small threshold + BuildTowerTask with all cubes available + GroupObjectsTask with 1-2 types of objects to be grouped + + This level of difficulty requires recognizing multiple proper type of objects. + All scenarios will require moving multiple objects. + Some tasks will require good spacial awareness to make structures. + + Parameters + ---------- + configs_dir : str + path to directory with simulation configs + connector_path : str + path to connector config + + + Returns + ------- + List[Scenario[O3DExROS2SimulationConfig]] + list of easy scenarios + """ + medium_simulation_configs_paths: List[str] = [ + configs_dir + "1rc_2bc_3yc.yaml", + configs_dir + "2carrots_2a.yaml", + configs_dir + "2yc_1bc_1rc.yaml", + configs_dir + "4carrots.yaml", + configs_dir + "1carrot_1a_1t_1bc_1corn.yaml", + configs_dir + "4bc.yaml", + configs_dir + "2a_1c_2rc.yaml", + ] + + hard_simulation_configs_paths: List[str] = [ + configs_dir + "3carrots_1a_1t_2bc_2yc.yaml", + configs_dir + "1carrot_1a_2t_1bc_1rc_3yc_stacked.yaml", + configs_dir + "2carrots_1a_1t_1bc_1rc_1yc_1corn.yaml", + configs_dir + "2rc_3bc_4yc_stacked.yaml", + configs_dir + "2t_3a_1corn_2rc.yaml", + configs_dir + "3a_4t_2bc.yaml", + configs_dir + "2rc.yaml", + configs_dir + "3carrots_1a_2bc_1rc_1yc_1corn.yaml", + configs_dir + "3rc_3bc_stacked.yaml", + ] + medium_simulations_configs = [ + O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) + for path in medium_simulation_configs_paths + ] + hard_simulations_configs = [ + O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) + for path in hard_simulation_configs_paths + ] + # move objects to the left + object_groups = [ + ["red_cube", "blue_cube"], + ["carrots", "apple", "yellow_cube"], + ["carrots", "apple"], + ["yellow_cube", "blue_cube"], + ["tomato", "apple"], + ["blue_cube", "red_cube"], + ] + + move_to_left_tasks: List[Task] = [ + MoveObjectsToLeftTask(obj_types=objects, logger=logger) + for objects in object_groups + ] + + move_to_left_scenarios = Benchmark.create_scenarios( + tasks=move_to_left_tasks, + simulation_configs=hard_simulations_configs, + simulation_configs_paths=hard_simulation_configs_paths, + ) + + # place cubes + task = PlaceCubesTask(threshold_distance=0.1, logger=logger) + easy_place_cubes_scenarios = Benchmark.create_scenarios( + tasks=[task], + simulation_configs=hard_simulations_configs, + simulation_configs_paths=hard_simulation_configs_paths, + ) + + # build tower task + object_groups = [ + ["red_cube", "blue_cube", "yellow_cube"], + ] + + build_tower_tasks: List[Task] = [ + BuildCubeTowerTask(obj_types=objects, logger=logger) + for objects in object_groups + ] + + build_tower_scenarios = Benchmark.create_scenarios( + tasks=build_tower_tasks, + simulation_configs=medium_simulations_configs, + simulation_configs_paths=medium_simulation_configs_paths, + logger=logger, + ) + + # group object task + object_groups = [ + ["apple", "carrot"], + ["carrot", "tomato"], + ["tomato", "blue_cube"], + ["red_cube"], + ["carrot"], + ["blue_cube"], + ["yellow_cube", "red_cube"], + ] + + group_object_tasks: List[Task] = [ + GroupObjectsTask(obj_types=objects, logger=logger) for objects in object_groups + ] + + group_object_scenarios = Benchmark.create_scenarios( + tasks=group_object_tasks, + simulation_configs=medium_simulations_configs, + simulation_configs_paths=medium_simulation_configs_paths, + ) + return [ + *move_to_left_scenarios, + *build_tower_scenarios, + *easy_place_cubes_scenarios, + *group_object_scenarios, + ] + + +def very_hard_scenarios( + configs_dir: str, connector_path: str, logger: loggers_type | None +) -> List[Scenario[O3DExROS2SimulationConfig]]: + """Packet of very_hard scenarios. The grading is subjective. + This packet contains harder variants of 'hard' tasks with majority of scenes containg 8+ objects, + Objects can be positioned in an unusual way, for example stacked. + In this packet: + BuildTowerTask with only ceratin type of cubes + GroupObjectsTask with multiple objects to be grouped + + This level of difficulty requires recognizing multiple proper type of objects. + All scenarios will require moving multiple objects. + All tasks will require very good spacial awareness to make structures. + + Parameters + ---------- + configs_dir : str + path to directory with simulation configs + connector_path : str + path to connector config + + + Returns + ------- + List[Scenario[O3DExROS2SimulationConfig]] + list of easy scenarios + """ + hard_simulation_configs_paths: List[str] = [ + configs_dir + "3carrots_1a_1t_2bc_2yc.yaml", + configs_dir + "1carrot_1a_2t_1bc_1rc_3yc_stacked.yaml", + configs_dir + "2carrots_1a_1t_1bc_1rc_1yc_1corn.yaml", + configs_dir + "2rc_3bc_4yc_stacked.yaml", + configs_dir + "2t_3a_1corn_2rc.yaml", + configs_dir + "3a_4t_2bc.yaml", + configs_dir + "2rc.yaml", + configs_dir + "3carrots_1a_2bc_1rc_1yc_1corn.yaml", + configs_dir + "3rc_3bc_stacked.yaml", + ] + hard_simulations_configs = [ + O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) + for path in hard_simulation_configs_paths + ] + # build tower task + object_groups = [ + ["red_cube", "blue_cube"], + ["red_cube"], + ["blue_cube"], + ["yellow_cube"], + ["yellow_cube"], + ["blue_cube"], + ] + + build_tower_tasks: List[Task] = [ + BuildCubeTowerTask(obj_types=objects, logger=logger) + for objects in object_groups + ] + + build_tower_scenarios = Benchmark.create_scenarios( + tasks=build_tower_tasks, + simulation_configs=hard_simulations_configs, + simulation_configs_paths=hard_simulation_configs_paths, + logger=logger, + ) + + # group object task + object_groups = [ + ["apple", "carrot"], + ["carrot", "tomato"], + ["tomato", "blue_cube", "yellow_cube"], + ["red_cube", "blue_cube"], + ["tomato", "apple", "carrot"], + ["blue_cube", "carrot"], + ] + + group_object_tasks: List[Task] = [ + GroupObjectsTask(obj_types=objects, logger=logger) for objects in object_groups + ] + + group_object_scenarios = Benchmark.create_scenarios( + tasks=group_object_tasks, + simulation_configs=hard_simulations_configs, + simulation_configs_paths=hard_simulation_configs_paths, + ) + return [ + *build_tower_scenarios, + *group_object_scenarios, + ] diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py index 5be82bf8c..c3d5532e1 100644 --- a/src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/__init__.py @@ -11,7 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from rai_bench.o3de_test_bench.tasks.grab_carrot_task import GrabCarrotTask -from rai_bench.o3de_test_bench.tasks.place_cubes_task import PlaceCubesTask +from rai_bench.o3de_test_bench.tasks.build_tower_task import ( + BuildCubeTowerTask, +) +from rai_bench.o3de_test_bench.tasks.group_objects_task import ( + GroupObjectsTask, +) +from rai_bench.o3de_test_bench.tasks.move_object_to_left_task import ( + MoveObjectsToLeftTask, +) +from rai_bench.o3de_test_bench.tasks.place_at_coord_task import ( + PlaceObjectAtCoordTask, +) +from rai_bench.o3de_test_bench.tasks.place_cubes_task import ( + PlaceCubesTask, +) +from rai_bench.o3de_test_bench.tasks.rotate_object_task import ( + RotateObjectTask, +) -__all__ = ["GrabCarrotTask", "PlaceCubesTask"] +__all__ = [ + "BuildCubeTowerTask", + "GroupObjectsTask", + "MoveObjectsToLeftTask", + "PlaceCubesTask", + "PlaceObjectAtCoordTask", + "RotateObjectTask", +] diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/build_tower_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/build_tower_task.py new file mode 100644 index 000000000..f7ee9db37 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/build_tower_task.py @@ -0,0 +1,136 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List, Tuple, Union + +from rclpy.impl.rcutils_logger import RcutilsLogger + +from rai_bench.o3de_test_bench.tasks.manipulation_task import ( + ManipulationTask, +) +from rai_sim.simulation_bridge import Entity, SimulationConfig + +loggers_type = Union[RcutilsLogger, logging.Logger] + + +class BuildCubeTowerTask(ManipulationTask): + ALLOWED_OBJECTS = {"red_cube", "blue_cube", "yellow_cube"} + # fixed upper limit for allowable displacement for every object type + # it should ensure that this displacement is not greater than half of the object size + MAXIMUM_DISPLACEMENT = {"red_cube": 0.02, "blue_cube": 0.02, "yellow_cube": 0.02} + + def __init__( + self, + obj_types: List[str], + allowable_displacement: float = 0.02, + logger: loggers_type | None = None, + ): + """ + This task requires that cubes of the specified types are arranged into a single vertical tower. + Only objects with types specified in `obj_types` (which must be a subset of the allowed objects) + are considered. Cubes are grouped by their z-coordinate using a horizontal tolerance, and only + groups with more than one cube are considered towers. The height of the tallest tower determines + the number of correctly placed cubes. + + Parameters + ---------- + obj_types : List[str] + A list of cube types (e.g., ["red_cube", "blue_cube"]) to be used for building the tower. + Each type must be one of the allowed objects: {"red_cube", "blue_cube", "yellow_cube"}. + allowable_displacement : float, optional + The allowable horizontal displacement (tolerance, in meters) used when grouping cubes by their + z-coordinate. Default is 0.02. + + + Raises + ------ + TypeError + If any of the provided object types is not allowed. + """ + # NOTE (jmatejcz) what if allowable_displament is greater then the size of object? + # we could check the z distance between entities + # or trust user with this + super().__init__(logger) + if not set(obj_types).issubset(self.ALLOWED_OBJECTS): + raise TypeError( + f"Invalid obj_types provided: {obj_types}. Allowed objects: {self.ALLOWED_OBJECTS}" + ) + for obj_type in obj_types: + if allowable_displacement > self.MAXIMUM_DISPLACEMENT[obj_type]: + raise ValueError( + f"allowable_displacement too large. For object type: {obj_type} maximum is {self.MAXIMUM_DISPLACEMENT[obj_type]}" + ) + self.obj_types = obj_types + self.allowable_displacement = allowable_displacement + + def get_prompt(self) -> str: + cube_names = ", ".join(obj + "s" for obj in self.obj_types).replace("_", " ") + return f"Manipulate objects so that all {cube_names} form a single vertical tower. Other types of objects cannot be included in a tower." + + def check_if_required_objects_present( + self, simulation_config: SimulationConfig + ) -> bool: + """ + Validate that at least two cubes of the specified types are present. + + Returns + ------- + bool + True if at least two cubes of the allowed types are present, False otherwise. + """ + cube_count = sum( + 1 for ent in simulation_config.entities if ent.prefab_name in self.obj_types + ) + return cube_count > 1 + + def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + """ + Calculate the number of correctly and incorrectly placed cubes. + + This task does not consider a single cube as correctly placed. + Cubes are grouped by their z-coordinate using a horizontal tolerance. + The highest tower (the group with the most cubes) is considered correct, + and all other cubes are counted as incorrect. + + Parameters + ---------- + entities : List[Entity] + List of all entities present in the simulation scene. + + Returns + ------- + Tuple[int, int] + A tuple where the first element is the number of correctly placed cubes (from the tallest tower) + and the second element is the number of incorrectly placed cubes. + """ + + # Group entities by z-coordinate + grouped_entities = self.group_entities_along_z_axis( + entities, self.allowable_displacement + ) + selected_type_objects = self.filter_entities_by_object_type( + entities=entities, object_types=self.obj_types + ) + + correct = 0 + incorrect = 0 + for group in grouped_entities: + if len(group) > 1: + # we treat single standing cubes as incorrect + if all(entity.prefab_name in self.obj_types for entity in group): + # highest tower is number of correctly placed objects + correct = max(correct, len(group)) + incorrect = len(selected_type_objects) - correct + return correct, incorrect diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py deleted file mode 100644 index d011051d1..000000000 --- a/src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (C) 2025 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Tuple - -from rai_bench.benchmark_model import EntitiesMismatchException, Task -from rai_sim.o3de.o3de_bridge import ( - SimulationBridge, -) -from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT, SpawnedEntity - - -class GrabCarrotTask(Task): - obj_types = ["carrot"] - - # TODO (jm) extract common logic to some parent manipulation task - def get_prompt(self) -> str: - return "Manipulate objects, so that all carrots to the left side of the table (positive y)" - - def validate_config(self, simulation_config: SimulationConfig) -> bool: - for ent in simulation_config.entities: - if ent.prefab_name in self.obj_types: - return True - - return False - - def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]: - """Calculate how many objects are positioned correct and incorrect""" - correct = sum(1 for ent in entities if ent.pose.translation.y > 0.0) - incorrect: int = len(entities) - correct - return correct, incorrect - - def calculate_initial_placements( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> tuple[int, int]: - """ - Calculates the number of objects that are correctly and incorrectly placed initially. - """ - initial_carrots = self.filter_entities_by_prefab_type( - simulation_bridge.spawned_entities, prefab_types=self.obj_types - ) - initially_correct, initially_incorrect = self.calculate_correct( - entities=initial_carrots - ) - - self.logger.info( # type: ignore - f"Initially correctly placed carrots: {initially_correct}, Initially incorrectly placed carrots: {initially_incorrect}" - ) - return initially_correct, initially_incorrect - - def calculate_final_placements( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> tuple[int, int]: - """ - Calculates the number of objects that are correctly and incorrectly placed at the end of the simulation. - """ - scene_state = simulation_bridge.get_scene_state() - final_carrots = self.filter_entities_by_prefab_type( - scene_state.entities, prefab_types=self.obj_types - ) - final_correct, final_incorrect = self.calculate_correct(entities=final_carrots) - - self.logger.info( # type: ignore - f"Finally correctly placed carrots: {final_correct}, Finally incorrectly placed carrots: {final_incorrect}" - ) - return final_correct, final_incorrect - - def calculate_result( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> float: - """ - Calculates a score from 0.0 to 1.0, where 0.0 represents the initial placements or worse and 1.0 represents perfect final placements. - """ - initially_correct, initially_incorrect = self.calculate_initial_placements( - simulation_bridge - ) - final_correct, final_incorrect = self.calculate_final_placements( - simulation_bridge - ) - - total_objects = initially_correct + initially_incorrect - if total_objects == 0: - return 1.0 - elif (initially_correct + initially_incorrect) != ( - final_correct + final_incorrect - ): - raise EntitiesMismatchException( - "number of initial entities does not match final entities number." - ) - elif initially_incorrect == 0: - pass - # NOTE all objects are placed correctly - # no point in running task - raise ValueError("All objects are placed correctly at the start.") - else: - corrected = final_correct - initially_correct - score = max(0.0, corrected / initially_incorrect) - - self.logger.info(f"Calculated score: {score:.2f}") # type: ignore - return score diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/group_objects_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/group_objects_task.py new file mode 100644 index 000000000..cdc24faca --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/group_objects_task.py @@ -0,0 +1,130 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import List, Tuple, Union + +from rclpy.impl.rcutils_logger import RcutilsLogger + +from rai_bench.o3de_test_bench.tasks.manipulation_task import ( + ManipulationTask, +) +from rai_sim.simulation_bridge import Entity, SimulationConfig + +loggers_type = Union[RcutilsLogger, logging.Logger] + + +class GroupObjectsTask(ManipulationTask): + def __init__( + self, + obj_types: List[str], + threshold_distance: float = 0.15, + logger: loggers_type | None = None, + ): + """ + This task requires that objects of specified types form a single, well-defined cluster. + + Parameters + ---------- + obj_types : List[str] + A list of object types to be grouped into clusters. Only objects whose prefab names match + one of these types will be evaluated. + threshold_distance : float, optional + The maximum distance between two objects (in meters) for them to be considered neighbours + when building the neighbourhood list. Defaults to 0.15. + """ + super().__init__(logger) + self.obj_types = obj_types + self.threshold_distance = threshold_distance + + def get_prompt(self) -> str: + obj_names = ", ".join(obj + "s" for obj in self.obj_types).replace( + "_", " " + ) # create prompt, add s for plural form + return ( + f"Manipulate objects so that {obj_names} form separate clusters based on their types. " + "Each cluster must: " + "1. Contain ALL objects of the same type " + "2. Contain ONLY objects of the same type " + "3. Form a single connected group " + "4. Be completely separated from other clusters " + ) + + def check_if_required_objects_present( + self, simulation_config: SimulationConfig + ) -> bool: + """ + Returns + ------- + bool + True if at least one entity of all specified object types are present, False otherwise. + """ + object_types_present = { + ent.prefab_name + for ent in simulation_config.entities + if ent.prefab_name in self.obj_types + } + + return set(self.obj_types) <= object_types_present + + def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + """ + Count correctly and incorrectly clustered objects based on clustering rules. + + Method first groups the entities by type. + Then, using the specified threshold distance, it builds a neighbourhood list + and identifies clusters using a DFS-based algorithm. + A cluster is considered properly clustered if: + 1. Only one cluster is found for that type. + 2. All objects in the cluster have neighbours exclusively of the same type. + If these conditions are met, the objects in that cluster are counted as correctly clustered. + Otherwise, all objects of that type are counted as misclustered. + + Parameters + ---------- + entities : List[Entity] + List of all entities present in the simulation scene. + + Returns + ------- + Tuple[int, int] + A tuple where the first element is the number of correctly clustered objects + and the second element is the number of misclustered objects. + """ + properly_clustered: List[Entity] = [] + misclustered: List[Entity] = [] + + neighbourhood_list = self.build_neighbourhood_list( + entities, threshold_distance=self.threshold_distance + ) + clusters = self.find_clusters(neighbourhood_list) + selected_type_objects = self.filter_entities_by_object_type( + entities=entities, object_types=self.obj_types + ) + entities_by_type = self.group_entities_by_type(selected_type_objects) + for obj_type, objects in entities_by_type.items(): + # Filter clusters that contain only entities of this type. + clusters_of_type = [ + cluster + for cluster in clusters + if all(ent.prefab_name == obj_type for ent in cluster) + ] + # Check if exactly one cluster exists and it includes all objects of that type. + if len(clusters_of_type) == 1 and len(objects) == len(clusters_of_type[0]): + # Verify that every entity in this cluster has neighbours exclusively of the allowed type. + properly_clustered.extend(clusters_of_type[0]) + else: + # Either no cluster or more than one cluster means misclustering. + misclustered.extend(objects) + + return len(properly_clustered), len(misclustered) diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/manipulation_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/manipulation_task.py new file mode 100644 index 000000000..22d11a00f --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/manipulation_task.py @@ -0,0 +1,198 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from abc import ABC, abstractmethod +from typing import List, Tuple, Union + +from rclpy.impl.rcutils_logger import RcutilsLogger + +from rai_bench.benchmark_model import ( + EntitiesMismatchException, + EntityT, + Task, +) +from rai_sim.simulation_bridge import ( + SimulationBridge, + SimulationConfig, + SimulationConfigT, +) + +loggers_type = Union[RcutilsLogger, logging.Logger] + + +class ManipulationTask(Task, ABC): + """ + Common class for manipulaiton tasks + obj_types variable represents object types that will be considered as the subject of the task. + That means that based on positions of these objects simulation config will be evaluated + and score will be calculated. + + Example + ------- + MoveObjectsToLeftTask with 'carrot' as objects type, will check if carrtos are present + and then calculated score based on how many carrots were moved to the left side + """ + + obj_types: List[str] = [] + + @abstractmethod + def check_if_required_objects_present( + self, simulation_config: SimulationConfig + ) -> bool: + """ + Check if the required objects are present in the simulation configuration. + + Returns + ------- + bool + True if the required objects are present, False otherwise. + """ + return True + + def check_if_any_placed_incorrectly( + self, simulation_config: SimulationConfig + ) -> bool: + """ + Check if any object is placed incorrectly in the simulation configuration. + Save number of initially correctly and incorrectly placed objects for + future calculations + + Returns + ------- + bool + True if at least one object is placed incorrectly, False otherwise. + """ + _, incorrect = self.calculate_correct(entities=simulation_config.entities) + return incorrect > 0 + + def validate_config(self, simulation_config: SimulationConfig) -> bool: + """ + Validate the simulation configuration. + + Checks whether the required objects are present and if any of them is placed incorrectly. + If these conditions are not met, the task should not be run with this configuration. + + Parameters + ---------- + simulation_config : SimulationConfig + The simulation configuration to validate. + + Returns + ------- + bool + True if the configuration is valid, False otherwise. + """ + + if self.check_if_required_objects_present( + simulation_config=simulation_config + ) and self.check_if_any_placed_incorrectly(simulation_config=simulation_config): + return True + else: + return False + + @abstractmethod + def calculate_correct(self, entities: List[EntityT]) -> Tuple[int, int]: + """Method to calculate how many objects are placed correctly + + Parameters + ---------- + entities : List[EntityT] + list of ALL entities present in the simulaiton scene + + Returns + ------- + Tuple[int, int] + first int HAVE TO be number of correctly placed objects, second int - number of incorrectly placed objects + """ + pass + + def calculate_current_placements( + self, simulation_bridge: SimulationBridge[SimulationConfigT] + ) -> tuple[int, int]: + """ + Get the current placements of objects in the simulation + and calculated their current placements + + Parameters + ---------- + simulation_bridge : SimulationBridge[SimulationConfigT] + The simulation bridge containing the current scene state. + + Returns + ------- + tuple[int, int] + A tuple where the first element is the number of currently correctly placed objects + and the second element is the number of currently incorrectly placed objects. + """ + scene_state = simulation_bridge.get_scene_state() + current_correct, current_incorrect = self.calculate_correct( + entities=scene_state.entities + ) + + self.logger.info( + f"Currently correctly placed objects: {current_correct}, Currenlty incorrectly placed objects: {current_incorrect}" + ) + return current_correct, current_incorrect + + def calculate_result( + self, simulation_bridge: SimulationBridge[SimulationConfig] + ) -> float: + """ + Calculate the task score based on the difference between initial and current placements. + + The score ranges from 0.0 to 1.0, where 0.0 indicates that the initial placements + remain unchanged (or got worse), and 1.0 indicates perfect placements relative to the initial ones. + The score is computed as the improvement in the number of correctly placed objects + divided by the number of initially incorrectly placed objects. + + Parameters + ---------- + simulation_bridge : SimulationBridge[SimulationConfig] + The simulation bridge that provides access to the current scene state. + + Returns + ------- + float + The calculated score, ranging from 0.0 to 1.0. + + Raises + ------ + EntitiesMismatchException + If the total number of initial entities does not match the total number of current entities. + """ + initially_correct, initially_incorrect = self.calculate_correct( + entities=simulation_bridge.spawned_entities + ) + self.logger.info( + f"Objects placed correctly in simulation config: {initially_correct}, incorrectly: {initially_incorrect}" + ) + current_correct, current_incorrect = self.calculate_current_placements( + simulation_bridge + ) + + initial_objects_num = initially_correct + initially_incorrect + current_objects_num = current_correct + current_incorrect + if initial_objects_num == 0: + return 1.0 + elif initial_objects_num != current_objects_num: + raise EntitiesMismatchException( + f"number of initial entities does not match current entities number, initially: {initially_correct + initially_incorrect}, current: {current_correct + current_incorrect}" + ) + else: + corrected = current_correct - initially_correct + score = max(0.0, corrected / initially_incorrect) + + self.logger.info(f"Calculated score: {score:.2f}") + return score diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/move_object_to_left_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/move_object_to_left_task.py new file mode 100644 index 000000000..f76ff83cb --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/move_object_to_left_task.py @@ -0,0 +1,79 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import List, Tuple, Union + +from rclpy.impl.rcutils_logger import RcutilsLogger + +from rai_bench.o3de_test_bench.tasks.manipulation_task import ( + ManipulationTask, +) +from rai_sim.simulation_bridge import Entity, SimulationConfig + +loggers_type = Union[RcutilsLogger, logging.Logger] + + +class MoveObjectsToLeftTask(ManipulationTask): + def __init__(self, obj_types: List[str], logger: loggers_type | None = None): + """ + This task requires moving all objects of specified types to the left side of the table (positive y). + + Parameters + ---------- + obj_types : List[str] + A list of object types to be moved. + """ + super().__init__(logger=logger) + self.obj_types = obj_types + + def get_prompt(self) -> str: + obj_names = ", ".join(obj + "s" for obj in self.obj_types).replace("_", " ") + # NOTE (jmatejcz) specifing (positive y) might not be the best way to tell agent what to do, + # but 'left side' is depending on where camera is positioned so it might not be enough + return f"Manipulate objects, so that all of the {obj_names} are on the left side of the table (positive y)" + + def check_if_required_objects_present( + self, simulation_config: SimulationConfig + ) -> bool: + """Validate if any object present""" + object_types_present = self.group_entities_by_type( + entities=simulation_config.entities + ) + return set(self.obj_types) <= object_types_present.keys() + + def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + """ + Calculate the number of objects correctly moved to the left side of the table. + + An object is considered correctly placed if its y-coordinate is positive. + + Parameters + ---------- + entities : List[Entity] + List of all entities present in the simulation scene. + + Returns + ------- + Tuple[int, int] + A tuple where the first element is the number of correctly placed objects (with positive y) + and the second element is the number of incorrectly placed objects. + """ + selected_type_objects = self.filter_entities_by_object_type( + entities=entities, object_types=self.obj_types + ) + correct = sum( + 1 for ent in selected_type_objects if ent.pose.translation.y > 0.0 + ) + incorrect: int = len(selected_type_objects) - correct + return correct, incorrect diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_at_coord_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_at_coord_task.py new file mode 100644 index 000000000..0a14396f4 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_at_coord_task.py @@ -0,0 +1,101 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +from typing import List, Tuple, Union + +from rclpy.impl.rcutils_logger import RcutilsLogger + +from rai_bench.o3de_test_bench.tasks.manipulation_task import ( + ManipulationTask, +) +from rai_sim.simulation_bridge import Entity, SimulationConfig + +loggers_type = Union[RcutilsLogger, logging.Logger] + + +class PlaceObjectAtCoordTask(ManipulationTask): + def __init__( + self, + obj_type: str, + target_position: Tuple[float, float], + allowable_displacement: float = 0.02, + logger: loggers_type | None = None, + ): + """ + This task requires placing one object of specified type into specified coords. + + Parameters + ---------- + obj_type : str + The type of object to be placed. + target_position : Tuple[float, float] + The target (x, y) coordinates (in meters) where one object of the specified type should be placed. + The z coordinate is not enforced. + allowable_displacement : float, optional + The acceptable deviation (in meters) from the target (x, y) coordinates. + Defaults to 0.02. + """ + super().__init__(logger) + self.obj_type = obj_type + self.target_position = target_position + self.allowable_displacement = allowable_displacement + + def get_prompt(self) -> str: + x, y = self.target_position + return ( + f"Manipulate one {self.obj_type.replace('_', ' ')} so that it is placed at " + f"the coordinates (x: {x}, y: {y})." + ) + + def check_if_required_objects_present( + self, simulation_config: SimulationConfig + ) -> bool: + count = sum( + 1 for ent in simulation_config.entities if ent.prefab_name == self.obj_type + ) + return count >= 1 + + def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + """ + Calculate the number of correctly and incorrectly placed objects. + + This task is successful if exactly one object of the specified type is placed + at the target (x, y) coordinates (within the allowable displacement). If more than one + object exists, only one counts as correct or incorrect + + Parameters + ---------- + entities : List[Entity] + List of all entities present in the simulation scene. + + Returns + ------- + Tuple[int, int] + A tuple where the first element number of correctly placed objects, second number of incorrect + """ + target_objects = [ent for ent in entities if ent.prefab_name == self.obj_type] + correct = 0 + + for ent in target_objects: + dx = ent.pose.translation.x - self.target_position[0] + dy = ent.pose.translation.y - self.target_position[1] + distance = math.sqrt(dx**2 + dy**2) + if distance <= self.allowable_displacement: + correct = 1 # Only one correct placement is needed. + break + + incorrect = 1 - correct + return correct, incorrect diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py index e7547c49e..08096bd88 100644 --- a/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py @@ -11,24 +11,52 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +import logging +from typing import List, Tuple, Union -from rai_bench.benchmark_model import ( - EntitiesMismatchException, - Task, +from rclpy.impl.rcutils_logger import RcutilsLogger + +from rai_bench.o3de_test_bench.tasks.manipulation_task import ( + ManipulationTask, ) -from rai_sim.o3de.o3de_bridge import SimulationBridge -from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT, SpawnedEntity +from rai_sim.simulation_bridge import Entity, SimulationConfig + +loggers_type = Union[RcutilsLogger, logging.Logger] -class PlaceCubesTask(Task): - # TODO (jm) extract common logic to some parent manipulation task +class PlaceCubesTask(ManipulationTask): obj_types = ["red_cube", "blue_cube", "yellow_cube"] + def __init__( + self, + threshold_distance: float = 0.15, + logger: loggers_type | None = None, + ): + """ + This task requires that evry cube is placed to at least one other cube. + + Parameters + ---------- + threshold_distance : float, optional + The distance threshold (in meters) used to determine if two cubes are adjacent. If the Euclidean + distance between two cubes is less than or equal to this value, they are considered adjacent. + Defaults to 0.15. + """ + super().__init__(logger) + self.threshold_distance = threshold_distance + def get_prompt(self) -> str: return "Manipulate objects, so that all cubes are adjacent to at least one cube" - def validate_config(self, simulation_config: SimulationConfig) -> bool: + def check_if_required_objects_present( + self, simulation_config: SimulationConfig + ) -> bool: + """ + Returns + ------- + bool + True if at least two cubes are present; otherwise, False. + """ cubes_num = 0 for ent in simulation_config.entities: if ent.prefab_name in self.obj_types: @@ -38,80 +66,33 @@ def validate_config(self, simulation_config: SimulationConfig) -> bool: return False - def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]: - """Calculate how many objects are positioned correct and incorrect""" + def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + """ + Calculate the number of correctly and incorrectly placed cubes based on adjacency. + + An object is considered correctly placed if it is adjacent to at least one other cube + within the given threshold distance. + + Parameters + ---------- + entities : List[Entity] + List of all entities (cubes) present in the simulation scene. + + Returns + ------- + Tuple[int, int] + A tuple where the first element is the number of correctly placed cubes (i.e., cubes that + are adjacent to at least one other cube) and the second element is the number of + incorrectly placed cubes. + """ correct = sum( 1 for ent in entities if self.is_adjacent_to_any( - ent.pose, [e.pose for e in entities if e != ent], 0.15 + ent.pose, + [e.pose for e in entities if e != ent], + self.threshold_distance, ) ) incorrect: int = len(entities) - correct return correct, incorrect - - def calculate_initial_placements( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> tuple[int, int]: - """ - Calculates the number of objects that are correctly and incorrectly placed initially. - """ - initial_cubes = self.filter_entities_by_prefab_type( - simulation_bridge.spawned_entities, prefab_types=self.obj_types - ) - initially_correct, initially_incorrect = self.calculate_correct( - entities=initial_cubes - ) - - self.logger.info( # type: ignore - f"Initially correctly placed cubes: {initially_correct}, Initially incorrectly placed cubes: {initially_incorrect}" - ) - return initially_correct, initially_incorrect - - def calculate_final_placements( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> tuple[int, int]: - """ - Calculates the number of objects that are correctly and incorrectly placed at the end of the simulation. - """ - scene_state = simulation_bridge.get_scene_state() - final_cubes = self.filter_entities_by_prefab_type( - scene_state.entities, prefab_types=self.obj_types - ) - final_correct, final_incorrect = self.calculate_correct(entities=final_cubes) - - self.logger.info( # type: ignore - f"Finally correctly placed cubes: {final_correct}, Finally incorrectly placed cubes: {final_incorrect}" - ) - return final_correct, final_incorrect - - def calculate_result( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> float: - """ - Calculates a score from 0.0 to 1.0, where 0.0 represents the initial placements or worse and 1.0 represents perfect final placements. - """ - initially_correct, initially_incorrect = self.calculate_initial_placements( - simulation_bridge - ) - final_correct, final_incorrect = self.calculate_final_placements( - simulation_bridge - ) - - total_objects = initially_correct + initially_incorrect - if total_objects == 0: - return 1.0 - elif (initially_correct + initially_incorrect) != ( - final_correct + final_incorrect - ): - raise EntitiesMismatchException( - "number of initial entities does not match final entities number." - ) - elif initially_incorrect == 0: - raise ValueError("All objects are placed correctly at the start.") - else: - corrected = final_correct - initially_correct - score = max(0.0, corrected / initially_incorrect) - - self.logger.info(f"Calculated score: {score:.2f}") # type: ignore - return score diff --git a/src/rai_bench/rai_bench/o3de_test_bench/tasks/rotate_object_task.py b/src/rai_bench/rai_bench/o3de_test_bench/tasks/rotate_object_task.py new file mode 100644 index 000000000..6b80c9891 --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/tasks/rotate_object_task.py @@ -0,0 +1,124 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +from typing import List, Tuple, Union + +from rclpy.impl.rcutils_logger import RcutilsLogger + +from rai_bench.o3de_test_bench.tasks.manipulation_task import ( + ManipulationTask, +) +from rai_sim.simulation_bridge import Entity, Rotation, SimulationConfig + +loggers_type = Union[RcutilsLogger, logging.Logger] + + +class RotateObjectTask(ManipulationTask): + def __init__( + self, + obj_types: List[str], + target_quaternion: Rotation, + logger: loggers_type | None = None, + ): + # NOTE (jmatejcz) for now manipulaiton tool does not support passing rotation + # NOTE (jmatejcz) rotating around other axis than z seems to not have much sense as the objecet will fall + # can the target rotation be expressed differently? maybe only by rotation around z axis as an angle? + """ + Parameters + ---------- + obj_types : List[str] + List of allowed object types that will be rotated. + target_quaternion : Tuple[float, float, float, float] + The target rotation expressed as a quaternion (x, y, z, w). + """ + super().__init__(logger=logger) + self.obj_types = obj_types + self.target_quaternion = target_quaternion + + def get_prompt(self) -> str: + object_names = ", ".join(obj.replace("_", " ") for obj in self.obj_types) + return ( + f"Rotate each {object_names} to the target orientation specified by the quaternion " + f"- x:{self.target_quaternion.x}, y:{self.target_quaternion.y}, z:{self.target_quaternion.z}, w:{self.target_quaternion.w} " + "Remember to rotate the gripper when grabbing objects." + ) + + def check_if_required_objects_present( + self, simulation_config: SimulationConfig + ) -> bool: + """ + Validate that at least one object of the specified types is present. + + Returns + ------- + bool + True if at least one allowed object is present, False otherwise. + """ + return any( + ent.prefab_name in self.obj_types for ent in simulation_config.entities + ) + + def calculate_correct( + self, entities: List[Entity], allowable_rotation_error: float = 5.0 + ) -> Tuple[int, int]: + """ + Calculate the number of correctly rotated objects and incorrectly rotated objects, + operating on quaternion representations. + + For each object, the dot product between its rotation quaternion and the target quaternion + is computed. The angular difference is calculated as: + + angle_diff = 2 * acos(|dot(current, target)|) + + This value (converted from radians to degrees) is compared with the allowable rotation error. + If the difference is within the allowable error, the object's orientation is considered correct. + + Parameters + ---------- + entities : List[Entity] + List of all entities present in the simulation scene. + allowable_rotation_error : float, optional + The acceptable deviation (in degrees) from the target rotation. Defaults to 5.0. + + Returns + ------- + Tuple[int, int] + A tuple where the first element is the number of correctly rotated objects and the second element + is the number of incorrectly rotated objects. + """ + correct = 0 + incorrect = 0 + for entity in entities: + if entity.prefab_name in self.obj_types: + if not entity.pose.rotation: + ValueError("Entity has no rotation defined.") + else: + dot = ( + entity.pose.rotation.x * self.target_quaternion.x + + entity.pose.rotation.y * self.target_quaternion.y + + entity.pose.rotation.z * self.target_quaternion.z + + entity.pose.rotation.w * self.target_quaternion.w + ) + # Account for the double cover: q and -q represent the same rotation. + dot = abs(dot) + # Clamp the dot product to avoid domain errors. + dot = max(min(dot, 1.0), -1.0) + angle_diff_deg = math.degrees(2 * math.acos(dot)) + if angle_diff_deg <= allowable_rotation_error: + correct += 1 + else: + incorrect += 1 + return correct, incorrect diff --git a/src/rai_core/rai/agents/conversational_agent.py b/src/rai_core/rai/agents/conversational_agent.py index 739b159ae..40e74c784 100644 --- a/src/rai_core/rai/agents/conversational_agent.py +++ b/src/rai_core/rai/agents/conversational_agent.py @@ -21,6 +21,7 @@ from langchain_core.messages import BaseMessage, SystemMessage from langchain_core.tools import BaseTool from langgraph.graph import START, StateGraph +from langgraph.graph.state import CompiledStateGraph from langgraph.prebuilt.tool_node import tools_condition from rclpy.impl.rcutils_logger import RcutilsLogger @@ -54,7 +55,7 @@ def create_conversational_agent( system_prompt: str, logger: Optional[RcutilsLogger | logging.Logger] = None, debug=False, -): +) -> CompiledStateGraph: _logger = None if logger: _logger = logger diff --git a/src/rai_sim/rai_sim/o3de/o3de_bridge.py b/src/rai_sim/rai_sim/o3de/o3de_bridge.py index 14ba247d4..924465879 100644 --- a/src/rai_sim/rai_sim/o3de/o3de_bridge.py +++ b/src/rai_sim/rai_sim/o3de/o3de_bridge.py @@ -43,6 +43,7 @@ class O3DExROS2SimulationConfig(SimulationConfig): binary_path: Path + level: Optional[str] = None robotic_stack_command: str required_simulation_ros2_interfaces: dict[str, List[str]] required_robotic_ros2_interfaces: dict[str, List[str]] @@ -201,7 +202,7 @@ def get_scene_state(self) -> SceneState: return SceneState(entities=entities) def _is_ros2_stack_ready( - self, required_ros2_stack: dict[str, List[str]], retries: int = 120 + self, required_ros2_stack: dict[str, List[str]], retries: int = 360 ) -> bool: for i in range(retries): available_topics = self.connector.get_topics_names_and_types() @@ -266,7 +267,10 @@ def _is_ros2_stack_ready( ) return False - def setup_scene(self, simulation_config: O3DExROS2SimulationConfig): + def setup_scene( + self, + simulation_config: O3DExROS2SimulationConfig, + ): if self.current_binary_path != simulation_config.binary_path: if self.current_sim_process: self.shutdown() @@ -282,8 +286,15 @@ def setup_scene(self, simulation_config: O3DExROS2SimulationConfig): for entity in simulation_config.entities: self._spawn_entity(entity) - def _launch_binary(self, simulation_config: O3DExROS2SimulationConfig): - command = [simulation_config.binary_path.as_posix()] + def _launch_binary( + self, + simulation_config: O3DExROS2SimulationConfig, + ): + command = [ + simulation_config.binary_path.as_posix(), + ] + if simulation_config.level: + command.append(f"+LoadLevel {simulation_config.level}") self.logger.info(f"Running command: {command}") self.current_sim_process = subprocess.Popen( command, @@ -383,6 +394,15 @@ def _from_ros2_pose(self, pose: ROS2Pose) -> Pose: class O3DEngineArmManipulationBridge(O3DExROS2Bridge): + def reset_arm(self): + self.connector.service_call( + ROS2ARIMessage(payload={}), + target="/reset_manipulator", + msg_type="std_srvs/srv/Trigger", + ) + + self.connector.node.get_logger().debug("Reset manipulator arm: DONE") + def move_arm( self, pose: Pose, diff --git a/src/rai_sim/rai_sim/simulation_bridge.py b/src/rai_sim/rai_sim/simulation_bridge.py index 336406ad0..9241583e4 100644 --- a/src/rai_sim/rai_sim/simulation_bridge.py +++ b/src/rai_sim/rai_sim/simulation_bridge.py @@ -68,6 +68,15 @@ class Entity(BaseModel): ) pose: Pose = Field(description="Initial pose of the entity") + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other) -> bool: + if isinstance(other, Entity) or isinstance(other, SpawnedEntity): + return self.name == other.name + else: + return False + class SpawnedEntity(Entity): """ diff --git a/tests/rai_bench/conftest.py b/tests/rai_bench/conftest.py new file mode 100644 index 000000000..e0d117e36 --- /dev/null +++ b/tests/rai_bench/conftest.py @@ -0,0 +1,35 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rai_sim.simulation_bridge import ( + Entity, + Pose, + Rotation, + Translation, +) + + +def create_entity( + name: str, + prefab: str, + x: float, + y: float, + z: float, + rotation: Rotation | None = None, +) -> Entity: + return Entity( + name=name, + prefab_name=prefab, + pose=Pose(translation=Translation(x=x, y=y, z=z), rotation=rotation), + ) diff --git a/tests/rai_bench/tasks/test_build_tower_task.py b/tests/rai_bench/tasks/test_build_tower_task.py new file mode 100644 index 000000000..e21c46b47 --- /dev/null +++ b/tests/rai_bench/tasks/test_build_tower_task.py @@ -0,0 +1,84 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from rai_bench.o3de_test_bench.tasks import BuildCubeTowerTask +from tests.rai_bench.conftest import create_entity + + +def test_calculate_proper_tower() -> None: + task = BuildCubeTowerTask(["red_cube"]) + e1 = create_entity("cube1", "red_cube", 0.0, 0.0, 0.0) + e2 = create_entity("cube2", "red_cube", 0.01, 0.01, 0.03) + e3 = create_entity("cube3", "red_cube", 0.0, 0.0, 0.04) + correct, incorrect = task.calculate_correct([e1, e2, e3]) + assert correct == 3 + assert incorrect == 0 + + +def test_calculate_multiple_groups() -> None: + task = BuildCubeTowerTask(["red_cube"]) + e1 = create_entity("cube1", "red_cube", 0.0, 0.0, 0.0) # Group 1 + e2 = create_entity("cube2", "red_cube", 0.0, 0.0, 0.03) # Group 1 + e3 = create_entity("cube3", "red_cube", 0.0, 0.0, 0.06) # Group 1 + + e4 = create_entity("cube4", "red_cube", 0.0, 1.0, 0.03) # Group 2 + e5 = create_entity("cube5", "red_cube", 0.0, 1.0, 0.06) # Group 2 + + # correct objects should be 3 as highest tower is 3 cubes high + correct, incorrect = task.calculate_correct([e1, e2, e3, e4, e5]) + assert correct == 3 + assert incorrect == 2 + + +def test_calculate_invalid_entity() -> None: + task = BuildCubeTowerTask(["red_cube"]) + e1 = create_entity("cube1", "red_cube", 0.0, 0.0, 0.0) + e2 = create_entity("cube2", "yellow_cube", 0.0, 0.0, 0.03) + e3 = create_entity("cube3", "red_cube", 0.0, 0.0, 0.06) + correct, incorrect = task.calculate_correct([e1, e2, e3]) + # The presence of an invalid object causes all cubes to be marked as incorrect. + assert correct == 0 + assert incorrect == 2 + + +def test_calculate_single_entity() -> None: + task = BuildCubeTowerTask(["red_cube"]) + e1 = create_entity("cube1", "red_cube", 0.0, 0.0, 0.0) + correct, incorrect = task.calculate_correct([e1]) + # A single cube in a group is considered incorrectly placed. + assert correct == 0 + assert incorrect == 1 + + +def test_calculate_seperate_entities() -> None: + task = BuildCubeTowerTask(["red_cube"]) + e1 = create_entity("cube1", "red_cube", 0.0, 0.0, 0.0) + e2 = create_entity("cube2", "red_cube", 0.2, 0.2, 0.0) + e3 = create_entity("cube3", "red_cube", -0.2, -0.2, 0.0) + correct, incorrect = task.calculate_correct([e1, e2, e3]) + # not a single cube should be treated as correct + assert correct == 0 + assert incorrect == 3 + + +def test_too_big_displacement() -> None: + with pytest.raises(ValueError): + BuildCubeTowerTask(["red_cube"], allowable_displacement=0.1) + + +def test_not_allowable_type() -> None: + with pytest.raises(TypeError): + BuildCubeTowerTask(["red_cube", "apple"], allowable_displacement=0.1) diff --git a/tests/rai_bench/tasks/test_group_objects_task.py b/tests/rai_bench/tasks/test_group_objects_task.py new file mode 100644 index 000000000..ac47728f0 --- /dev/null +++ b/tests/rai_bench/tasks/test_group_objects_task.py @@ -0,0 +1,117 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from rai_bench.o3de_test_bench.tasks import GroupObjectsTask +from rai_sim.simulation_bridge import Entity +from tests.rai_bench.conftest import create_entity + + +def test_calculate_proper_cluster() -> None: + task = GroupObjectsTask(["red_cube"]) + e1 = create_entity("e1", "red_cube", 0.0, 0.0, 0.0) + e2 = create_entity("e2", "red_cube", 0.1, 0.0, 0.0) + e3 = create_entity("e3", "red_cube", 0.05, 0.05, 0.0) + entities: List[Entity] = [e1, e2, e3] + + correct, misclustered = task.calculate_correct(entities) + # Since all three red_cube objects form a single cluster, they are all considered correctly clustered. + assert correct == 3 + assert misclustered == 0 + + +def test_calculate_multiple_clusters_same_type() -> None: + task = GroupObjectsTask(["red_cube"]) + # Cluster 1 + e1 = create_entity("e1", "red_cube", 0.0, 0.0, 0.0) + e2 = create_entity("e2", "red_cube", 0.05, 0.0, 0.0) + # Cluster 2 + e3 = create_entity("e3", "red_cube", 5.0, 5.0, 0.0) + e4 = create_entity("e4", "red_cube", 5.05, 5.05, 0.0) + entities: List[Entity] = [e1, e2, e3, e4] + + correct, misclustered = task.calculate_correct(entities) + # Since there are two clusters for red_cube, all objects are considered misclustered. + assert correct == 0 + assert misclustered == 4 + + +def test_calculate_multiple_types_proper_clusters() -> None: + task = GroupObjectsTask(["red_cube", "blue_cube"]) + # Red cubes: form a proper cluster. + r1 = create_entity("r1", "red_cube", 0.0, 0.0, 0.0) + r2 = create_entity("r2", "red_cube", 0.1, 0.0, 0.0) + r3 = create_entity("r3", "red_cube", 0.05, 0.05, 0.0) + # Blue cubes: placed far apart so that each becomes its own cluster. + b1 = create_entity("b1", "blue_cube", 10.0, 10.0, 0.0) + b2 = create_entity("b2", "blue_cube", 10.1, 10.1, 0.0) + entities: List[Entity] = [r1, r2, r3, b1, b2] + + correct, misclustered = task.calculate_correct(entities) + + assert correct == 5 + assert misclustered == 0 + + +def test_calculate_multiple_types_mixed() -> None: + task = GroupObjectsTask(["red_cube", "blue_cube"]) + # Red cubes: form a proper cluster. + r1 = create_entity("r1", "red_cube", 0.0, 0.0, 0.0) + r2 = create_entity("r2", "red_cube", 0.1, 0.0, 0.0) + r3 = create_entity("r3", "red_cube", 0.05, 0.05, 0.0) + # Blue cubes: placed near to red cluster, so they are not separate + b1 = create_entity("b1", "blue_cube", 0.1, 0.05, 0.0) + b2 = create_entity("b2", "blue_cube", 0.05, 0.0, 0.0) + entities: List[Entity] = [r1, r2, r3, b1, b2] + + correct, misclustered = task.calculate_correct(entities) + + assert correct == 0 + assert misclustered == 5 + + +def test_calculate_other_types_mixed() -> None: + task = GroupObjectsTask(["red_cube"]) + # Red cubes: form a proper cluster. + r1 = create_entity("r1", "red_cube", 0.0, 0.0, 0.0) + r2 = create_entity("r2", "red_cube", 0.1, 0.0, 0.0) + r3 = create_entity("r3", "red_cube", 0.05, 0.05, 0.0) + # Blue cubes: arent clustered but placed near to red cluster, + # so the red cubes cluster does not contain only red cubes + b1 = create_entity("b1", "blue_cube", 0.1, 0.05, 0.0) + b2 = create_entity("b2", "blue_cube", 0.05, 0.0, 0.0) + entities: List[Entity] = [r1, r2, r3, b1, b2] + + correct, misclustered = task.calculate_correct(entities) + + assert correct == 0 + assert misclustered == 3 + + +def test_calculate_no_specified_objects() -> None: + task = GroupObjectsTask(["apple"]) + # Red cubes: form a proper cluster. + r1 = create_entity("r1", "red_cube", 0.0, 0.0, 0.0) + r2 = create_entity("r2", "red_cube", 0.1, 0.0, 0.0) + r3 = create_entity("r3", "red_cube", 0.05, 0.05, 0.0) + + b1 = create_entity("b1", "blue_cube", 0.1, 0.05, 0.0) + b2 = create_entity("b2", "blue_cube", 0.05, 0.0, 0.0) + entities: List[Entity] = [r1, r2, r3, b1, b2] + + correct, misclustered = task.calculate_correct(entities) + + assert correct == 0 + assert misclustered == 0 diff --git a/tests/rai_bench/tasks/test_move_to_left_task.py b/tests/rai_bench/tasks/test_move_to_left_task.py new file mode 100644 index 000000000..d2d4e6130 --- /dev/null +++ b/tests/rai_bench/tasks/test_move_to_left_task.py @@ -0,0 +1,45 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rai_bench.o3de_test_bench.tasks import MoveObjectsToLeftTask +from tests.rai_bench.conftest import create_entity + + +def test_calculate_all_on_left() -> None: + task = MoveObjectsToLeftTask(["red_cube", "blue_cube"]) + e1 = create_entity("obj1", "red_cube", 0.0, 1.0, 0.0) # y > 0 → correct + e2 = create_entity("obj2", "blue_cube", 0.0, 0.5, 0.0) # y > 0 → correct + correct, incorrect = task.calculate_correct([e1, e2]) + assert correct == 2 + assert incorrect == 0 + + +def test_calculate_some_not_on_left() -> None: + task = MoveObjectsToLeftTask(["red_cube", "blue_cube"]) + e1 = create_entity("obj1", "red_cube", 0.0, 1.0, 0.0) # y > 0 → correct + e2 = create_entity("obj2", "blue_cube", 0.0, -0.5, 0.0) # y < 0 → incorrect + correct, incorrect = task.calculate_correct([e1, e2]) + assert correct == 1 + assert incorrect == 1 + + +def test_calculate_other_types() -> None: + task = MoveObjectsToLeftTask(["red_cube", "blue_cube"]) + e1 = create_entity("obj1", "red_cube", 0.0, 1.0, 0.0) # valid type, y > 0 → correct + e2 = create_entity( + "obj2", "apple", 0.0, 1.0, 0.0 + ) # invalid type, should be ignored + correct, incorrect = task.calculate_correct([e1, e2]) + assert correct == 1 + assert incorrect == 0 diff --git a/tests/rai_bench/tasks/test_place_at_coords_task.py b/tests/rai_bench/tasks/test_place_at_coords_task.py new file mode 100644 index 000000000..ee5975db2 --- /dev/null +++ b/tests/rai_bench/tasks/test_place_at_coords_task.py @@ -0,0 +1,62 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rai_bench.o3de_test_bench.tasks import PlaceObjectAtCoordTask +from tests.rai_bench.conftest import create_entity + + +def test_calculate_single_object_correct() -> None: + task = PlaceObjectAtCoordTask("carrot", (0.5, 0.5)) + e1 = create_entity("carrot1", "carrot", 0.5, 0.5, 0.0) + correct, incorrect = task.calculate_correct([e1]) + assert correct == 1 + assert incorrect == 0 + + +def test_calculate_single_object_incorrect() -> None: + task = PlaceObjectAtCoordTask("carrot", (0.5, 0.5)) + e1 = create_entity("carrot1", "carrot", 0.7, 0.7, 0.0) + correct, incorrect = task.calculate_correct([e1]) + assert correct == 0 + assert incorrect == 1 + + +def test_calculate_multiple_objects_one_correct() -> None: + task = PlaceObjectAtCoordTask("carrot", (0.5, 0.5)) + e1 = create_entity("carrot1", "carrot", 0.5, 0.5, 0.0) # Correct placement. + e2 = create_entity("carrot2", "carrot", 0.6, 0.6, 0.0) # Incorrect placement. + correct, incorrect = task.calculate_correct([e1, e2]) + # only one is considered + assert correct == 1 + assert incorrect == 0 + + +def test_calculate_multiple_objects_multiple_correct() -> None: + task = PlaceObjectAtCoordTask("carrot", (0.5, 0.5)) + e1 = create_entity("carrot1", "carrot", 0.5, 0.5, 0.0) # Correct placement. + e2 = create_entity( + "carrot2", "carrot", 0.501, 0.501, 0.0 + ) # also correct placement. + correct, incorrect = task.calculate_correct([e1, e2]) + assert correct == 1 + assert incorrect == 0 + + +def test_calculate_multiple_objects_none_correct() -> None: + task = PlaceObjectAtCoordTask("carrot", (0.5, 0.5)) + e1 = create_entity("carrot1", "carrot", 0.6, 0.6, 0.0) # Off target. + e2 = create_entity("carrot2", "carrot", 0.7, 0.7, 0.0) # Off target. + correct, incorrect = task.calculate_correct([e1, e2]) + assert correct == 0 + assert incorrect == 1 diff --git a/tests/rai_bench/tasks/test_place_cubes_task.py b/tests/rai_bench/tasks/test_place_cubes_task.py new file mode 100644 index 000000000..e12265cfd --- /dev/null +++ b/tests/rai_bench/tasks/test_place_cubes_task.py @@ -0,0 +1,61 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rai_bench.o3de_test_bench.tasks import PlaceCubesTask +from tests.rai_bench.conftest import create_entity + + +def test_calculate_all_adjacent() -> None: + task = PlaceCubesTask(threshold_distance=0.15) + # Create three cubes that are all close to each other. + e1 = create_entity("cube1", "red_cube", 0.0, 0.0, 0.0) + e2 = create_entity("cube2", "red_cube", 0.1, 0.1, 0.0) + e3 = create_entity("cube3", "red_cube", 0.2, 0.2, 0.0) + correct, incorrect = task.calculate_correct([e1, e2, e3]) + assert correct == 3 + assert incorrect == 0 + + +def test_calculate_one_separated() -> None: + task = PlaceCubesTask(threshold_distance=0.15) + # Two cubes close together and one isolated. + e1 = create_entity("cube1", "red_cube", 0.0, 0.0, 0.0) + e2 = create_entity("cube2", "red_cube", 0.1, 0.1, 0.0) + e3 = create_entity("cube3", "red_cube", 1.0, 1.0, 0.0) # Isolated cube. + correct, incorrect = task.calculate_correct([e1, e2, e3]) + assert correct == 2 + assert incorrect == 1 + + +def test_calculate_none_adjacent() -> None: + task = PlaceCubesTask(threshold_distance=0.15) + # All cubes are far apart. + e1 = create_entity("cube1", "red_cube", 0.0, 0.0, 0.0) + e2 = create_entity("cube2", "red_cube", 1.0, 1.0, 0.0) + e3 = create_entity("cube3", "red_cube", 2.0, 2.0, 0.0) + correct, incorrect = task.calculate_correct([e1, e2, e3]) + assert correct == 0 + assert incorrect == 3 + + +def test_calculate_2_clusters_adjacent() -> None: + task = PlaceCubesTask(threshold_distance=0.15) + # All cubes form 2 clusters + e1 = create_entity("cube1", "red_cube", 0.0, 0.0, 0.0) + e2 = create_entity("cube2", "red_cube", 0.1, 0.0, 0.0) + e3 = create_entity("cube3", "red_cube", 2.0, 2.0, 0.0) + e4 = create_entity("cube4", "red_cube", 2.1, 2.0, 0.0) + correct, incorrect = task.calculate_correct([e1, e2, e3, e4]) + assert correct == 4 + assert incorrect == 0 diff --git a/tests/rai_bench/tasks/test_rotate_objects_task.py b/tests/rai_bench/tasks/test_rotate_objects_task.py new file mode 100644 index 000000000..5c41540a5 --- /dev/null +++ b/tests/rai_bench/tasks/test_rotate_objects_task.py @@ -0,0 +1,92 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +from rai_bench.o3de_test_bench.tasks import RotateObjectTask +from rai_sim.simulation_bridge import Rotation +from tests.rai_bench.conftest import create_entity + + +def test_calculate_perfect_match() -> None: + target = Rotation(x=0.0, y=0.0, z=0.0, w=1.0) + task = RotateObjectTask(["apple"], target_quaternion=target) + e1 = create_entity( + "obj1", "apple", 0.3, 0.0, 0.05, Rotation(x=0.0, y=0.0, z=0.0, w=1.0) + ) + correct, incorrect = task.calculate_correct([e1], allowable_rotation_error=5.0) + assert correct == 1 + assert incorrect == 0 + + +def test_calculate_error_under_threshold() -> None: + target = Rotation(x=0.0, y=0.0, z=0.0, w=1.0) + task = RotateObjectTask(["apple"], target_quaternion=target) + half_angle = math.radians(3.0) + current_rotation = Rotation( + x=0.0, y=0.0, z=math.sin(half_angle), w=math.cos(half_angle) + ) + e1 = create_entity("obj1", "apple", 0.3, 0.0, 0.05, current_rotation) + # rotation error less then margin + correct, incorrect = task.calculate_correct([e1], allowable_rotation_error=7.0) + assert correct == 1 + assert incorrect == 0 + + +def test_calculate_multiple_types() -> None: + target = Rotation(x=0.0, y=0.0, z=0.0, w=1.0) + task = RotateObjectTask(["apple", "carrot"], target_quaternion=target) + half_angle = math.radians(1.0) + current_rotation = Rotation( + x=0.0, y=0.0, z=math.sin(half_angle), w=math.cos(half_angle) + ) + e1 = create_entity("obj1", "apple", 0.3, 0.0, 0.05, current_rotation) + e2 = create_entity("obj2", "apple", 0.4, 0.1, 0.05, current_rotation) + e3 = create_entity("obj3", "carrot", 0.3, 0.0, 0.05, current_rotation) + correct, incorrect = task.calculate_correct( + [e1, e2, e3], allowable_rotation_error=5.0 + ) + assert correct == 3 + assert incorrect == 0 + + +def test_calculate_mixed_types() -> None: + target = Rotation(x=0.0, y=0.0, z=0.0, w=1.0) + task = RotateObjectTask(["yellow_cube", "carrot"], target_quaternion=target) + half_angle = math.radians(1.0) + current_rotation = Rotation( + x=0.0, y=0.0, z=math.sin(half_angle), w=math.cos(half_angle) + ) + e1 = create_entity("obj1", "apple", 0.3, 0.0, 0.05, current_rotation) + e2 = create_entity("obj2", "yellow_cube", 0.4, 0.1, 0.05, current_rotation) + e3 = create_entity("obj3", "carrot", 0.3, 0.0, 0.05, current_rotation) + correct, incorrect = task.calculate_correct( + [e1, e2, e3], allowable_rotation_error=5.0 + ) + assert correct == 2 + assert incorrect == 0 + + +def test_calculate_error_above_threshold() -> None: + target = Rotation(x=0.0, y=0.0, z=0.0, w=1.0) + task = RotateObjectTask(["apple"], target_quaternion=target) + half_angle = math.radians(5) + current_rotation = Rotation( + x=0.0, y=0.0, z=math.sin(half_angle), w=math.cos(half_angle) + ) + e1 = create_entity("obj1", "apple", 0.3, 0.0, 0.05, current_rotation) + correct, incorrect = task.calculate_correct([e1], allowable_rotation_error=5.0) + # The rotation error is 10°, so it exceeds the 5° threshold. + assert correct == 0 + assert incorrect == 1 diff --git a/tests/rai_bench/tasks/test_task.py b/tests/rai_bench/tasks/test_task.py new file mode 100644 index 000000000..16e56b7b7 --- /dev/null +++ b/tests/rai_bench/tasks/test_task.py @@ -0,0 +1,116 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Set + +from rai_bench.benchmark_model import Task +from rai_sim.simulation_bridge import Entity, Pose, Translation +from tests.rai_bench.conftest import create_entity + + +class DummyTask(Task): + def get_prompt(self) -> str: + return "dummy prompt" + + def validate_config(self, simulation_config: Any) -> bool: + return True + + def calculate_result(self, simulation_bridge: Any) -> float: + return 1.0 + + +def create_pose(x: float, y: float, z: float) -> Pose: + return Pose(translation=Translation(x=x, y=y, z=z)) + + +def test_build_neighbourhood_list() -> None: + task = DummyTask() + e1: Entity = create_entity("e1", "red_cube", 0, 0, 0) + e2: Entity = create_entity("e2", "red_cube", 0.1, 0, 0) + e3: Entity = create_entity("e3", "red_cube", 1, 1, 1) + entities: List[Entity] = [e1, e2, e3] + + neighbourhood: Dict[Entity, List[Entity]] = task.build_neighbourhood_list( + entities, threshold_distance=0.2 + ) + # e1 and e2 should be neighbours of each other; e3 remains isolated. + assert set(neighbourhood[e1]) == {e2}, neighbourhood[e1] + assert set(neighbourhood[e2]) == {e1} + assert neighbourhood[e3] == [] + + +def test_check_neighbourhood_types() -> None: + task = DummyTask() + + e1: Entity = create_entity("e1", "red_cube", 0, 0, 0) + e2: Entity = create_entity("e2", "red_cube", 0, 0, 0) + + assert task.check_neighbourhood_types([e1, e2], allowed_types=["red_cube"]) is True + assert ( + task.check_neighbourhood_types([e1, e2], allowed_types=["blue_cube"]) is False + ) + assert task.check_neighbourhood_types([], allowed_types=["red_cube"]) is True + + +def test_find_clusters() -> None: + task = DummyTask() + e1: Entity = create_entity("e1", "red_cube", 0, 0, 0) + e2: Entity = create_entity("e2", "red_cube", 0, 0, 0) + e3: Entity = create_entity("e3", "red_cube", 0, 0, 0) + e4: Entity = create_entity("e4", "red_cube", 0, 0, 0) + # Manually create a neighbourhood graph: + neighbourhood: Dict[Entity, List[Entity]] = { + e1: [e2], + e2: [e1, e3], + e3: [e2], + e4: [], + } + clusters: List[List[Entity]] = task.find_clusters(neighbourhood) + # Convert to sets for order-independent comparison. + clusters_as_sets: List[Set[Entity]] = [set(cluster) for cluster in clusters] + assert {e1, e2, e3} in clusters_as_sets + assert {e4} in clusters_as_sets + assert len(clusters_as_sets) == 2 + + +def test_group_entities_by_z_coordinate_all_stacked() -> None: + task = DummyTask() + e1: Entity = create_entity("e1", "red_cube", 0, 0, 0.0) + e2: Entity = create_entity("e2", "red_cube", 0, 0, 0.05) + e3: Entity = create_entity("e3", "red_cube", 0, 0, 0.2) + e4: Entity = create_entity("e4", "red_cube", 0, 0, 0.25) + e5: Entity = create_entity("e5", "red_cube", 0, 0, 0.5) + entities: List[Entity] = [e1, e2, e3, e4, e5] + + groups: List[List[Entity]] = task.group_entities_along_z_axis(entities, margin=0.1) + assert len(groups) == 1 + assert groups[0] == entities + + +def test_group_entities_by_z_coordinate_2_stacks() -> None: + task = DummyTask() + e1: Entity = create_entity("e1", "red_cube", 0, 1, 0.0) + e2: Entity = create_entity("e2", "red_cube", 0, 1, 0.05) + + e3: Entity = create_entity("e3", "red_cube", 0, 0, 0.0) + e4: Entity = create_entity("e4", "red_cube", 0, 0, 0.05) + e5: Entity = create_entity("e5", "red_cube", 0, 0, 0.1) + entities: List[Entity] = [e1, e2, e3, e4, e5] + + groups: List[List[Entity]] = task.group_entities_along_z_axis(entities, margin=0.01) + # Convert to sets for order-independent comparison. + groups_as_sets: List[Set[Entity]] = [set(group) for group in groups] + assert len(groups) == 2 + assert {e1, e2} in groups_as_sets + assert {e3, e4, e5} in groups_as_sets