Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add low dimension interest observation. #1977

Merged
merged 13 commits into from
Apr 25, 2023
2 changes: 2 additions & 0 deletions smarts/core/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class VehicleObservation(NamedTuple):
lane_position: Optional[RefLinePoint] = None
"""(s,t,h) coordinates within the lane, where s is the longitudinal offset along the lane, t is the lateral displacement from the lane center, and h (not yet supported) is the vertical displacement from the lane surface.
See the Reference Line coordinate system in OpenDRIVE here: https://www.asam.net/index.php?eID=dumpFile&t=f&f=4089&token=deea5d707e2d0edeeb4fccd544a973de4bc46a09#_coordinate_systems """
interest: bool = False
"""If this vehicle is of interest in the current scenario."""


class EgoVehicleObservation(NamedTuple):
Expand Down
9 changes: 7 additions & 2 deletions smarts/core/sensors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@
LANE_INDEX_CONSTANT = -1


def _make_vehicle_observation(road_map, neighborhood_vehicle):
def _make_vehicle_observation(
road_map, neighborhood_vehicle: VehicleState, sim_frame: SimulationFrame
):
nv_lane = road_map.nearest_lane(neighborhood_vehicle.pose.point, radius=3)
if nv_lane:
nv_road_id = nv_lane.road.road_id
Expand All @@ -91,6 +93,7 @@ def _make_vehicle_observation(road_map, neighborhood_vehicle):
lane_id=nv_lane_id,
lane_index=nv_lane_index,
lane_position=None,
interest=sim_frame.actor_is_interest(neighborhood_vehicle.actor_id),
)


Expand Down Expand Up @@ -262,7 +265,9 @@ def process_serialization_safe_sensors(
for nv in neighborhood_vehicle_states_sensor(
vehicle_state, sim_frame.vehicle_states.values()
):
veh_obs = _make_vehicle_observation(sim_local_constants.road_map, nv)
veh_obs = _make_vehicle_observation(
sim_local_constants.road_map, nv, sim_frame
)
lane_position_sensor = vehicle_sensors.get("lane_position_sensor")
nv_lane_pos = None
if veh_obs.lane_id is not LANE_ID_CONSTANT and lane_position_sensor:
Expand Down
24 changes: 24 additions & 0 deletions smarts/core/simulation_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import logging
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set

Expand Down Expand Up @@ -55,6 +56,7 @@ class SimulationFrame:
vehicle_sensors: Dict[str, Dict[str, Any]]

sensor_states: Any
interest_filter: re.Pattern
# TODO MTA: renderer can be allowed here as long as it is only type information
# renderer_type: Any = None
_ground_bullet_id: Optional[str] = None
Expand All @@ -74,6 +76,28 @@ def actor_states_by_id(self) -> Dict[str, ActorState]:
"""Get actor states paired by their ids."""
return {a_s.actor_id: a_s for a_s in self.actor_states}

@cached_property
def _interest_actors(self) -> Dict[str, ActorState]:
"""Get the actor states of actors that are marked as of interest."""
if self.interest_filter.pattern:
return {
a_s.actor_id: a_s
for a_s in self.actor_states
if self.interest_filter.match(a_s.actor_id)
}
return {}

def actor_is_interest(self, actor_id) -> bool:
"""Determine if the actor is of interest.

Args:
actor_id (str): The id of the actor to test.

Returns:
bool: If the actor is of interest.
"""
return actor_id in self._interest_actors

def vehicle_did_collide(self, vehicle_id) -> bool:
"""Test if the given vehicle had any collisions in the last physics update."""
vehicle_collisions = self.vehicle_collisions.get(vehicle_id, [])
Expand Down
6 changes: 5 additions & 1 deletion smarts/core/smarts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import importlib.resources as pkg_resources
import logging
import os
import re
import warnings
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union

Expand Down Expand Up @@ -76,6 +77,7 @@
level=logging.ERROR,
)

_DEFAULT_PATTERN = re.compile("")
MAX_PYBULLET_FREQ = 240


Expand Down Expand Up @@ -1381,7 +1383,6 @@ def neighborhood_vehicles_around_vehicle(
) -> List[VehicleState]:
"""Find vehicles in the vicinity of the target vehicle."""
self._check_valid()
from smarts.core.sensors import Sensors

vehicle = self._vehicle_index.vehicle_by_id(vehicle_id)
return neighborhood_vehicles_around_vehicle(
Expand Down Expand Up @@ -1680,4 +1681,7 @@ def cached_frame(self):
vehicle_sensors=self.sensor_manager.sensors_for_actor_ids(vehicle_ids),
sensor_states=dict(self.sensor_manager.sensor_states_items()),
_ground_bullet_id=self._ground_bullet_id,
interest_filter=self.scenario.metadata.get(
"actor_of_interest_re_filter", _DEFAULT_PATTERN
),
)
12 changes: 11 additions & 1 deletion smarts/env/utils/observation_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def _format_neighborhood_vehicle_states(
"lane_index": np.zeros((des_shp,), dtype=np.int8),
"position": np.zeros((des_shp, 3), dtype=np.float64),
"speed": np.zeros((des_shp,), dtype=np.float32),
"interest": np.zeros((des_shp,), dtype=np.bool8),
}

neighborhood_vehicle_states = [
Expand All @@ -206,16 +207,20 @@ def _format_neighborhood_vehicle_states(
nghb.lane_index,
nghb.position,
nghb.speed,
nghb.interest,
)
for nghb in neighborhood_vehicle_states[:des_shp]
]
box, heading, vehicle_id, lane_index, pos, speed = zip(*neighborhood_vehicle_states)
box, heading, vehicle_id, lane_index, pos, speed, interest = zip(
*neighborhood_vehicle_states
)

box = np.array(box, dtype=np.float32)
heading = np.array(heading, dtype=np.float32)
lane_index = np.array(lane_index, dtype=np.int8)
pos = np.array(pos, dtype=np.float64)
speed = np.array(speed, dtype=np.float32)
interest = np.array(interest, dtype=np.bool8)

# fmt: off
box = np.pad(box, ((0,pad_shp),(0,0)), mode='constant', constant_values=0)
Expand All @@ -224,6 +229,7 @@ def _format_neighborhood_vehicle_states(
lane_index = np.pad(lane_index, ((0,pad_shp)), mode='constant', constant_values=0)
pos = np.pad(pos, ((0,pad_shp),(0,0)), mode='constant', constant_values=0)
speed = np.pad(speed, ((0,pad_shp)), mode='constant', constant_values=0)
interest = np.pad(interest, ((0,pad_shp)), mode="constant", constant_values=False)
# fmt: on

return {
Expand All @@ -233,6 +239,7 @@ def _format_neighborhood_vehicle_states(
"lane_index": lane_index,
"position": pos,
"speed": speed,
"interest": interest,
}


Expand Down Expand Up @@ -707,6 +714,9 @@ def name(self):
"speed": gym.spaces.Box(
low=0, high=1e10, shape=(_NEIGHBOR_SHP,), dtype=np.float32
),
"interest": gym.spaces.Box(
low=0, high=1, shape=(_NEIGHBOR_SHP,), dtype=np.bool8
),
Gamenot marked this conversation as resolved.
Show resolved Hide resolved
}
),
)
Expand Down
7 changes: 6 additions & 1 deletion smarts/env/wrappers/format_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ def _std_neighborhood_vehicle_states(
"lane_index": np.zeros((des_shp,), dtype=np.int8),
"pos": np.zeros((des_shp, 3), dtype=np.float64),
"speed": np.zeros((des_shp,), dtype=np.float32),
"interest": np.zeros((des_shp,), dtype=np.bool8),
}

nghbs = [
Expand All @@ -529,23 +530,26 @@ def _std_neighborhood_vehicle_states(
nghb.lane_index,
nghb.position,
nghb.speed,
nghb.interest,
)
for nghb in nghbs[:des_shp]
]
box, heading, lane_index, pos, speed = zip(*nghbs)
box, heading, lane_index, pos, speed, interest = zip(*nghbs)

box = np.array(box, dtype=np.float32)
heading = np.array(heading, dtype=np.float32)
lane_index = np.array(lane_index, dtype=np.int8)
pos = np.array(pos, dtype=np.float64)
speed = np.array(speed, dtype=np.float32)
interest = np.array(interest, dtype=np.bool8)

# fmt: off
box = np.pad(box, ((0,pad_shp),(0,0)), mode='constant', constant_values=0)
heading = np.pad(heading, ((0,pad_shp)), mode='constant', constant_values=0)
lane_index = np.pad(lane_index, ((0,pad_shp)), mode='constant', constant_values=0)
pos = np.pad(pos, ((0,pad_shp),(0,0)), mode='constant', constant_values=0)
speed = np.pad(speed, ((0,pad_shp)), mode='constant', constant_values=0)
interest = np.pad(interest, ((0,pad_shp)), mode="constant", constant_values=0)
# fmt: on

return {
Expand All @@ -554,6 +558,7 @@ def _std_neighborhood_vehicle_states(
"lane_index": lane_index,
"pos": pos,
"speed": speed,
"interest": interest,
}


Expand Down