From 149a1a69d2c6a8f2989254ad207cc5d3add701fa Mon Sep 17 00:00:00 2001 From: Tucker Date: Mon, 24 Apr 2023 09:30:00 -0400 Subject: [PATCH 01/13] Add low dimension interest observation. --- smarts/core/observations.py | 2 ++ smarts/core/sensors/__init__.py | 9 ++++++-- smarts/core/simulation_frame.py | 24 ++++++++++++++++++++++ smarts/core/smarts.py | 6 +++++- smarts/env/utils/observation_conversion.py | 12 ++++++++++- smarts/env/wrappers/format_obs.py | 7 ++++++- 6 files changed, 55 insertions(+), 5 deletions(-) diff --git a/smarts/core/observations.py b/smarts/core/observations.py index ae37729bee..9e0e0465e2 100644 --- a/smarts/core/observations.py +++ b/smarts/core/observations.py @@ -53,6 +53,8 @@ class VehicleObservation(NamedTuple): lane_position: Optional[RefLinePoint] = None """(s,t,h) coordinates within the lane, where s is the longitudinal offset along the lane, t is the lateral displacement from the lane center, and h (not yet supported) is the vertical displacement from the lane surface. See the Reference Line coordinate system in OpenDRIVE here: https://www.asam.net/index.php?eID=dumpFile&t=f&f=4089&token=deea5d707e2d0edeeb4fccd544a973de4bc46a09#_coordinate_systems """ + interest: bool = False + """If this vehicle is of interest in the current scenario.""" class EgoVehicleObservation(NamedTuple): diff --git a/smarts/core/sensors/__init__.py b/smarts/core/sensors/__init__.py index e94eab059f..40589931ba 100644 --- a/smarts/core/sensors/__init__.py +++ b/smarts/core/sensors/__init__.py @@ -70,7 +70,9 @@ LANE_INDEX_CONSTANT = -1 -def _make_vehicle_observation(road_map, neighborhood_vehicle): +def _make_vehicle_observation( + road_map, neighborhood_vehicle: VehicleState, sim_frame: SimulationFrame +): nv_lane = road_map.nearest_lane(neighborhood_vehicle.pose.point, radius=3) if nv_lane: nv_road_id = nv_lane.road.road_id @@ -91,6 +93,7 @@ def _make_vehicle_observation(road_map, neighborhood_vehicle): lane_id=nv_lane_id, lane_index=nv_lane_index, lane_position=None, + interest=sim_frame.actor_is_interest(neighborhood_vehicle.actor_id), ) @@ -262,7 +265,9 @@ def process_serialization_safe_sensors( for nv in neighborhood_vehicle_states_sensor( vehicle_state, sim_frame.vehicle_states.values() ): - veh_obs = _make_vehicle_observation(sim_local_constants.road_map, nv) + veh_obs = _make_vehicle_observation( + sim_local_constants.road_map, nv, sim_frame + ) lane_position_sensor = vehicle_sensors.get("lane_position_sensor") nv_lane_pos = None if veh_obs.lane_id is not LANE_ID_CONSTANT and lane_position_sensor: diff --git a/smarts/core/simulation_frame.py b/smarts/core/simulation_frame.py index c9251a0f02..74a6b0a8fd 100644 --- a/smarts/core/simulation_frame.py +++ b/smarts/core/simulation_frame.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. import logging +import re from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set @@ -55,6 +56,7 @@ class SimulationFrame: vehicle_sensors: Dict[str, Dict[str, Any]] sensor_states: Any + interest_filter: re.Pattern # TODO MTA: renderer can be allowed here as long as it is only type information # renderer_type: Any = None _ground_bullet_id: Optional[str] = None @@ -74,6 +76,28 @@ def actor_states_by_id(self) -> Dict[str, ActorState]: """Get actor states paired by their ids.""" return {a_s.actor_id: a_s for a_s in self.actor_states} + @cached_property + def _interest_actors(self) -> Dict[str, ActorState]: + """Get the actor states of actors that are marked as of interest.""" + if self.interest_filter.pattern: + return { + a_s.actor_id: a_s + for a_s in self.actor_states + if self.interest_filter.match(a_s.actor_id) + } + return {} + + def actor_is_interest(self, actor_id) -> bool: + """Determine if the actor is of interest. + + Args: + actor_id (str): The id of the actor to test. + + Returns: + bool: If the actor is of interest. + """ + return actor_id in self._interest_actors + def vehicle_did_collide(self, vehicle_id) -> bool: """Test if the given vehicle had any collisions in the last physics update.""" vehicle_collisions = self.vehicle_collisions.get(vehicle_id, []) diff --git a/smarts/core/smarts.py b/smarts/core/smarts.py index f6583426e0..dc7eccd304 100644 --- a/smarts/core/smarts.py +++ b/smarts/core/smarts.py @@ -20,6 +20,7 @@ import importlib.resources as pkg_resources import logging import os +import re import warnings from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union @@ -76,6 +77,7 @@ level=logging.ERROR, ) +_DEFAULT_PATTERN = re.compile("") MAX_PYBULLET_FREQ = 240 @@ -1381,7 +1383,6 @@ def neighborhood_vehicles_around_vehicle( ) -> List[VehicleState]: """Find vehicles in the vicinity of the target vehicle.""" self._check_valid() - from smarts.core.sensors import Sensors vehicle = self._vehicle_index.vehicle_by_id(vehicle_id) return neighborhood_vehicles_around_vehicle( @@ -1680,4 +1681,7 @@ def cached_frame(self): vehicle_sensors=self.sensor_manager.sensors_for_actor_ids(vehicle_ids), sensor_states=dict(self.sensor_manager.sensor_states_items()), _ground_bullet_id=self._ground_bullet_id, + interest_filter=self.scenario.metadata.get( + "actor_of_interest_re_filter", _DEFAULT_PATTERN + ), ) diff --git a/smarts/env/utils/observation_conversion.py b/smarts/env/utils/observation_conversion.py index 04bab7a762..1681ba41d1 100644 --- a/smarts/env/utils/observation_conversion.py +++ b/smarts/env/utils/observation_conversion.py @@ -196,6 +196,7 @@ def _format_neighborhood_vehicle_states( "lane_index": np.zeros((des_shp,), dtype=np.int8), "position": np.zeros((des_shp, 3), dtype=np.float64), "speed": np.zeros((des_shp,), dtype=np.float32), + "interest": np.zeros((des_shp,), dtype=np.bool8), } neighborhood_vehicle_states = [ @@ -206,16 +207,20 @@ def _format_neighborhood_vehicle_states( nghb.lane_index, nghb.position, nghb.speed, + nghb.interest, ) for nghb in neighborhood_vehicle_states[:des_shp] ] - box, heading, vehicle_id, lane_index, pos, speed = zip(*neighborhood_vehicle_states) + box, heading, vehicle_id, lane_index, pos, speed, interest = zip( + *neighborhood_vehicle_states + ) box = np.array(box, dtype=np.float32) heading = np.array(heading, dtype=np.float32) lane_index = np.array(lane_index, dtype=np.int8) pos = np.array(pos, dtype=np.float64) speed = np.array(speed, dtype=np.float32) + interest = np.array(interest, dtype=np.bool8) # fmt: off box = np.pad(box, ((0,pad_shp),(0,0)), mode='constant', constant_values=0) @@ -224,6 +229,7 @@ def _format_neighborhood_vehicle_states( lane_index = np.pad(lane_index, ((0,pad_shp)), mode='constant', constant_values=0) pos = np.pad(pos, ((0,pad_shp),(0,0)), mode='constant', constant_values=0) speed = np.pad(speed, ((0,pad_shp)), mode='constant', constant_values=0) + interest = np.pad(interest, ((0,pad_shp)), mode="constant", constant_values=False) # fmt: on return { @@ -233,6 +239,7 @@ def _format_neighborhood_vehicle_states( "lane_index": lane_index, "position": pos, "speed": speed, + "interest": interest, } @@ -707,6 +714,9 @@ def name(self): "speed": gym.spaces.Box( low=0, high=1e10, shape=(_NEIGHBOR_SHP,), dtype=np.float32 ), + "interest": gym.spaces.Box( + low=0, high=1, shape=(_NEIGHBOR_SHP,), dtype=np.bool8 + ), } ), ) diff --git a/smarts/env/wrappers/format_obs.py b/smarts/env/wrappers/format_obs.py index edca5c024e..70b1e5a02d 100644 --- a/smarts/env/wrappers/format_obs.py +++ b/smarts/env/wrappers/format_obs.py @@ -520,6 +520,7 @@ def _std_neighborhood_vehicle_states( "lane_index": np.zeros((des_shp,), dtype=np.int8), "pos": np.zeros((des_shp, 3), dtype=np.float64), "speed": np.zeros((des_shp,), dtype=np.float32), + "interest": np.zeros((des_shp,), dtype=np.bool8), } nghbs = [ @@ -529,16 +530,18 @@ def _std_neighborhood_vehicle_states( nghb.lane_index, nghb.position, nghb.speed, + nghb.interest, ) for nghb in nghbs[:des_shp] ] - box, heading, lane_index, pos, speed = zip(*nghbs) + box, heading, lane_index, pos, speed, interest = zip(*nghbs) box = np.array(box, dtype=np.float32) heading = np.array(heading, dtype=np.float32) lane_index = np.array(lane_index, dtype=np.int8) pos = np.array(pos, dtype=np.float64) speed = np.array(speed, dtype=np.float32) + interest = np.array(interest, dtype=np.bool8) # fmt: off box = np.pad(box, ((0,pad_shp),(0,0)), mode='constant', constant_values=0) @@ -546,6 +549,7 @@ def _std_neighborhood_vehicle_states( lane_index = np.pad(lane_index, ((0,pad_shp)), mode='constant', constant_values=0) pos = np.pad(pos, ((0,pad_shp),(0,0)), mode='constant', constant_values=0) speed = np.pad(speed, ((0,pad_shp)), mode='constant', constant_values=0) + interest = np.pad(interest, ((0,pad_shp)), mode="constant", constant_values=0) # fmt: on return { @@ -554,6 +558,7 @@ def _std_neighborhood_vehicle_states( "lane_index": lane_index, "pos": pos, "speed": speed, + "interest": interest, } From 9868240b5b1d033c8acd166d204b1dbd5f4f059f Mon Sep 17 00:00:00 2001 From: Tucker Date: Mon, 24 Apr 2023 10:46:33 -0400 Subject: [PATCH 02/13] Fix tests. --- smarts/env/wrappers/format_obs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/smarts/env/wrappers/format_obs.py b/smarts/env/wrappers/format_obs.py index 70b1e5a02d..132755d65e 100644 --- a/smarts/env/wrappers/format_obs.py +++ b/smarts/env/wrappers/format_obs.py @@ -364,6 +364,7 @@ def get_spaces() -> Dict[str, Callable[[Any], gym.Space]]: "lane_index": gym.spaces.Box(low=0, high=127, shape=(_NEIGHBOR_SHP,), dtype=np.int8), "pos": gym.spaces.Box(low=-1e10, high=1e10, shape=(_NEIGHBOR_SHP,3), dtype=np.float64), "speed": gym.spaces.Box(low=0, high=1e10, shape=(_NEIGHBOR_SHP,), dtype=np.float32), + "interest": gym.spaces.Box(low=0, high=1, shape=(_NEIGHBOR_SHP,), dtype=np.bool8), }), "occupancy_grid_map": lambda val: gym.spaces.Box(low=0, high=255,shape=(val.height, val.width, 1), dtype=np.uint8), "top_down_rgb": lambda val: gym.spaces.Box(low=0, high=255, shape=(val.height, val.width, 3), dtype=np.uint8), From 4eb80b35f294f698ce013e9d28c9d084208f680d Mon Sep 17 00:00:00 2001 From: Tucker Date: Mon, 24 Apr 2023 11:32:05 -0400 Subject: [PATCH 03/13] Use multibinary --- smarts/env/utils/observation_conversion.py | 8 +++----- smarts/env/wrappers/format_obs.py | 6 +++--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/smarts/env/utils/observation_conversion.py b/smarts/env/utils/observation_conversion.py index 1681ba41d1..9404d47e1a 100644 --- a/smarts/env/utils/observation_conversion.py +++ b/smarts/env/utils/observation_conversion.py @@ -196,7 +196,7 @@ def _format_neighborhood_vehicle_states( "lane_index": np.zeros((des_shp,), dtype=np.int8), "position": np.zeros((des_shp, 3), dtype=np.float64), "speed": np.zeros((des_shp,), dtype=np.float32), - "interest": np.zeros((des_shp,), dtype=np.bool8), + "interest": np.zeros((des_shp,), dtype=np.int8), } neighborhood_vehicle_states = [ @@ -220,7 +220,7 @@ def _format_neighborhood_vehicle_states( lane_index = np.array(lane_index, dtype=np.int8) pos = np.array(pos, dtype=np.float64) speed = np.array(speed, dtype=np.float32) - interest = np.array(interest, dtype=np.bool8) + interest = np.array(interest, dtype=np.int8) # fmt: off box = np.pad(box, ((0,pad_shp),(0,0)), mode='constant', constant_values=0) @@ -714,9 +714,7 @@ def name(self): "speed": gym.spaces.Box( low=0, high=1e10, shape=(_NEIGHBOR_SHP,), dtype=np.float32 ), - "interest": gym.spaces.Box( - low=0, high=1, shape=(_NEIGHBOR_SHP,), dtype=np.bool8 - ), + "interest": gym.spaces.MultiBinary(_NEIGHBOR_SHP), } ), ) diff --git a/smarts/env/wrappers/format_obs.py b/smarts/env/wrappers/format_obs.py index 132755d65e..dad8bb214a 100644 --- a/smarts/env/wrappers/format_obs.py +++ b/smarts/env/wrappers/format_obs.py @@ -364,7 +364,7 @@ def get_spaces() -> Dict[str, Callable[[Any], gym.Space]]: "lane_index": gym.spaces.Box(low=0, high=127, shape=(_NEIGHBOR_SHP,), dtype=np.int8), "pos": gym.spaces.Box(low=-1e10, high=1e10, shape=(_NEIGHBOR_SHP,3), dtype=np.float64), "speed": gym.spaces.Box(low=0, high=1e10, shape=(_NEIGHBOR_SHP,), dtype=np.float32), - "interest": gym.spaces.Box(low=0, high=1, shape=(_NEIGHBOR_SHP,), dtype=np.bool8), + "interest": gym.spaces.MultiBinary(_NEIGHBOR_SHP), }), "occupancy_grid_map": lambda val: gym.spaces.Box(low=0, high=255,shape=(val.height, val.width, 1), dtype=np.uint8), "top_down_rgb": lambda val: gym.spaces.Box(low=0, high=255, shape=(val.height, val.width, 3), dtype=np.uint8), @@ -521,7 +521,7 @@ def _std_neighborhood_vehicle_states( "lane_index": np.zeros((des_shp,), dtype=np.int8), "pos": np.zeros((des_shp, 3), dtype=np.float64), "speed": np.zeros((des_shp,), dtype=np.float32), - "interest": np.zeros((des_shp,), dtype=np.bool8), + "interest": np.zeros((des_shp,), dtype=np.int8), } nghbs = [ @@ -542,7 +542,7 @@ def _std_neighborhood_vehicle_states( lane_index = np.array(lane_index, dtype=np.int8) pos = np.array(pos, dtype=np.float64) speed = np.array(speed, dtype=np.float32) - interest = np.array(interest, dtype=np.bool8) + interest = np.array(interest, dtype=np.int8) # fmt: off box = np.pad(box, ((0,pad_shp),(0,0)), mode='constant', constant_values=0) From dcd85f6c90d12209ed68f40881f2f88b7d58637d Mon Sep 17 00:00:00 2001 From: Tucker Date: Mon, 24 Apr 2023 12:25:28 -0400 Subject: [PATCH 04/13] Add agent interface. --- smarts/core/agent_interface.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/smarts/core/agent_interface.py b/smarts/core/agent_interface.py index d6f338951f..4a6b2e9522 100644 --- a/smarts/core/agent_interface.py +++ b/smarts/core/agent_interface.py @@ -240,6 +240,8 @@ class DoneCriteria: """If set, triggers the ego agent to be done based on the number of active agents for multi-agent purposes.""" actors_alive: Optional[ActorsAliveDoneCriteria] = None """If set, triggers the ego agent to be done based on actors existing in the simulation.""" + interest: bool = False + """If set, triggers when there are no interest vehicles left existing in the simulation.""" @dataclass From 35e3c742356bb4f66bc09d0e492a2459216b9d74 Mon Sep 17 00:00:00 2001 From: Tucker Date: Mon, 24 Apr 2023 13:04:34 -0400 Subject: [PATCH 05/13] Implement interest actors. --- smarts/core/agent_interface.py | 14 ++++- smarts/core/sensors/__init__.py | 54 +++++++++++++++---- smarts/core/simulation_frame.py | 4 +- smarts/env/gymnasium/platoon_env.py | 2 +- .../env/gymnasium/wrappers/metric/metrics.py | 4 +- 5 files changed, 61 insertions(+), 17 deletions(-) diff --git a/smarts/core/agent_interface.py b/smarts/core/agent_interface.py index 4a6b2e9522..a79b8a3e09 100644 --- a/smarts/core/agent_interface.py +++ b/smarts/core/agent_interface.py @@ -197,10 +197,20 @@ class AgentsAliveDoneCriteria: class ActorsAliveDoneCriteria: """Require actors to persist.""" - actors_of_interest: Tuple[str, ...] = () + actors_filter: Tuple[str, ...] = () """Actors that should exist to continue this agent.""" strict: bool = True + """If strict the agent will be done instantly if a target actor is not available + immediately. + """ + + +@dataclass(frozen=True) +class ScenarioInterestDoneCriteria: + """Require scenario marked interest actors to exist.""" + + strict: bool = False """If strict the agent will be done instantly if an actor of interest is not available immediately. """ @@ -240,7 +250,7 @@ class DoneCriteria: """If set, triggers the ego agent to be done based on the number of active agents for multi-agent purposes.""" actors_alive: Optional[ActorsAliveDoneCriteria] = None """If set, triggers the ego agent to be done based on actors existing in the simulation.""" - interest: bool = False + interest: Optional[ScenarioInterestDoneCriteria] = None """If set, triggers when there are no interest vehicles left existing in the simulation.""" diff --git a/smarts/core/sensors/__init__.py b/smarts/core/sensors/__init__.py index 40589931ba..36ff975a05 100644 --- a/smarts/core/sensors/__init__.py +++ b/smarts/core/sensors/__init__.py @@ -24,7 +24,11 @@ import numpy as np -from smarts.core.agent_interface import ActorsAliveDoneCriteria, AgentsAliveDoneCriteria +from smarts.core.agent_interface import ( + ActorsAliveDoneCriteria, + AgentsAliveDoneCriteria, + ScenarioInterestDoneCriteria, +) from smarts.core.coordinates import Heading, Point from smarts.core.events import Events from smarts.core.observations import ( @@ -105,14 +109,24 @@ def __init__(self, max_episode_steps: int, plan_frame: PlanFrame): self._plan_frame = plan_frame self._step = 0 self._seen_interest_actors = False + self._seen_alive_actors = False def step(self): """Update internal state.""" self._step += 1 + @property + def seen_alive_actors(self) -> bool: + """If an agents alive actor has been spotted before.""" + return self._seen_alive_actors + + @seen_alive_actors.setter + def seen_alive_actors(self, value: bool): + self._seen_alive_actors = value + @property def seen_interest_actors(self) -> bool: - """If a relevant actor has been spotted before.""" + """If an interest actor has been spotted before.""" return self._seen_interest_actors @seen_interest_actors.setter @@ -519,25 +533,42 @@ def _agents_alive_done_check( def _actors_alive_done_check( cls, vehicle_ids, - sensor_state, + sensor_state: SensorState, actors_alive: Optional[ActorsAliveDoneCriteria], ): if actors_alive is None: return False - sensor_state: SensorState = sensor_state - pattern = re.compile( - "|".join(rf"(?:{aoi})" for aoi in actors_alive.actors_of_interest) + "|".join(rf"(?:{aoi})" for aoi in actors_alive.actors_filter) ) - ## TODO optimization to get vehicles that were added and removed last step - ## TODO second optimization to check for already known vehicles + ## TODO optimization to only get vehicles that were removed last step for vehicle_id in vehicle_ids: # get vehicles by pattern if pattern.match(vehicle_id): - sensor_state.seen_interest_actors = True + sensor_state.seen_alive_actors = True return False - if actors_alive.strict or sensor_state.seen_interest_actors: + if actors_alive.strict or sensor_state.seen_alive_actors: + # if agent requires the actor to exist immediately + # OR if previously seen relevant actors but no actors match anymore + return True + + ## if never seen a relevant actor + return False + + @classmethod + def _interest_done_check( + cls, + interest_actors, + sensor_state: SensorState, + scenario_interest: Optional[ScenarioInterestDoneCriteria], + ): + if scenario_interest is None: + return False + + if len(interest_actors): + return True + if scenario_interest.strict or sensor_state.seen_interest_actors: # if agent requires the actor to exist immediately # OR if previously seen relevant actors but no actors match anymore return True @@ -588,6 +619,9 @@ def _is_done_with_events( actors_alive_done = cls._actors_alive_done_check( sim_frame.vehicle_ids, sensor_state, done_criteria.actors_alive ) + interest_done = cls._interest_done_check( + sim_frame.interest_actors, sensor_state, done_criteria.interest + ) done = not sim_frame.resetting and ( (is_off_road and done_criteria.off_road) diff --git a/smarts/core/simulation_frame.py b/smarts/core/simulation_frame.py index 74a6b0a8fd..5934438b04 100644 --- a/smarts/core/simulation_frame.py +++ b/smarts/core/simulation_frame.py @@ -77,7 +77,7 @@ def actor_states_by_id(self) -> Dict[str, ActorState]: return {a_s.actor_id: a_s for a_s in self.actor_states} @cached_property - def _interest_actors(self) -> Dict[str, ActorState]: + def interest_actors(self) -> Dict[str, ActorState]: """Get the actor states of actors that are marked as of interest.""" if self.interest_filter.pattern: return { @@ -96,7 +96,7 @@ def actor_is_interest(self, actor_id) -> bool: Returns: bool: If the actor is of interest. """ - return actor_id in self._interest_actors + return actor_id in self.interest_actors def vehicle_did_collide(self, vehicle_id) -> bool: """Test if the given vehicle had any collisions in the last physics update.""" diff --git a/smarts/env/gymnasium/platoon_env.py b/smarts/env/gymnasium/platoon_env.py index b5c600d148..771c6cdef4 100644 --- a/smarts/env/gymnasium/platoon_env.py +++ b/smarts/env/gymnasium/platoon_env.py @@ -163,7 +163,7 @@ def resolve_agent_interface(agent_interface: AgentInterface): wrong_way=False, not_moving=False, actors_alive=ActorsAliveDoneCriteria( - actors_of_interest=("Leader-007",), + actors_filter=("Leader-007",), strict=True, ), ) diff --git a/smarts/env/gymnasium/wrappers/metric/metrics.py b/smarts/env/gymnasium/wrappers/metric/metrics.py index 585feecd1b..9d19d4f91b 100644 --- a/smarts/env/gymnasium/wrappers/metric/metrics.py +++ b/smarts/env/gymnasium/wrappers/metric/metrics.py @@ -195,7 +195,7 @@ def reset(self, **kwargs): ].done_criteria.actors_alive if isinstance(actors_alive, ActorsAliveDoneCriteria): end_pos, dist_tot = _get_sumo_smarts_dist( - vehicle_name=actors_alive.actors_of_interest[0], + vehicle_name=actors_alive.actors_filter[0], traffic_sims=self.env.smarts.traffic_sims, road_map=self._road_map, ) @@ -397,7 +397,7 @@ def check_intrfc(agent_intrfc: AgentInterface): if ( params.dist_to_destination.active and isinstance(actors_alive, ActorsAliveDoneCriteria) - and len(actors_alive.actors_of_interest) != 1 + and len(actors_alive.actors_filter) != 1 ): raise AttributeError( ( From 23084f511c08331ad7e0a5bb8c5ec771cd339c28 Mon Sep 17 00:00:00 2001 From: Tucker Date: Mon, 24 Apr 2023 13:16:16 -0400 Subject: [PATCH 06/13] Remove test warning. --- envision/tests/test_data_formatter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/envision/tests/test_data_formatter.py b/envision/tests/test_data_formatter.py index 955dd9e00f..6952cad2b2 100644 --- a/envision/tests/test_data_formatter.py +++ b/envision/tests/test_data_formatter.py @@ -373,7 +373,7 @@ def smarts(): # envision.teardown = MagicMock() smarts = SMARTS( agents, - traffic_sim=SumoTrafficSimulation(), + traffic_sims=[SumoTrafficSimulation()], envision=envision, ) yield smarts From f5d767441fd729eb49079a78fc98dafce2639063 Mon Sep 17 00:00:00 2001 From: Tucker Date: Mon, 24 Apr 2023 18:17:53 -0400 Subject: [PATCH 07/13] Combine actors_alive and interest. --- docs/benchmarks/driving_smarts_2023_3.rst | 2 +- examples/rl/platoon/train/reward.py | 2 +- smarts/core/agent_interface.py | 26 ++++------ smarts/core/events.py | 7 ++- smarts/core/sensors/__init__.py | 51 ++++++------------- smarts/core/utils/tests/fixtures.py | 2 +- smarts/env/gymnasium/platoon_env.py | 6 +-- smarts/env/gymnasium/wrappers/metric/costs.py | 4 +- .../env/gymnasium/wrappers/metric/metrics.py | 36 ++++++------- smarts/env/utils/observation_conversion.py | 12 ++--- zoo/policies/__init__.py | 2 +- 11 files changed, 66 insertions(+), 84 deletions(-) diff --git a/docs/benchmarks/driving_smarts_2023_3.rst b/docs/benchmarks/driving_smarts_2023_3.rst index f97bf216af..4c11a69204 100644 --- a/docs/benchmarks/driving_smarts_2023_3.rst +++ b/docs/benchmarks/driving_smarts_2023_3.rst @@ -18,7 +18,7 @@ Objective is to develop a single-ego policy capable of controlling a single ego Each ego is supposed to track and follow its specified leader (i.e., lead vehicle) in a single file or in a platoon fashion. The name identifier of the lead vehicle to be followed is given to the ego through the configuration -of the :attr:`~smarts.core.agent_interface.ActorsAliveDoneCriteria.actors_of_interest` attribute. +of the :attr:`~smarts.core.agent_interface.InterestDoneCriteria.actors_of_interest` attribute. .. figure:: ../_static/driving_smarts_2023/vehicle_following.png diff --git a/examples/rl/platoon/train/reward.py b/examples/rl/platoon/train/reward.py index d7150be12c..ca2390bc7c 100644 --- a/examples/rl/platoon/train/reward.py +++ b/examples/rl/platoon/train/reward.py @@ -45,7 +45,7 @@ def step(self, action): agent_obs["events"]["collisions"] | agent_obs["events"]["off_road"] ): pass - elif agent_obs["events"]["actors_alive_done"]: + elif agent_obs["events"]["interest_done"]: print(f"{agent_id}: Actors alive done triggered.") else: print("Events: ", agent_obs["events"]) diff --git a/smarts/core/agent_interface.py b/smarts/core/agent_interface.py index a79b8a3e09..bd29721ee0 100644 --- a/smarts/core/agent_interface.py +++ b/smarts/core/agent_interface.py @@ -194,21 +194,14 @@ class AgentsAliveDoneCriteria: @dataclass(frozen=True) -class ActorsAliveDoneCriteria: - """Require actors to persist.""" +class InterestDoneCriteria: + """Require scenario marked interest actors to exist.""" actors_filter: Tuple[str, ...] = () - """Actors that should exist to continue this agent.""" + """Interface defined interest actors that should exist to continue this agent.""" - strict: bool = True - """If strict the agent will be done instantly if a target actor is not available - immediately. - """ - - -@dataclass(frozen=True) -class ScenarioInterestDoneCriteria: - """Require scenario marked interest actors to exist.""" + include_scenario_marked: bool = True + """If scenario marked interest actors should be considered as interest vehicles.""" strict: bool = False """If strict the agent will be done instantly if an actor of interest is not available @@ -248,11 +241,14 @@ class DoneCriteria: """ agents_alive: Optional[AgentsAliveDoneCriteria] = None """If set, triggers the ego agent to be done based on the number of active agents for multi-agent purposes.""" - actors_alive: Optional[ActorsAliveDoneCriteria] = None - """If set, triggers the ego agent to be done based on actors existing in the simulation.""" - interest: Optional[ScenarioInterestDoneCriteria] = None + interest: Optional[InterestDoneCriteria] = None """If set, triggers when there are no interest vehicles left existing in the simulation.""" + @property + def actors_alive(self): + """Deprecated. Use interest.""" + raise NameError("Deprecated. Use interest.") + @dataclass class AgentInterface: diff --git a/smarts/core/events.py b/smarts/core/events.py index bccde7828c..61fb939af5 100644 --- a/smarts/core/events.py +++ b/smarts/core/events.py @@ -42,5 +42,10 @@ class Events(NamedTuple): agents_alive_done: bool """True if all configured co-simulating agents are done (if any), else False. This is useful for cases when the vehicle is related to other vehicles.""" - actors_alive_done: bool + interest_done: bool """True if described actors have left the simulation.""" + + @property + def actors_alive_done(self): + """Deprecated. Use interest_done.""" + raise NameError("Deprecated. Use interest_done.") diff --git a/smarts/core/sensors/__init__.py b/smarts/core/sensors/__init__.py index 36ff975a05..73e4cc4e46 100644 --- a/smarts/core/sensors/__init__.py +++ b/smarts/core/sensors/__init__.py @@ -24,11 +24,7 @@ import numpy as np -from smarts.core.agent_interface import ( - ActorsAliveDoneCriteria, - AgentsAliveDoneCriteria, - ScenarioInterestDoneCriteria, -) +from smarts.core.agent_interface import AgentsAliveDoneCriteria, InterestDoneCriteria from smarts.core.coordinates import Heading, Point from smarts.core.events import Events from smarts.core.observations import ( @@ -530,45 +526,28 @@ def _agents_alive_done_check( return False @classmethod - def _actors_alive_done_check( + def _interest_done_check( cls, vehicle_ids, + interest_actors, sensor_state: SensorState, - actors_alive: Optional[ActorsAliveDoneCriteria], + interest_criteria: Optional[InterestDoneCriteria], ): - if actors_alive is None: + if interest_criteria is None: return False pattern = re.compile( - "|".join(rf"(?:{aoi})" for aoi in actors_alive.actors_filter) + "|".join(rf"(?:{aoi})" for aoi in interest_criteria.actors_filter) ) + if interest_criteria.include_scenario_marked and len(interest_actors): + return True ## TODO optimization to only get vehicles that were removed last step for vehicle_id in vehicle_ids: # get vehicles by pattern if pattern.match(vehicle_id): sensor_state.seen_alive_actors = True return False - if actors_alive.strict or sensor_state.seen_alive_actors: - # if agent requires the actor to exist immediately - # OR if previously seen relevant actors but no actors match anymore - return True - - ## if never seen a relevant actor - return False - - @classmethod - def _interest_done_check( - cls, - interest_actors, - sensor_state: SensorState, - scenario_interest: Optional[ScenarioInterestDoneCriteria], - ): - if scenario_interest is None: - return False - - if len(interest_actors): - return True - if scenario_interest.strict or sensor_state.seen_interest_actors: + if interest_criteria.strict or sensor_state.seen_interest_actors: # if agent requires the actor to exist immediately # OR if previously seen relevant actors but no actors match anymore return True @@ -616,11 +595,11 @@ def _is_done_with_events( agents_alive_done = cls._agents_alive_done_check( sim_frame.ego_ids, sim_frame.potential_agent_ids, done_criteria.agents_alive ) - actors_alive_done = cls._actors_alive_done_check( - sim_frame.vehicle_ids, sensor_state, done_criteria.actors_alive - ) interest_done = cls._interest_done_check( - sim_frame.interest_actors, sensor_state, done_criteria.interest + sim_frame.vehicle_ids, + sim_frame.interest_actors, + sensor_state, + done_criteria.interest, ) done = not sim_frame.resetting and ( @@ -633,7 +612,7 @@ def _is_done_with_events( or (is_off_route and done_criteria.off_route) or (is_wrong_way and done_criteria.wrong_way) or agents_alive_done - or actors_alive_done + or interest_done ) events = Events( @@ -646,7 +625,7 @@ def _is_done_with_events( wrong_way=is_wrong_way, not_moving=is_not_moving, agents_alive_done=agents_alive_done, - actors_alive_done=actors_alive_done, + interest_done=interest_done, ) return done, events diff --git a/smarts/core/utils/tests/fixtures.py b/smarts/core/utils/tests/fixtures.py index 357cef105e..3d9a6ec726 100644 --- a/smarts/core/utils/tests/fixtures.py +++ b/smarts/core/utils/tests/fixtures.py @@ -59,7 +59,7 @@ def large_observation(): reached_goal=False, reached_max_episode_steps=False, agents_alive_done=False, - actors_alive_done=False, + interest_done=False, ), ego_vehicle_state=EgoVehicleObservation( id="AGENT-007-07a0ca6e", diff --git a/smarts/env/gymnasium/platoon_env.py b/smarts/env/gymnasium/platoon_env.py index 771c6cdef4..34a44c798a 100644 --- a/smarts/env/gymnasium/platoon_env.py +++ b/smarts/env/gymnasium/platoon_env.py @@ -25,9 +25,9 @@ from envision.client import Client as Envision from envision.client import EnvisionDataFormatterArgs from smarts.core.agent_interface import ( - ActorsAliveDoneCriteria, AgentInterface, DoneCriteria, + InterestDoneCriteria, NeighborhoodVehicles, Waypoints, ) @@ -55,7 +55,7 @@ def platoon_env( """Each ego is supposed to track and follow its specified leader (i.e., lead vehicle) in a single file or in a platoon fashion. The name of the lead vehicle to track is given to the ego through its - :attr:`~smarts.core.agent_interface.ActorsAliveDoneCriteria.actors_of_interest` attribute. + :attr:`~smarts.core.agent_interface.InterestDoneCriteria.actors_of_interest` attribute. The episode ends for an ego when its assigned leader reaches the leader's destination. Egos do not have prior knowledge of their assigned leader's destination. @@ -162,7 +162,7 @@ def resolve_agent_interface(agent_interface: AgentInterface): on_shoulder=False, wrong_way=False, not_moving=False, - actors_alive=ActorsAliveDoneCriteria( + interest=InterestDoneCriteria( actors_filter=("Leader-007",), strict=True, ), diff --git a/smarts/env/gymnasium/wrappers/metric/costs.py b/smarts/env/gymnasium/wrappers/metric/costs.py index aad799b82c..3111b46ac4 100644 --- a/smarts/env/gymnasium/wrappers/metric/costs.py +++ b/smarts/env/gymnasium/wrappers/metric/costs.py @@ -336,7 +336,7 @@ def func( if not done: return Costs(steps=-1) - if obs.events.reached_goal or obs.events.actors_alive_done: + if obs.events.reached_goal or obs.events.interest_done: return Costs(steps=step / max_episode_steps) elif ( len(obs.events.collisions) > 0 @@ -347,7 +347,7 @@ def func( else: raise CostError( "Expected reached_goal, collisions, off_road, " - "max_episode_steps, or actors_alive_done, to be true " + "max_episode_steps, or interest_done, to be true " f"on agent done, but got events: {obs.events}." ) diff --git a/smarts/env/gymnasium/wrappers/metric/metrics.py b/smarts/env/gymnasium/wrappers/metric/metrics.py index 9d19d4f91b..9df85e0c18 100644 --- a/smarts/env/gymnasium/wrappers/metric/metrics.py +++ b/smarts/env/gymnasium/wrappers/metric/metrics.py @@ -24,7 +24,7 @@ import gymnasium as gym -from smarts.core.agent_interface import ActorsAliveDoneCriteria, AgentInterface +from smarts.core.agent_interface import AgentInterface, InterestDoneCriteria from smarts.core.coordinates import Point, RefLinePoint from smarts.core.observations import Observation from smarts.core.plan import EndlessGoal, PositionalGoal @@ -143,11 +143,11 @@ def step(self, action: Dict[str, Any]): or len(base_obs.events.collisions) or base_obs.events.off_road or base_obs.events.reached_max_episode_steps - or base_obs.events.actors_alive_done + or base_obs.events.interest_done ): raise MetricsError( "Expected reached_goal, collisions, off_road, " - "max_episode_steps, or actors_alive_done, to be true " + "max_episode_steps, or interest_done, to be true " f"on agent done, but got events: {base_obs.events}." ) @@ -190,16 +190,16 @@ def reset(self, **kwargs): end_pos = Point(0, 0, 0) dist_tot = 0 if self._params.dist_to_destination.active: - actors_alive = self.env.agent_interfaces[ + interest_criteria = self.env.agent_interfaces[ agent_name - ].done_criteria.actors_alive - if isinstance(actors_alive, ActorsAliveDoneCriteria): + ].done_criteria.interest + if isinstance(interest_criteria, InterestDoneCriteria): end_pos, dist_tot = _get_sumo_smarts_dist( - vehicle_name=actors_alive.actors_filter[0], + vehicle_name=interest_criteria.actors_filter[0], traffic_sims=self.env.smarts.traffic_sims, road_map=self._road_map, ) - elif actors_alive == None: + elif interest_criteria == None: end_pos = self._scen.missions[agent_name].goal.position dist_tot = get_dist( road_map=self._road_map, @@ -393,19 +393,19 @@ def check_intrfc(agent_intrfc: AgentInterface): ).format(agent_name, intrfc) ) - actors_alive = agent_interface.done_criteria.actors_alive + interest_criteria = agent_interface.done_criteria.interest if ( params.dist_to_destination.active - and isinstance(actors_alive, ActorsAliveDoneCriteria) - and len(actors_alive.actors_filter) != 1 + and isinstance(interest_criteria, InterestDoneCriteria) + and len(interest_criteria.actors_filter) != 1 ): raise AttributeError( ( - "ActorsAliveDoneCriteria with none or multiple actors of " + "InterestDoneCriteria with none or multiple actors of " "interest is currently not supported when " "dist_to_destination cost function is enabled. Current " "interface is {0}:{1}." - ).format(agent_name, actors_alive) + ).format(agent_name, interest_criteria) ) @@ -425,15 +425,17 @@ def _check_scen(scenario: Scenario, agent_interfaces: Dict[str, AgentInterface]) } for agent_name, agent_interface in agent_interfaces.items(): - actors_alive = agent_interface.done_criteria.actors_alive + interest_criteria = agent_interface.done_criteria.interest if not ( - (goal_types[agent_name] == PositionalGoal and actors_alive == None) + (goal_types[agent_name] == PositionalGoal and interest_criteria is None) or ( goal_types[agent_name] == EndlessGoal - and isinstance(actors_alive, ActorsAliveDoneCriteria) + and isinstance(interest_criteria, InterestDoneCriteria) ) ): raise AttributeError( "{0} has an unsupported goal type {1} and actors alive done criteria {2} " - "combination.".format(agent_name, goal_types[agent_name], actors_alive) + "combination.".format( + agent_name, goal_types[agent_name], interest_criteria + ) ) diff --git a/smarts/env/utils/observation_conversion.py b/smarts/env/utils/observation_conversion.py index 9404d47e1a..6f5c822ae7 100644 --- a/smarts/env/utils/observation_conversion.py +++ b/smarts/env/utils/observation_conversion.py @@ -579,10 +579,10 @@ def __call__(self, agent_interface: AgentInterface) -> BaseSpaceFormat: ) -events_actors_alive_done_space_format = StandardSpaceFormat( - lambda obs: np.int64(obs.events.actors_alive_done), +events_interest_done_space_format = StandardSpaceFormat( + lambda obs: np.int64(obs.events.interest_done), lambda _: True, - "actors_alive_done", + "interest_done", _DISCRETE2_SPACE, ) @@ -854,7 +854,7 @@ def name(self): events_space_format = StandardCompoundSpaceFormat( space_generators=[ - events_actors_alive_done_space_format, + events_interest_done_space_format, events_agents_alive_done_space_format, events_collisions_space_format, events_not_moving_space_format, @@ -965,8 +965,8 @@ class ObservationSpacesFormatter: A dictionary of event markers. "events": dict({ - "actors_alive_done": - 1 if `DoneCriteria.actors_alive` is triggered, else 0. + "interest_done": + 1 if `DoneCriteria.interest` is triggered, else 0. "agents_alive_done": 1 if `DoneCriteria.agents_alive` is triggered, else 0. "collisions": diff --git a/zoo/policies/__init__.py b/zoo/policies/__init__.py index 055014c14c..abb505aeaa 100644 --- a/zoo/policies/__init__.py +++ b/zoo/policies/__init__.py @@ -56,7 +56,7 @@ wrong_way=False, not_moving=False, agents_alive=None, - actors_alive=None, + interest=None, ), accelerometer=False, drivable_area_grid_map=False, From 6bfd81cbd94cabc35895eed03f47f172713b357b Mon Sep 17 00:00:00 2001 From: Tucker Date: Mon, 24 Apr 2023 18:20:10 -0400 Subject: [PATCH 08/13] Fix error call. --- smarts/core/smarts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smarts/core/smarts.py b/smarts/core/smarts.py index dc7eccd304..34d2f33942 100644 --- a/smarts/core/smarts.py +++ b/smarts/core/smarts.py @@ -627,7 +627,7 @@ def switch_control_to_agent( ), f"Vehicle has already been hijacked: {vehicle_id}" assert not vehicle_id in self.vehicle_index.agent_vehicle_ids(), ( f"`{agent_id}` can't hijack vehicle that is already controlled by an agent" - f" `{self.vehicle_index.actor_id_from_vehicle_id(vehicle_id)}`: {vehicle_id}" + f" `{self.agent_manager.agent_for_vehicle(vehicle_id)}`: {vehicle_id}" ) # Switch control to agent From f3bb0c0fe2f629fb751fdf129856387d2b5e1592 Mon Sep 17 00:00:00 2001 From: Tucker Date: Mon, 24 Apr 2023 19:07:26 -0400 Subject: [PATCH 09/13] Add interface defined interest to observations. --- smarts/core/agent_interface.py | 7 +++++++ smarts/core/sensors/__init__.py | 36 +++++++++++++++++---------------- smarts/core/simulation_frame.py | 26 ++++++++++++++++++------ 3 files changed, 46 insertions(+), 23 deletions(-) diff --git a/smarts/core/agent_interface.py b/smarts/core/agent_interface.py index bd29721ee0..23fa799fd1 100644 --- a/smarts/core/agent_interface.py +++ b/smarts/core/agent_interface.py @@ -17,9 +17,11 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +import re import warnings from dataclasses import dataclass, field, replace from enum import IntEnum +from functools import cached_property from typing import List, Optional, Tuple, Union from smarts.core.controllers.action_space_type import ActionSpaceType @@ -208,6 +210,11 @@ class InterestDoneCriteria: immediately. """ + @cached_property + def actors_pattern(self) -> re.Pattern: + """The expression match pattern for actors covered by this interface specifically.""" + return re.compile("|".join(rf"(?:{aoi})" for aoi in self.actors_filter)) + @dataclass(frozen=True) class EventConfiguration: diff --git a/smarts/core/sensors/__init__.py b/smarts/core/sensors/__init__.py index 73e4cc4e46..4742644450 100644 --- a/smarts/core/sensors/__init__.py +++ b/smarts/core/sensors/__init__.py @@ -71,7 +71,10 @@ def _make_vehicle_observation( - road_map, neighborhood_vehicle: VehicleState, sim_frame: SimulationFrame + road_map, + neighborhood_vehicle: VehicleState, + sim_frame: SimulationFrame, + interest_extension: Optional[re.Pattern], ): nv_lane = road_map.nearest_lane(neighborhood_vehicle.pose.point, radius=3) if nv_lane: @@ -93,7 +96,9 @@ def _make_vehicle_observation( lane_id=nv_lane_id, lane_index=nv_lane_index, lane_position=None, - interest=sim_frame.actor_is_interest(neighborhood_vehicle.actor_id), + interest=sim_frame.actor_is_interest( + neighborhood_vehicle.actor_id, extension=interest_extension + ), ) @@ -272,11 +277,18 @@ def process_serialization_safe_sensors( ) if neighborhood_vehicle_states_sensor: neighborhood_vehicle_states = [] + interface = sim_frame.agent_interfaces.get(agent_id) + interest_pattern = ( + interface.done_criteria.interest.actors_pattern + if interface is not None + and interface.done_criteria.interest is not None + else None + ) for nv in neighborhood_vehicle_states_sensor( vehicle_state, sim_frame.vehicle_states.values() ): veh_obs = _make_vehicle_observation( - sim_local_constants.road_map, nv, sim_frame + sim_local_constants.road_map, nv, sim_frame, interest_pattern ) lane_position_sensor = vehicle_sensors.get("lane_position_sensor") nv_lane_pos = None @@ -533,20 +545,10 @@ def _interest_done_check( sensor_state: SensorState, interest_criteria: Optional[InterestDoneCriteria], ): - if interest_criteria is None: + if len(interest_actors) > 0: + sensor_state.seen_alive_actors = True return False - pattern = re.compile( - "|".join(rf"(?:{aoi})" for aoi in interest_criteria.actors_filter) - ) - if interest_criteria.include_scenario_marked and len(interest_actors): - return True - ## TODO optimization to only get vehicles that were removed last step - for vehicle_id in vehicle_ids: - # get vehicles by pattern - if pattern.match(vehicle_id): - sensor_state.seen_alive_actors = True - return False if interest_criteria.strict or sensor_state.seen_interest_actors: # if agent requires the actor to exist immediately # OR if previously seen relevant actors but no actors match anymore @@ -597,9 +599,9 @@ def _is_done_with_events( ) interest_done = cls._interest_done_check( sim_frame.vehicle_ids, - sim_frame.interest_actors, + sim_frame.interest_actors(interface.done_criteria.interest.actors_pattern), sensor_state, - done_criteria.interest, + interest_criteria=interface.done_criteria.interest, ) done = not sim_frame.resetting and ( diff --git a/smarts/core/simulation_frame.py b/smarts/core/simulation_frame.py index 5934438b04..af49a82901 100644 --- a/smarts/core/simulation_frame.py +++ b/smarts/core/simulation_frame.py @@ -22,6 +22,7 @@ import logging import re from dataclasses import dataclass +from functools import lru_cache from typing import Any, Dict, List, Optional, Set from cached_property import cached_property @@ -76,18 +77,31 @@ def actor_states_by_id(self) -> Dict[str, ActorState]: """Get actor states paired by their ids.""" return {a_s.actor_id: a_s for a_s in self.actor_states} - @cached_property - def interest_actors(self) -> Dict[str, ActorState]: - """Get the actor states of actors that are marked as of interest.""" + @lru_cache + def interest_actors(self, extension: Optional[re.Pattern] = None) -> bool: + """Get the actor states of actors that are marked as of interest. + + Args: + extension (re.Pattern): A matching for interest actors not defined in scenario. + """ + + _matchers: List[re.Pattern] = [] if self.interest_filter.pattern: + _matchers.append(self.interest_filter) + if extension is not None and extension.pattern: + _matchers.append(extension) + + if len(_matchers) == 0: return { a_s.actor_id: a_s for a_s in self.actor_states - if self.interest_filter.match(a_s.actor_id) + if any(bool(m.match(a_s.actor_id)) for m in _matchers) } return {} - def actor_is_interest(self, actor_id) -> bool: + def actor_is_interest( + self, actor_id, extension: Optional[re.Pattern] = None + ) -> bool: """Determine if the actor is of interest. Args: @@ -96,7 +110,7 @@ def actor_is_interest(self, actor_id) -> bool: Returns: bool: If the actor is of interest. """ - return actor_id in self.interest_actors + return actor_id in self.interest_actors(extension) def vehicle_did_collide(self, vehicle_id) -> bool: """Test if the given vehicle had any collisions in the last physics update.""" From 7826d156fd32bf429ac4c0a4c04b5b296bf02ae4 Mon Sep 17 00:00:00 2001 From: Tucker Date: Mon, 24 Apr 2023 20:32:01 -0400 Subject: [PATCH 10/13] Fix hashing error. --- smarts/core/sensors/__init__.py | 15 ++++++++------- smarts/core/simulation_frame.py | 11 +++++++++-- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/smarts/core/sensors/__init__.py b/smarts/core/sensors/__init__.py index 4742644450..7d9e0bb393 100644 --- a/smarts/core/sensors/__init__.py +++ b/smarts/core/sensors/__init__.py @@ -540,7 +540,6 @@ def _agents_alive_done_check( @classmethod def _interest_done_check( cls, - vehicle_ids, interest_actors, sensor_state: SensorState, interest_criteria: Optional[InterestDoneCriteria], @@ -572,6 +571,7 @@ def _is_done_with_events( interface = sim_frame.agent_interfaces[agent_id] done_criteria = interface.done_criteria event_config = interface.event_configuration + interest = interface.done_criteria.interest # TODO: the following calls nearest_lanes (expensive) 6 times reached_goal = cls._agent_reached_goal( @@ -597,12 +597,13 @@ def _is_done_with_events( agents_alive_done = cls._agents_alive_done_check( sim_frame.ego_ids, sim_frame.potential_agent_ids, done_criteria.agents_alive ) - interest_done = cls._interest_done_check( - sim_frame.vehicle_ids, - sim_frame.interest_actors(interface.done_criteria.interest.actors_pattern), - sensor_state, - interest_criteria=interface.done_criteria.interest, - ) + interest_done = False + if interest: + cls._interest_done_check( + sim_frame.interest_actors(interest.actors_pattern), + sensor_state, + interest_criteria=interest, + ) done = not sim_frame.resetting and ( (is_off_road and done_criteria.off_road) diff --git a/smarts/core/simulation_frame.py b/smarts/core/simulation_frame.py index af49a82901..ca9923d77e 100644 --- a/smarts/core/simulation_frame.py +++ b/smarts/core/simulation_frame.py @@ -41,7 +41,7 @@ class SimulationFrame: actor_states: List[ActorState] agent_vehicle_controls: Dict[str, str] agent_interfaces: Dict[str, AgentInterface] - ego_ids: str + ego_ids: Set[str] pending_agent_ids: List[str] elapsed_sim_time: float fixed_timestep: float @@ -77,7 +77,7 @@ def actor_states_by_id(self) -> Dict[str, ActorState]: """Get actor states paired by their ids.""" return {a_s.actor_id: a_s for a_s in self.actor_states} - @lru_cache + @lru_cache(1) def interest_actors(self, extension: Optional[re.Pattern] = None) -> bool: """Get the actor states of actors that are marked as of interest. @@ -129,6 +129,13 @@ def filtered_vehicle_collisions(self, vehicle_id) -> List[Collision]: c for c in vehicle_collisions if c.collidee_id != self._ground_bullet_id ] + @cached_property + def _hash(self): + return self.step_count ^ hash(self.fixed_timestep) ^ hash(self.map_spec) + + def __hash__(self): + return self._hash + def __post_init__(self): if logger.isEnabledFor(logging.DEBUG): assert isinstance(self.actor_states, list) From 2da33602b8696dcb087a5ebe9f23ebd9f1cf7a4f Mon Sep 17 00:00:00 2001 From: Tucker Date: Mon, 24 Apr 2023 21:50:08 -0400 Subject: [PATCH 11/13] Fix typing. --- smarts/core/simulation_frame.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/smarts/core/simulation_frame.py b/smarts/core/simulation_frame.py index ca9923d77e..88d58a92be 100644 --- a/smarts/core/simulation_frame.py +++ b/smarts/core/simulation_frame.py @@ -78,7 +78,9 @@ def actor_states_by_id(self) -> Dict[str, ActorState]: return {a_s.actor_id: a_s for a_s in self.actor_states} @lru_cache(1) - def interest_actors(self, extension: Optional[re.Pattern] = None) -> bool: + def interest_actors( + self, extension: Optional[re.Pattern] = None + ) -> Dict[str, ActorState]: """Get the actor states of actors that are marked as of interest. Args: @@ -134,7 +136,7 @@ def _hash(self): return self.step_count ^ hash(self.fixed_timestep) ^ hash(self.map_spec) def __hash__(self): - return self._hash + return self._hash def __post_init__(self): if logger.isEnabledFor(logging.DEBUG): From ff04e9507c44a4717f8688ec249d28a6146a15a1 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Tue, 25 Apr 2023 13:05:14 +0000 Subject: [PATCH 12/13] Clear Optional typing issue. --- smarts/core/sensors/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smarts/core/sensors/__init__.py b/smarts/core/sensors/__init__.py index 7d9e0bb393..e2f08634d9 100644 --- a/smarts/core/sensors/__init__.py +++ b/smarts/core/sensors/__init__.py @@ -544,7 +544,7 @@ def _interest_done_check( sensor_state: SensorState, interest_criteria: Optional[InterestDoneCriteria], ): - if len(interest_actors) > 0: + if interest_criteria is None or len(interest_actors) > 0: sensor_state.seen_alive_actors = True return False From d7ad346533021681167322d828af244ab414f26c Mon Sep 17 00:00:00 2001 From: Tucker Alban Date: Tue, 25 Apr 2023 10:20:30 -0400 Subject: [PATCH 13/13] Update examples/rl/platoon/train/reward.py Co-authored-by: adai --- examples/rl/platoon/train/reward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/rl/platoon/train/reward.py b/examples/rl/platoon/train/reward.py index ca2390bc7c..11be320e17 100644 --- a/examples/rl/platoon/train/reward.py +++ b/examples/rl/platoon/train/reward.py @@ -46,7 +46,7 @@ def step(self, action): ): pass elif agent_obs["events"]["interest_done"]: - print(f"{agent_id}: Actors alive done triggered.") + print(f"{agent_id}: Interest done triggered.") else: print("Events: ", agent_obs["events"]) raise Exception("Episode ended for unknown reason.")