Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add low dimension interest observation. #1977

Merged
merged 13 commits into from
Apr 25, 2023
2 changes: 1 addition & 1 deletion docs/benchmarks/driving_smarts_2023_3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Objective is to develop a single-ego policy capable of controlling a single ego

Each ego is supposed to track and follow its specified leader (i.e., lead vehicle) in a single file or in a
platoon fashion. The name identifier of the lead vehicle to be followed is given to the ego through the configuration
of the :attr:`~smarts.core.agent_interface.ActorsAliveDoneCriteria.actors_of_interest` attribute.
of the :attr:`~smarts.core.agent_interface.InterestDoneCriteria.actors_of_interest` attribute.

.. figure:: ../_static/driving_smarts_2023/vehicle_following.png

Expand Down
2 changes: 1 addition & 1 deletion envision/tests/test_data_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def smarts():
# envision.teardown = MagicMock()
smarts = SMARTS(
agents,
traffic_sim=SumoTrafficSimulation(),
traffic_sims=[SumoTrafficSimulation()],
envision=envision,
)
yield smarts
Expand Down
4 changes: 2 additions & 2 deletions examples/rl/platoon/train/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def step(self, action):
agent_obs["events"]["collisions"] | agent_obs["events"]["off_road"]
):
pass
elif agent_obs["events"]["actors_alive_done"]:
print(f"{agent_id}: Actors alive done triggered.")
elif agent_obs["events"]["interest_done"]:
print(f"{agent_id}: Interest done triggered.")
else:
print("Events: ", agent_obs["events"])
raise Exception("Episode ended for unknown reason.")
Expand Down
29 changes: 22 additions & 7 deletions smarts/core/agent_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import re
import warnings
from dataclasses import dataclass, field, replace
from enum import IntEnum
from functools import cached_property
from typing import List, Optional, Tuple, Union

from smarts.core.controllers.action_space_type import ActionSpaceType
Expand Down Expand Up @@ -194,17 +196,25 @@ class AgentsAliveDoneCriteria:


@dataclass(frozen=True)
class ActorsAliveDoneCriteria:
"""Require actors to persist."""
class InterestDoneCriteria:
"""Require scenario marked interest actors to exist."""

actors_of_interest: Tuple[str, ...] = ()
"""Actors that should exist to continue this agent."""
actors_filter: Tuple[str, ...] = ()
"""Interface defined interest actors that should exist to continue this agent."""

strict: bool = True
include_scenario_marked: bool = True
"""If scenario marked interest actors should be considered as interest vehicles."""

strict: bool = False
"""If strict the agent will be done instantly if an actor of interest is not available
immediately.
"""

@cached_property
def actors_pattern(self) -> re.Pattern:
"""The expression match pattern for actors covered by this interface specifically."""
return re.compile("|".join(rf"(?:{aoi})" for aoi in self.actors_filter))

Comment on lines +213 to +217
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to place this method elsewhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Placing it elsewhere would complicate its use. It is fine since it has no side effects.

My consideration is if I should compress actors_pattern and actors_filter to be a single attribute since it is not clear that actors_filter uses regular expression matching.


@dataclass(frozen=True)
class EventConfiguration:
Expand Down Expand Up @@ -238,8 +248,13 @@ class DoneCriteria:
"""
agents_alive: Optional[AgentsAliveDoneCriteria] = None
"""If set, triggers the ego agent to be done based on the number of active agents for multi-agent purposes."""
actors_alive: Optional[ActorsAliveDoneCriteria] = None
"""If set, triggers the ego agent to be done based on actors existing in the simulation."""
interest: Optional[InterestDoneCriteria] = None
"""If set, triggers when there are no interest vehicles left existing in the simulation."""

@property
def actors_alive(self):
"""Deprecated. Use interest."""
raise NameError("Deprecated. Use interest.")


@dataclass
Expand Down
7 changes: 6 additions & 1 deletion smarts/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,10 @@ class Events(NamedTuple):
agents_alive_done: bool
"""True if all configured co-simulating agents are done (if any), else False.
This is useful for cases when the vehicle is related to other vehicles."""
actors_alive_done: bool
interest_done: bool
"""True if described actors have left the simulation."""

@property
def actors_alive_done(self):
"""Deprecated. Use interest_done."""
raise NameError("Deprecated. Use interest_done.")
2 changes: 2 additions & 0 deletions smarts/core/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class VehicleObservation(NamedTuple):
lane_position: Optional[RefLinePoint] = None
"""(s,t,h) coordinates within the lane, where s is the longitudinal offset along the lane, t is the lateral displacement from the lane center, and h (not yet supported) is the vertical displacement from the lane surface.
See the Reference Line coordinate system in OpenDRIVE here: https://www.asam.net/index.php?eID=dumpFile&t=f&f=4089&token=deea5d707e2d0edeeb4fccd544a973de4bc46a09#_coordinate_systems """
interest: bool = False
"""If this vehicle is of interest in the current scenario."""


class EgoVehicleObservation(NamedTuple):
Expand Down
75 changes: 48 additions & 27 deletions smarts/core/sensors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import numpy as np

from smarts.core.agent_interface import ActorsAliveDoneCriteria, AgentsAliveDoneCriteria
from smarts.core.agent_interface import AgentsAliveDoneCriteria, InterestDoneCriteria
from smarts.core.coordinates import Heading, Point
from smarts.core.events import Events
from smarts.core.observations import (
Expand Down Expand Up @@ -70,7 +70,12 @@
LANE_INDEX_CONSTANT = -1


def _make_vehicle_observation(road_map, neighborhood_vehicle):
def _make_vehicle_observation(
road_map,
neighborhood_vehicle: VehicleState,
sim_frame: SimulationFrame,
interest_extension: Optional[re.Pattern],
):
nv_lane = road_map.nearest_lane(neighborhood_vehicle.pose.point, radius=3)
if nv_lane:
nv_road_id = nv_lane.road.road_id
Expand All @@ -91,6 +96,9 @@ def _make_vehicle_observation(road_map, neighborhood_vehicle):
lane_id=nv_lane_id,
lane_index=nv_lane_index,
lane_position=None,
interest=sim_frame.actor_is_interest(
neighborhood_vehicle.actor_id, extension=interest_extension
),
)


Expand All @@ -102,14 +110,24 @@ def __init__(self, max_episode_steps: int, plan_frame: PlanFrame):
self._plan_frame = plan_frame
self._step = 0
self._seen_interest_actors = False
self._seen_alive_actors = False

def step(self):
"""Update internal state."""
self._step += 1

@property
def seen_alive_actors(self) -> bool:
"""If an agents alive actor has been spotted before."""
return self._seen_alive_actors

@seen_alive_actors.setter
def seen_alive_actors(self, value: bool):
self._seen_alive_actors = value

@property
def seen_interest_actors(self) -> bool:
"""If a relevant actor has been spotted before."""
"""If an interest actor has been spotted before."""
return self._seen_interest_actors

@seen_interest_actors.setter
Expand Down Expand Up @@ -259,10 +277,19 @@ def process_serialization_safe_sensors(
)
if neighborhood_vehicle_states_sensor:
neighborhood_vehicle_states = []
interface = sim_frame.agent_interfaces.get(agent_id)
interest_pattern = (
interface.done_criteria.interest.actors_pattern
if interface is not None
and interface.done_criteria.interest is not None
else None
)
for nv in neighborhood_vehicle_states_sensor(
vehicle_state, sim_frame.vehicle_states.values()
):
veh_obs = _make_vehicle_observation(sim_local_constants.road_map, nv)
veh_obs = _make_vehicle_observation(
sim_local_constants.road_map, nv, sim_frame, interest_pattern
)
lane_position_sensor = vehicle_sensors.get("lane_position_sensor")
nv_lane_pos = None
if veh_obs.lane_id is not LANE_ID_CONSTANT and lane_position_sensor:
Expand Down Expand Up @@ -511,28 +538,17 @@ def _agents_alive_done_check(
return False

@classmethod
def _actors_alive_done_check(
def _interest_done_check(
cls,
vehicle_ids,
sensor_state,
actors_alive: Optional[ActorsAliveDoneCriteria],
interest_actors,
sensor_state: SensorState,
interest_criteria: Optional[InterestDoneCriteria],
):
if actors_alive is None:
if interest_criteria is None or len(interest_actors) > 0:
sensor_state.seen_alive_actors = True
return False

sensor_state: SensorState = sensor_state

pattern = re.compile(
"|".join(rf"(?:{aoi})" for aoi in actors_alive.actors_of_interest)
)
## TODO optimization to get vehicles that were added and removed last step
## TODO second optimization to check for already known vehicles
for vehicle_id in vehicle_ids:
# get vehicles by pattern
if pattern.match(vehicle_id):
sensor_state.seen_interest_actors = True
return False
if actors_alive.strict or sensor_state.seen_interest_actors:
if interest_criteria.strict or sensor_state.seen_interest_actors:
# if agent requires the actor to exist immediately
# OR if previously seen relevant actors but no actors match anymore
return True
Expand All @@ -555,6 +571,7 @@ def _is_done_with_events(
interface = sim_frame.agent_interfaces[agent_id]
done_criteria = interface.done_criteria
event_config = interface.event_configuration
interest = interface.done_criteria.interest

# TODO: the following calls nearest_lanes (expensive) 6 times
reached_goal = cls._agent_reached_goal(
Expand All @@ -580,9 +597,13 @@ def _is_done_with_events(
agents_alive_done = cls._agents_alive_done_check(
sim_frame.ego_ids, sim_frame.potential_agent_ids, done_criteria.agents_alive
)
actors_alive_done = cls._actors_alive_done_check(
sim_frame.vehicle_ids, sensor_state, done_criteria.actors_alive
)
interest_done = False
if interest:
cls._interest_done_check(
sim_frame.interest_actors(interest.actors_pattern),
sensor_state,
interest_criteria=interest,
)

done = not sim_frame.resetting and (
(is_off_road and done_criteria.off_road)
Expand All @@ -594,7 +615,7 @@ def _is_done_with_events(
or (is_off_route and done_criteria.off_route)
or (is_wrong_way and done_criteria.wrong_way)
or agents_alive_done
or actors_alive_done
or interest_done
)

events = Events(
Expand All @@ -607,7 +628,7 @@ def _is_done_with_events(
wrong_way=is_wrong_way,
not_moving=is_not_moving,
agents_alive_done=agents_alive_done,
actors_alive_done=actors_alive_done,
interest_done=interest_done,
)

return done, events
Expand Down
49 changes: 48 additions & 1 deletion smarts/core/simulation_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import logging
import re
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Dict, List, Optional, Set

from cached_property import cached_property
Expand All @@ -39,7 +41,7 @@ class SimulationFrame:
actor_states: List[ActorState]
agent_vehicle_controls: Dict[str, str]
agent_interfaces: Dict[str, AgentInterface]
ego_ids: str
ego_ids: Set[str]
pending_agent_ids: List[str]
elapsed_sim_time: float
fixed_timestep: float
Expand All @@ -55,6 +57,7 @@ class SimulationFrame:
vehicle_sensors: Dict[str, Dict[str, Any]]

sensor_states: Any
interest_filter: re.Pattern
# TODO MTA: renderer can be allowed here as long as it is only type information
# renderer_type: Any = None
_ground_bullet_id: Optional[str] = None
Expand All @@ -74,6 +77,43 @@ def actor_states_by_id(self) -> Dict[str, ActorState]:
"""Get actor states paired by their ids."""
return {a_s.actor_id: a_s for a_s in self.actor_states}

@lru_cache(1)
def interest_actors(
self, extension: Optional[re.Pattern] = None
) -> Dict[str, ActorState]:
"""Get the actor states of actors that are marked as of interest.

Args:
extension (re.Pattern): A matching for interest actors not defined in scenario.
"""

_matchers: List[re.Pattern] = []
if self.interest_filter.pattern:
_matchers.append(self.interest_filter)
if extension is not None and extension.pattern:
_matchers.append(extension)

if len(_matchers) == 0:
return {
a_s.actor_id: a_s
for a_s in self.actor_states
if any(bool(m.match(a_s.actor_id)) for m in _matchers)
}
return {}

def actor_is_interest(
self, actor_id, extension: Optional[re.Pattern] = None
) -> bool:
"""Determine if the actor is of interest.

Args:
actor_id (str): The id of the actor to test.

Returns:
bool: If the actor is of interest.
"""
return actor_id in self.interest_actors(extension)

def vehicle_did_collide(self, vehicle_id) -> bool:
"""Test if the given vehicle had any collisions in the last physics update."""
vehicle_collisions = self.vehicle_collisions.get(vehicle_id, [])
Expand All @@ -91,6 +131,13 @@ def filtered_vehicle_collisions(self, vehicle_id) -> List[Collision]:
c for c in vehicle_collisions if c.collidee_id != self._ground_bullet_id
]

@cached_property
def _hash(self):
return self.step_count ^ hash(self.fixed_timestep) ^ hash(self.map_spec)

def __hash__(self):
return self._hash

def __post_init__(self):
if logger.isEnabledFor(logging.DEBUG):
assert isinstance(self.actor_states, list)
Expand Down
8 changes: 6 additions & 2 deletions smarts/core/smarts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import importlib.resources as pkg_resources
import logging
import os
import re
import warnings
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union

Expand Down Expand Up @@ -76,6 +77,7 @@
level=logging.ERROR,
)

_DEFAULT_PATTERN = re.compile("")
MAX_PYBULLET_FREQ = 240


Expand Down Expand Up @@ -625,7 +627,7 @@ def switch_control_to_agent(
), f"Vehicle has already been hijacked: {vehicle_id}"
assert not vehicle_id in self.vehicle_index.agent_vehicle_ids(), (
f"`{agent_id}` can't hijack vehicle that is already controlled by an agent"
f" `{self.vehicle_index.actor_id_from_vehicle_id(vehicle_id)}`: {vehicle_id}"
f" `{self.agent_manager.agent_for_vehicle(vehicle_id)}`: {vehicle_id}"
)

# Switch control to agent
Expand Down Expand Up @@ -1381,7 +1383,6 @@ def neighborhood_vehicles_around_vehicle(
) -> List[VehicleState]:
"""Find vehicles in the vicinity of the target vehicle."""
self._check_valid()
from smarts.core.sensors import Sensors

vehicle = self._vehicle_index.vehicle_by_id(vehicle_id)
return neighborhood_vehicles_around_vehicle(
Expand Down Expand Up @@ -1680,4 +1681,7 @@ def cached_frame(self):
vehicle_sensors=self.sensor_manager.sensors_for_actor_ids(vehicle_ids),
sensor_states=dict(self.sensor_manager.sensor_states_items()),
_ground_bullet_id=self._ground_bullet_id,
interest_filter=self.scenario.metadata.get(
"actor_of_interest_re_filter", _DEFAULT_PATTERN
),
)
Loading