Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test to check for desync'ed vehicle state. #1988

Merged
merged 7 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Copy and pasting the git commit messages is __NOT__ enough.
- `VehicleIndex` no longer segfaults when attempting to `repr()` it.
- Fixed issues related to waypoints in SUMO maps. Waypoints in junctions should now return all possible paths through the junction.
- Fixed CI tests for metrics.
- Fixed an issue where the actor states and vehicle states were not synchronized after simulation vehicle updates resulting in different values from the simulation frame.
- Minor fix in regular expression compilation of `actor_of_interest_re_filter` from scenario metadata.
- Fixed acceleration and jerk computation in comfort metric, by ignoring vehicle position jitters smaller than a threshold.
### Removed
Expand Down
2 changes: 1 addition & 1 deletion smarts/core/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ def __hash__(self) -> int:
return hash(self.actor_id)

def __eq__(self, other) -> bool:
return self.__class__ == other.__class__ and hash(self.actor_id) == hash(
return isinstance(other, type(self)) and hash(self.actor_id) == hash(
other.actor_id
)
15 changes: 15 additions & 0 deletions smarts/core/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,21 @@ def intersects(self, actor_ids: Set[str]) -> bool:
intersection = actor_ids & provider_actor_ids
return bool(intersection)

def replace_actor_type(
self, updated_actors: List[ActorState], actor_state_type: type
):
"""Replaces all actors of the given type.

Args:
updated_actors (List[ActorState]): The actors to use as replacement.
actor_type (str): The actor type to replace.
"""
self.actors = [
actor_state
for actor_state in self.actors
if not issubclass(actor_state.__class__, actor_state_type)
] + updated_actors


class ProviderManager:
"""Interface to be implemented by a class that manages a set of Providers
Expand Down
2 changes: 1 addition & 1 deletion smarts/core/road_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def center_pose_at_point(self, point: Point) -> Pose:
position = self.from_lane_coord(RefLinePoint(s=offset))
desired_vector = self.vector_at_offset(offset)
orientation = fast_quaternion_from_angle(vec_to_radians(desired_vector[:2]))
return Pose(position=position, orientation=orientation)
return Pose(position=position.as_np_array, orientation=orientation)

def curvature_radius_at_offset(
self, offset: float, lookahead: int = 5
Expand Down
13 changes: 10 additions & 3 deletions smarts/core/smarts.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def _step(self, agent_actions, time_delta_since_last_step: Optional[float] = Non
# 2. Step all providers and harmonize state
with timeit("Stepping all providers and harmonizing state", self._log.debug):
provider_state = self._step_providers(all_agent_actions)
self._last_provider_state = provider_state
with timeit("Checking if all agents are active", self._log.debug):
self._check_if_acting_on_active_agents(agent_actions)

Expand All @@ -314,7 +313,7 @@ def _step(self, agent_actions, time_delta_since_last_step: Optional[float] = Non
# want these during their observation/reward computations.
# This is a hack to give us some short term perf wins. Longer term we
# need to expose better support for batched computations
self._vehicle_states = [v.state for v in self._vehicle_index.vehicles]
self._sync_smarts_and_provider_actor_states(provider_state)
self._sensor_manager.clean_up_sensors_for_actors(
set(v.actor_id for v in self._vehicle_states), renderer=self.renderer_ref
)
Expand Down Expand Up @@ -1332,6 +1331,13 @@ def _step_providers(self, actions) -> ProviderState:
self._harmonize_providers(accumulated_provider_state)
return accumulated_provider_state

def _sync_smarts_and_provider_actor_states(
self, external_provider_state: ProviderState
):
self._last_provider_state = external_provider_state
self._vehicle_states = [v.state for v in self._vehicle_index.vehicles]
self._last_provider_state.replace_actor_type(self._vehicle_states, VehicleState)

@property
def should_reset(self):
"""If the simulation requires a reset."""
Expand Down Expand Up @@ -1671,7 +1677,8 @@ def cached_frame(self):
step_count=self.step_count,
vehicle_collisions=self._vehicle_collisions,
vehicle_states={
vehicle_id: vehicle.state for vehicle_id, vehicle in vehicles.items()
vehicle_state.actor_id: vehicle_state
for vehicle_state in self._vehicle_states
},
vehicles_for_agents={
agent_id: self.vehicle_index.vehicle_ids_by_owner_id(
Expand Down
10 changes: 5 additions & 5 deletions smarts/core/tests/test_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,19 @@ def agent_spec(agent_interface):


@pytest.fixture
def env(agent_spec):
env = gym.make(
def env(agent_spec: AgentSpec):
_env = gym.make(
"smarts.env:hiway-v0",
scenarios=["scenarios/sumo/figure_eight"],
agent_specs={AGENT_ID: agent_spec},
agent_interfaces={AGENT_ID: agent_spec.interface},
headless=True,
visdom=False,
fixed_timestep_sec=0.1,
seed=42,
)

yield env
env.close()
yield _env
_env.close()


def project_2d(lens, img_metadata: GridMapMetadata, pos):
Expand Down
17 changes: 17 additions & 0 deletions smarts/core/tests/test_simulation_state_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,23 @@ def test_state(sim: SMARTS, scenario):
assert hasattr(frame, "vehicle_collisions")


def test_vehicles_in_actors(sim: SMARTS, scenario):
sim.setup(scenario)
frame: SimulationFrame = sim.cached_frame

while (frame := sim.cached_frame) and len(frame.vehicle_states) < 1:
sim.step({})

assert set(k for k in frame.vehicle_states) == set(
actor_state.actor_id for actor_state in frame.actor_states
)
actor_states = {
actor_state.actor_id: actor_state for actor_state in frame.actor_states
}
for k, vehicle_state in frame.vehicle_states.items():
assert vehicle_state == actor_states[k]


def test_state_serialization(sim: SMARTS, scenario: Scenario):
sim.setup(scenario)
sim.reset(scenario, start_time=10)
Expand Down
6 changes: 5 additions & 1 deletion smarts/core/vehicle_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ def __post_init__(self):
assert self.pose is not None and self.dimensions is not None

def __eq__(self, __o: object):
return super().__eq__(__o)
return (
isinstance(__o, type(self))
and super().__eq__(__o)
and self.pose == __o.pose
)

@property
def bounding_box_points(
Expand Down