diff --git a/CHANGELOG.md b/CHANGELOG.md index 3620217df5..2b07701cc4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -69,6 +69,7 @@ Copy and pasting the git commit messages is __NOT__ enough. - `VehicleIndex` no longer segfaults when attempting to `repr()` it. - Fixed issues related to waypoints in SUMO maps. Waypoints in junctions should now return all possible paths through the junction. - Fixed CI tests for metrics. +- Fixed an issue where the actor states and vehicle states were not synchronized after simulation vehicle updates resulting in different values from the simulation frame. - Minor fix in regular expression compilation of `actor_of_interest_re_filter` from scenario metadata. - Fixed acceleration and jerk computation in comfort metric, by ignoring vehicle position jitters smaller than a threshold. ### Removed diff --git a/smarts/core/actor.py b/smarts/core/actor.py index 2b05627f4f..9665aac02a 100644 --- a/smarts/core/actor.py +++ b/smarts/core/actor.py @@ -63,6 +63,6 @@ def __hash__(self) -> int: return hash(self.actor_id) def __eq__(self, other) -> bool: - return self.__class__ == other.__class__ and hash(self.actor_id) == hash( + return isinstance(other, type(self)) and hash(self.actor_id) == hash( other.actor_id ) diff --git a/smarts/core/provider.py b/smarts/core/provider.py index 043a86e8e1..7985b88d3c 100644 --- a/smarts/core/provider.py +++ b/smarts/core/provider.py @@ -89,6 +89,21 @@ def intersects(self, actor_ids: Set[str]) -> bool: intersection = actor_ids & provider_actor_ids return bool(intersection) + def replace_actor_type( + self, updated_actors: List[ActorState], actor_state_type: type + ): + """Replaces all actors of the given type. + + Args: + updated_actors (List[ActorState]): The actors to use as replacement. + actor_type (str): The actor type to replace. + """ + self.actors = [ + actor_state + for actor_state in self.actors + if not issubclass(actor_state.__class__, actor_state_type) + ] + updated_actors + class ProviderManager: """Interface to be implemented by a class that manages a set of Providers diff --git a/smarts/core/road_map.py b/smarts/core/road_map.py index 5fa3be671a..b1fda43f3e 100644 --- a/smarts/core/road_map.py +++ b/smarts/core/road_map.py @@ -447,7 +447,7 @@ def center_pose_at_point(self, point: Point) -> Pose: position = self.from_lane_coord(RefLinePoint(s=offset)) desired_vector = self.vector_at_offset(offset) orientation = fast_quaternion_from_angle(vec_to_radians(desired_vector[:2])) - return Pose(position=position, orientation=orientation) + return Pose(position=position.as_np_array, orientation=orientation) def curvature_radius_at_offset( self, offset: float, lookahead: int = 5 diff --git a/smarts/core/smarts.py b/smarts/core/smarts.py index acb8b8a4b8..940dd2081f 100644 --- a/smarts/core/smarts.py +++ b/smarts/core/smarts.py @@ -296,7 +296,6 @@ def _step(self, agent_actions, time_delta_since_last_step: Optional[float] = Non # 2. Step all providers and harmonize state with timeit("Stepping all providers and harmonizing state", self._log.debug): provider_state = self._step_providers(all_agent_actions) - self._last_provider_state = provider_state with timeit("Checking if all agents are active", self._log.debug): self._check_if_acting_on_active_agents(agent_actions) @@ -314,7 +313,7 @@ def _step(self, agent_actions, time_delta_since_last_step: Optional[float] = Non # want these during their observation/reward computations. # This is a hack to give us some short term perf wins. Longer term we # need to expose better support for batched computations - self._vehicle_states = [v.state for v in self._vehicle_index.vehicles] + self._sync_smarts_and_provider_actor_states(provider_state) self._sensor_manager.clean_up_sensors_for_actors( set(v.actor_id for v in self._vehicle_states), renderer=self.renderer_ref ) @@ -1332,6 +1331,13 @@ def _step_providers(self, actions) -> ProviderState: self._harmonize_providers(accumulated_provider_state) return accumulated_provider_state + def _sync_smarts_and_provider_actor_states( + self, external_provider_state: ProviderState + ): + self._last_provider_state = external_provider_state + self._vehicle_states = [v.state for v in self._vehicle_index.vehicles] + self._last_provider_state.replace_actor_type(self._vehicle_states, VehicleState) + @property def should_reset(self): """If the simulation requires a reset.""" @@ -1671,7 +1677,8 @@ def cached_frame(self): step_count=self.step_count, vehicle_collisions=self._vehicle_collisions, vehicle_states={ - vehicle_id: vehicle.state for vehicle_id, vehicle in vehicles.items() + vehicle_state.actor_id: vehicle_state + for vehicle_state in self._vehicle_states }, vehicles_for_agents={ agent_id: self.vehicle_index.vehicle_ids_by_owner_id( diff --git a/smarts/core/tests/test_observations.py b/smarts/core/tests/test_observations.py index 069de2cef1..310894c923 100644 --- a/smarts/core/tests/test_observations.py +++ b/smarts/core/tests/test_observations.py @@ -98,19 +98,19 @@ def agent_spec(agent_interface): @pytest.fixture -def env(agent_spec): - env = gym.make( +def env(agent_spec: AgentSpec): + _env = gym.make( "smarts.env:hiway-v0", scenarios=["scenarios/sumo/figure_eight"], - agent_specs={AGENT_ID: agent_spec}, + agent_interfaces={AGENT_ID: agent_spec.interface}, headless=True, visdom=False, fixed_timestep_sec=0.1, seed=42, ) - yield env - env.close() + yield _env + _env.close() def project_2d(lens, img_metadata: GridMapMetadata, pos): diff --git a/smarts/core/tests/test_simulation_state_frame.py b/smarts/core/tests/test_simulation_state_frame.py index dca791478d..6c0fbaff9f 100644 --- a/smarts/core/tests/test_simulation_state_frame.py +++ b/smarts/core/tests/test_simulation_state_frame.py @@ -89,6 +89,23 @@ def test_state(sim: SMARTS, scenario): assert hasattr(frame, "vehicle_collisions") +def test_vehicles_in_actors(sim: SMARTS, scenario): + sim.setup(scenario) + frame: SimulationFrame = sim.cached_frame + + while (frame := sim.cached_frame) and len(frame.vehicle_states) < 1: + sim.step({}) + + assert set(k for k in frame.vehicle_states) == set( + actor_state.actor_id for actor_state in frame.actor_states + ) + actor_states = { + actor_state.actor_id: actor_state for actor_state in frame.actor_states + } + for k, vehicle_state in frame.vehicle_states.items(): + assert vehicle_state == actor_states[k] + + def test_state_serialization(sim: SMARTS, scenario: Scenario): sim.setup(scenario) sim.reset(scenario, start_time=10) diff --git a/smarts/core/vehicle_state.py b/smarts/core/vehicle_state.py index cee9b8c10c..5bfe11d4c3 100644 --- a/smarts/core/vehicle_state.py +++ b/smarts/core/vehicle_state.py @@ -120,7 +120,11 @@ def __post_init__(self): assert self.pose is not None and self.dimensions is not None def __eq__(self, __o: object): - return super().__eq__(__o) + return ( + isinstance(__o, type(self)) + and super().__eq__(__o) + and self.pose == __o.pose + ) @property def bounding_box_points(