diff --git a/examples/history_vehicles_replacement_for_imitation_learning.py b/examples/history_vehicles_replacement_for_imitation_learning.py index 55dea60f58..cfae2277ef 100644 --- a/examples/history_vehicles_replacement_for_imitation_learning.py +++ b/examples/history_vehicles_replacement_for_imitation_learning.py @@ -20,6 +20,12 @@ def act(self, obs): def main(scenarios, headless, seed): scenarios_iterator = Scenario.scenario_variations(scenarios, []) + smarts = SMARTS( + agent_interfaces={}, + traffic_sim=SumoTrafficSimulation(headless=True, auto_start=True), + envision=Envision(), + ) + for _ in scenarios: scenario = next(scenarios_iterator) agent_missions = scenario.discover_missions_of_traffic_histories() @@ -33,14 +39,10 @@ def main(scenarios, headless, seed): ), agent_builder=KeepLaneAgent, ) - agent = agent_spec.build_agent() - smarts = SMARTS( - agent_interfaces={agent_id: agent_spec.interface}, - traffic_sim=SumoTrafficSimulation(headless=True, auto_start=True), - envision=Envision(), - ) + smarts.switch_ego_agent({agent_id: agent_spec.interface}) + observations = smarts.reset(scenario) dones = {agent_id: False} @@ -52,7 +54,7 @@ def main(scenarios, headless, seed): {agent_id: agent_action} ) - smarts.destroy() + smarts.destroy() if __name__ == "__main__": diff --git a/smarts/core/agent_manager.py b/smarts/core/agent_manager.py index b1a9b2fb17..6147751667 100644 --- a/smarts/core/agent_manager.py +++ b/smarts/core/agent_manager.py @@ -300,6 +300,9 @@ def send_observations_to_social_agents(self, observations): obs = observations[agent_id] self._remote_social_agents_action[agent_id] = remote_agent.act(obs) + def switch_initial_agent(self, agent_interface): + self._initial_interfaces = agent_interface + def setup_agents(self, sim): self.init_ego_agents(sim) self.setup_social_agents(sim) diff --git a/smarts/core/sensors.py b/smarts/core/sensors.py index 918baebba9..708cc5062c 100644 --- a/smarts/core/sensors.py +++ b/smarts/core/sensors.py @@ -298,7 +298,7 @@ def observe(sim, agent_id, sensor_state, vehicle): and sensor_state.steps_completed == 1 and agent_id in sim.agent_manager.ego_agent_ids ): - logger.warning(f"{agent_id} is done on the first step") + logger.warning(f"Agent Id: {agent_id} is done on the first step") return ( Observation( diff --git a/smarts/core/smarts.py b/smarts/core/smarts.py index 3043d8af33..f290fceb6b 100644 --- a/smarts/core/smarts.py +++ b/smarts/core/smarts.py @@ -369,6 +369,10 @@ def setup(self, scenario: Scenario): def add_provider(self, provider): self._providers.append(provider) + def switch_ego_agent(self, agent_interface): + self._agent_manager.switch_initial_agent(agent_interface) + self._is_setup = False + def _setup_road_network(self): glb_path = self.scenario.map_glb_filepath if self._road_network_np: diff --git a/smarts/core/trap_manager.py b/smarts/core/trap_manager.py index 41ddca183b..79f5369f71 100644 --- a/smarts/core/trap_manager.py +++ b/smarts/core/trap_manager.py @@ -145,6 +145,10 @@ def largest_vehicle_plane_dimension(vehicle): ), ) for v_id in sorted_vehicle_ids: + # Skip the capturing process if history traffic is used + if sim.scenario.traffic_history: + break + vehicle = vehicles[v_id] point = Point(vehicle.position)