Skip to content

Commit

Permalink
Address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gamenot committed Jan 6, 2022
1 parent 36fac79 commit 4c9935e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
8 changes: 5 additions & 3 deletions smarts/core/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@
# THE SOFTWARE.
from dataclasses import dataclass, field
from enum import IntFlag
from typing import List, Optional, Set
from typing import List, Optional, Set, Tuple

from .controllers import ActionSpaceType
from .scenario import Scenario
from .vehicle import VehicleState


class ProviderRecoveryFlags(IntFlag):
"""This describes actions to be taken with a provider should it fail."""

NOT_REQUIRED = 0x00000000
"""Not needed for the current step. Error causes skip."""
EPISODE_REQUIRED = 0x00000010
Expand Down Expand Up @@ -91,7 +93,7 @@ def teardown(self):

def recover(
self, scenario, elapsed_sim_time: float, error: Optional[Exception] = None
) -> bool:
) -> Tuple[ProviderState, bool]:
"""Attempt to reconnect the provider if an error or disconnection occured.
Implementations may choose to e-raise the passed in exception.
Args:
Expand All @@ -103,7 +105,7 @@ def recover(
"""
if error:
raise error
return False
return ProviderState(), False

@property
def connected(self) -> bool:
Expand Down
20 changes: 9 additions & 11 deletions smarts/core/smarts.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
logging.basicConfig(
format="%(asctime)s.%(msecs)03d %(levelname)s: {%(module)s} %(message)s",
datefmt="%Y-%m-%d,%H:%M:%S",
level=logging.INFO,
level=logging.ERROR,
)

MAX_PYBULLET_FREQ = 240
Expand Down Expand Up @@ -415,8 +415,7 @@ def add_provider(
recovery_flags: ProviderRecoveryFlags = ProviderRecoveryFlags.EXPERIMENT_REQUIRED,
):
assert isinstance(provider, Provider)
self._providers.append(provider)
self._provider_recovery_flags[provider] = recovery_flags
self._insert_provider(len(self._providers), provider, recovery_flags)

def _insert_provider(
self,
Expand Down Expand Up @@ -850,8 +849,7 @@ def _setup_providers(self, scenario) -> ProviderState:
try:
new_provider_state = provider.setup(scenario)
except Exception as provider_error:
self._handle_provider(provider, provider_error)
new_provider_state = ProviderState()
new_provider_state = self._handle_provider(provider, provider_error)
provider_state.merge(new_provider_state)
return provider_state

Expand All @@ -877,7 +875,7 @@ def _reset_providers(self):
except Exception as provider_error:
self._handle_provider(provider, provider_error)

def _handle_provider(self, provider: Provider, provider_error):
def _handle_provider(self, provider: Provider, provider_error) -> ProviderState:
provider_problem = bool(provider_error or not provider.connected)
if not provider_problem:
return
Expand All @@ -887,12 +885,13 @@ def _handle_provider(self, provider: Provider, provider_error):
)
recovered = False
if recovery_flags & ProviderRecoveryFlags.ATTEMPT_RECOVERY:
recovered = provider.recover(
provider_state, recovered = provider.recover(
self._scenario, self.elapsed_sim_time, provider_error
)

provider_state = provider_state or ProviderState()
if recovered:
return
return provider_state

if recovery_flags & ProviderRecoveryFlags.EPISODE_REQUIRED:
self._reset_required = True
Expand All @@ -901,7 +900,7 @@ def _handle_provider(self, provider: Provider, provider_error):
f"`Provider {provider.__class__.__name__} has crashed during reset`"
)
raise provider_error
return
return provider_state
elif recovery_flags & ProviderRecoveryFlags.EXPERIMENT_REQUIRED:
raise provider_error

Expand Down Expand Up @@ -961,8 +960,7 @@ def matches_no_provider_action_space(agent_id):
try:
provider_state = self._step_provider(provider, actions)
except Exception as provider_error:
self._handle_provider(provider, provider_error)
provider_state = ProviderState()
provider_state = self._handle_provider(provider, provider_error)

if provider == self._traffic_sim:
# Remove agent vehicles from provider vehicles
Expand Down
2 changes: 1 addition & 1 deletion smarts/core/sumo_traffic_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def recover(
self._handle_traci_disconnect(error)
elif isinstance(error, Exception):
raise error
return False
return ProviderState(), False

def step(self, provider_actions, dt, elapsed_sim_time) -> ProviderState:
"""
Expand Down

0 comments on commit 4c9935e

Please sign in to comment.