Skip to content

Goal position fixes #1637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions examples/traffic_histories_to_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from smarts.core.colors import Colors
from smarts.core.controllers import ActionSpaceType, ControllerOutOfLaneException
from smarts.core.coordinates import Point
from smarts.core.local_traffic_provider import LocalTrafficProvider
from smarts.core.plan import PositionalGoal
from smarts.core.scenario import Scenario
Expand Down Expand Up @@ -84,24 +85,12 @@ def __init__(
"No output dir provided. Observations will not be saved."
)
self._smarts = None
self._create_missions()

if agent_interface is not None:
self.agent_interface = agent_interface
else:
self.agent_interface = self._create_default_interface()

def _create_missions(self):
self._missions = dict()
orig_missions = self._scenario.discover_missions_of_traffic_histories()
for v_id, mission in orig_missions.items():
veh_goal = self._scenario._get_vehicle_goal(v_id)
# TODO: get prefixed vehicle_id from TrafficHistoryProvider
veh_id = f"history-vehicle-{v_id}"
self._missions[veh_id] = replace(
mission, goal=PositionalGoal(veh_goal, radius=3)
)

def _create_default_interface(
self, img_meters: int = 64, img_pixels: int = 256, action_space="TargetPose"
) -> AgentInterface:
Expand Down Expand Up @@ -258,8 +247,33 @@ def collect(
)

if self._output_dir:
# Get original missions for all vehicles
missions = dict()
orig_missions = self._scenario.discover_missions_of_traffic_histories()
for v_id, mission in orig_missions.items():
# TODO: get prefixed vehicle_id from TrafficHistoryProvider
veh_id = f"history-vehicle-{v_id}"
missions[veh_id] = mission

# Save recorded observations as pickle files
for car, data in collected_data.items():
# Fill in mission with proper goal position for all observations
last_t = max(data.keys())
last_state = data[last_t].ego_vehicle_state
goal_pos = Point(last_state.position[0], last_state.position[1])
new_mission = replace(
missions[last_state.id], goal=PositionalGoal(goal_pos, radius=3)
)
for t in data.keys():
ego_state = data[t].ego_vehicle_state
new_ego_state = ego_state._replace(mission=new_mission)
data[t] = replace(data[t], ego_vehicle_state=new_ego_state)

# Create terminal state for last timestep, when the vehicle reaches the goal
events = data[last_t].events
new_events = events._replace(reached_goal=True)
data[last_t] = replace(data[last_t], events=new_events)

outfile = os.path.join(
self._output_dir,
f"{car}.pkl",
Expand Down Expand Up @@ -295,18 +309,13 @@ def _record_data(
# Get observations from each vehicle and record them.
obs = dict()
obs, _, _, _ = self._smarts.observe_from(list(valid_vehicles))
self._logger.info(f"t={t}, active_vehicles={len(valid_vehicles)}")
self._logger.debug(f"t={t}, active_vehicles={len(valid_vehicles)}")
for id_ in list(obs):
ego_state = obs[id_].ego_vehicle_state
if ego_state.lane_index is None:
del obs[id_]
continue

mission = self._missions[ego_state.id]
if mission:
new_ego_state = ego_state._replace(mission=mission)
obs[id_] = replace(obs[id_], ego_vehicle_state=new_ego_state)

top_down_rgb = obs[id_].top_down_rgb
if top_down_rgb:
res = top_down_rgb.metadata.resolution
Expand All @@ -333,7 +342,7 @@ def _record_data(

# TODO: handle case where neighboring vehicle has lane_index of None too
for car, car_obs in obs.items():
collected_data.setdefault(car, {}).setdefault(t, {})
collected_data.setdefault(car, {})
collected_data[car][t] = car_obs


Expand Down
7 changes: 2 additions & 5 deletions smarts/core/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,8 @@ def _get_vehicle_goal(self, vehicle_id: str) -> Point:
vehicle_id, final_exit_time
)
assert final_pose
final_pos_x, final_pos_y, final_heading, _ = final_pose
# missions start from front bumper, but pos is center of vehicle
veh_dims = self._traffic_history.vehicle_dims(vehicle_id)
final_hhx, final_hhy = radians_to_vec(final_heading) * (0.5 * veh_dims.length)
return Point(final_pos_x + final_hhx, final_pos_y + final_hhy)
final_pos_x, final_pos_y, _, _ = final_pose
return Point(final_pos_x, final_pos_y)

def discover_missions_of_traffic_histories(self) -> Dict[str, Mission]:
"""Retrieves the missions of traffic history vehicles."""
Expand Down