Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

type hint fixes #1216

Merged
merged 2 commits into from
Dec 29, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions smarts/core/smarts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import os
import warnings
from collections import defaultdict
from typing import List, Optional, Sequence
from typing import List, Optional, Set

import numpy as np

Expand Down Expand Up @@ -81,8 +81,8 @@ def __init__(
self,
agent_interfaces,
traffic_sim, # SumoTrafficSimulation
envision: EnvisionClient = None,
visdom: VisdomClient = None,
envision: Optional[EnvisionClient] = None,
visdom: Optional[VisdomClient] = None,
fixed_timestep_sec: float = 0.1,
reset_agents_only: bool = False,
zoo_addrs=None,
Expand All @@ -91,10 +91,10 @@ def __init__(
self._log = logging.getLogger(self.__class__.__name__)
self._sim_id = Id.new("smarts")
self._is_setup = False
self._scenario: Scenario = None
self._scenario: Optional[Scenario] = None
self._renderer = None
self._envision: EnvisionClient = envision
self._visdom: VisdomClient = visdom
self._envision: Optional[EnvisionClient] = envision
self._visdom: Optional[VisdomClient] = visdom
self._traffic_sim = traffic_sim
self._external_provider = None

Expand Down Expand Up @@ -148,7 +148,7 @@ def __init__(
self._vehicle_states = []

self._bubble_manager = None
self._trap_manager: TrapManager = None
self._trap_manager: Optional[TrapManager] = None

self._ground_bullet_id = None
self._map_bb = None
Expand Down Expand Up @@ -320,6 +320,7 @@ def reset(self, scenario: Scenario):
ids = self._vehicle_index.vehicle_ids_by_actor_id(agent_id)
vehicle_ids_to_teardown.extend(ids)
self._teardown_vehicles(set(vehicle_ids_to_teardown))
assert self._trap_manager
self._trap_manager.init_traps(scenario.road_map, scenario.missions)
self._agent_manager.init_ego_agents(self)
if self._renderer:
Expand Down Expand Up @@ -399,7 +400,7 @@ def add_agent_and_switch_control(
agent_id: str,
agent_interface: AgentInterface,
mission: Mission,
) -> Vehicle:
):
self.agent_manager.add_ego_agent(agent_id, agent_interface, for_trap=False)
vehicle = self.switch_control_to_agent(
vehicle_id, agent_id, mission, recreate=False, is_hijacked=True
Expand Down Expand Up @@ -651,7 +652,7 @@ def elapsed_sim_time(self) -> float:
def version(self) -> str:
return VERSION

def teardown_agents_without_vehicles(self, agent_ids: Sequence):
def teardown_agents_without_vehicles(self, agent_ids: Set[str]):
"""
Teardown agents in the given list that have no vehicles registered as
controlled-by or shadowed-by
Expand Down Expand Up @@ -816,7 +817,7 @@ def _reset_providers(self):
for provider in self.providers:
provider.reset()

def _step_providers(self, actions) -> List[VehicleState]:
def _step_providers(self, actions) -> ProviderState:
accumulated_provider_state = ProviderState()

def agent_controls_vehicles(agent_id):
Expand Down Expand Up @@ -1073,7 +1074,7 @@ def _check_ground_plane(self):
)
self._setup_pybullet_ground_plane(self._bullet_client)

def _try_emit_envision_state(self, provider_state, obs, scores):
def _try_emit_envision_state(self, provider_state: ProviderState, obs, scores):
if not self._envision:
return

Expand Down
4 changes: 2 additions & 2 deletions smarts/core/traffic_history_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

from typing import Optional, Set
from typing import Iterable, Optional, Set

import numpy as np

Expand Down Expand Up @@ -59,7 +59,7 @@ def setup(self, scenario) -> ProviderState:
self._is_setup = True
return ProviderState()

def set_replaced_ids(self, vehicle_ids: list):
def set_replaced_ids(self, vehicle_ids: Iterable[str]):
self._replaced_vehicle_ids.update(vehicle_ids)

def get_history_id(self, vehicle_id: str) -> Optional[str]:
Expand Down