Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 10 additions & 23 deletions src/rai_bench/rai_bench/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,13 @@ def __init__(
else:
self._logger = logging.getLogger(__name__)

self.fieldnames = [
"task",
"simulation_config",
"final_score",
"total_time",
"number_of_tool_calls",
]
self._initialize_results_file()

@classmethod
Expand All @@ -192,6 +199,7 @@ def create_scenarios(
# TODO (jm) hacky_fix, taking paths as args here, not the best solution,
# but more changes to code would be required
scenarios: List[Scenario[SimulationConfigT]] = []

for task in tasks:
for sim_conf, sim_path in zip(simulation_configs, simulation_configs_paths):
try:
Expand All @@ -210,19 +218,10 @@ def create_scenarios(

def _initialize_results_file(self):
"""Initialize the CSV file with headers."""
fieldnames = [
"task",
"simulation_config",
"initial_score",
"final_score",
"total_time",
"number_of_tool_calls",
]

with open(
self.results_filename, mode="w", newline="", encoding="utf-8"
) as file:
writer = csv.DictWriter(file, fieldnames=fieldnames)
writer = csv.DictWriter(file, fieldnames=self.fieldnames)
writer.writeheader()

def run_next(self, agent) -> None:
Expand All @@ -239,8 +238,6 @@ def run_next(self, agent) -> None:
self._logger.info( # type: ignore
f"RUNNING SCENARIO NUMBER {i + 1} / {self.num_of_scenarios}, TASK: {scenario.task.get_prompt()}"
)
initial_result = scenario.task.calculate_result(self.simulation_bridge)
self._logger.info(f"RESULT OF THE INITIAL SETUP: {initial_result}") # type: ignore
tool_calls_num = 0

ts = time.perf_counter()
Expand Down Expand Up @@ -281,7 +278,6 @@ def run_next(self, agent) -> None:
scenario_result: Dict[str, Any] = {
"task": scenario.task.get_prompt(),
"simulation_config": scenario.simulation_config_path,
"initial_score": initial_result,
"final_score": result,
"total_time": f"{total_time:.3f}",
"number_of_tool_calls": tool_calls_num,
Expand All @@ -294,19 +290,10 @@ def run_next(self, agent) -> None:

def _save_scenario_result_to_csv(self, result: Dict[str, Any]) -> None:
"""Save a single scenario result to the CSV file."""
fieldnames = [
"task",
"simulation_config",
"initial_score",
"final_score",
"total_time",
"number_of_tool_calls",
]

with open(
self.results_filename, mode="a", newline="", encoding="utf-8"
) as file:
writer = csv.DictWriter(file, fieldnames=fieldnames)
writer = csv.DictWriter(file, fieldnames=self.fieldnames)
writer.writerow(result)

def get_results(self) -> List[Dict[str, Any]]:
Expand Down
133 changes: 71 additions & 62 deletions src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,91 +11,100 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple

from rai_bench.benchmark_model import (
EntitiesMismatchException,
Task,
)
from rai_bench.benchmark_model import EntitiesMismatchException, Task
from rai_sim.o3de.o3de_bridge import (
SimulationBridge,
)
from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT
from rai_sim.simulation_bridge import SimulationConfig, SimulationConfigT, SpawnedEntity


class GrabCarrotTask(Task):
obj_types = ["carrot"]

# TODO (jm) extract common logic to some parent manipulation task
def get_prompt(self) -> str:
return "Manipulate objects, so that all carrots to the left side of the table (positive y)"

def validate_config(self, simulation_config: SimulationConfig) -> bool:
for ent in simulation_config.entities:
if ent.prefab_name == "carrot":
if ent.prefab_name in self.obj_types:
return True

return False

def calculate_result(
self, simulation_bridge: SimulationBridge[SimulationConfigT]
) -> float:
# TODO (jm) extract common logic to some parent manipulation task?
initially_misplaced_now_correct = 0 # when the object which was in the incorrect place at the start, is in a correct place at the end
initially_misplaced_still_incorrect = 0 # when the object which was in the incorrect place at the start, is in a incorrect place at the end
initially_correct_still_correct = 0 # when the object which was in the correct place at the start, is in a correct place at the end
initially_correct_now_incorrect = 0 # when the object which was in the correct place at the start, is in a incorrect place at the end
def calculate_correct(self, entities: List[SpawnedEntity]) -> Tuple[int, int]:
"""Calculate how many objects are positioned correct and incorrect"""
correct = sum(1 for ent in entities if ent.pose.translation.y > 0.0)
incorrect: int = len(entities) - correct
return correct, incorrect

scene_state = simulation_bridge.get_scene_state()
def calculate_initial_placements(
self, simulation_bridge: SimulationBridge[SimulationConfigT]
) -> tuple[int, int]:
"""
Calculates the number of objects that are correctly and incorrectly placed initially.
"""
initial_carrots = self.filter_entities_by_prefab_type(
simulation_bridge.spawned_entities, prefab_types=["carrot"]
simulation_bridge.spawned_entities, prefab_types=self.obj_types
)
initially_correct, initially_incorrect = self.calculate_correct(
entities=initial_carrots
)

self.logger.info( # type: ignore
f"Initially correctly placed carrots: {initially_correct}, Initially incorrectly placed carrots: {initially_incorrect}"
)
return initially_correct, initially_incorrect

def calculate_final_placements(
self, simulation_bridge: SimulationBridge[SimulationConfigT]
) -> tuple[int, int]:
"""
Calculates the number of objects that are correctly and incorrectly placed at the end of the simulation.
"""
scene_state = simulation_bridge.get_scene_state()
final_carrots = self.filter_entities_by_prefab_type(
scene_state.entities, prefab_types=["carrot"]
scene_state.entities, prefab_types=self.obj_types
)
final_correct, final_incorrect = self.calculate_correct(entities=final_carrots)

self.logger.info( # type: ignore
f"Finally correctly placed carrots: {final_correct}, Finally incorrectly placed carrots: {final_incorrect}"
)
num_initial_carrots = len(initial_carrots)
return final_correct, final_incorrect

if num_initial_carrots != len(final_carrots):
def calculate_result(
self, simulation_bridge: SimulationBridge[SimulationConfigT]
) -> float:
"""
Calculates a score from 0.0 to 1.0, where 0.0 represents the initial placements or worse and 1.0 represents perfect final placements.
"""
initially_correct, initially_incorrect = self.calculate_initial_placements(
simulation_bridge
)
final_correct, final_incorrect = self.calculate_final_placements(
simulation_bridge
)

total_objects = initially_correct + initially_incorrect
if total_objects == 0:
return 1.0
elif (initially_correct + initially_incorrect) != (
final_correct + final_incorrect
):
raise EntitiesMismatchException(
"Number of initially spawned entities does not match number of entities present at the end."
"number of initial entities does not match final entities number."
)

elif initially_incorrect == 0:
pass
# NOTE all objects are placed correctly
# no point in running task
raise ValueError("All objects are placed correctly at the start.")
else:
self.logger.debug(f"initial positions: {initial_carrots}") # type: ignore
self.logger.debug(f"current positions: {final_carrots}") # type: ignore
for initial_carrot in initial_carrots:
for final_carrot in final_carrots:
if initial_carrot.name == final_carrot.name:
initial_y = initial_carrot.pose.translation.y
final_y = final_carrot.pose.translation.y
# NOTE the specific coords that refer to for example
# middle of the table can differ across simulations,
# take that into consideration
if (
initial_y <= 0.0
): # Carrot started in the incorrect place (right side)
if final_y >= 0.0:
initially_misplaced_now_correct += (
1 # Moved to correct side
)
else:
initially_misplaced_still_incorrect += (
1 # Stayed on incorrect side
)
else: # Carrot started in the correct place (left side)
if final_y >= 0.0:
initially_correct_still_correct += (
1 # Stayed on correct side
)
else:
initially_correct_now_incorrect += (
1 # Moved incorrectly to the wrong side
)
break
else:
raise EntitiesMismatchException(
f"Entity with name: {initial_carrot.name} which was present in initial scene, not found in final scene."
)
corrected = final_correct - initially_correct
score = max(0.0, corrected / initially_incorrect)

self.logger.info( # type: ignore
f"initially_misplaced_now_correct: {initially_misplaced_now_correct}, initially_misplaced_still_incorrect: {initially_misplaced_still_incorrect}, initially_correct_still_correct: {initially_correct_still_correct}, initially_correct_now_incorrect: {initially_correct_now_incorrect}"
)
return (
initially_misplaced_now_correct + initially_correct_still_correct
) / num_initial_carrots
self.logger.info(f"Calculated score: {score:.2f}") # type: ignore
return score
Loading