diff --git a/CHANGELOG.md b/CHANGELOG.md index 299f561a22..68348be668 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,8 @@ Copy and pasting the git commit messages is __NOT__ enough. - `info` returned by `hiway-v1` in `reset()` and `step()` methods are unified. - Changed instances of `hiway-v0` and `gym` to use `hiway-v1` and `gymnasium`, respectively. - `RoadMap.Route` now optionally stores the start and end lanes of the route. +- `hiway-v1` can now be configured for per-agent or environment reward(s), truncation(s), termination(s), and info(s) through `environment_return_mode`. +- `hiway-v1`'s `observation_options` no longer has an effect on the environment rewards, truncations, and terminations `agent|environment` style return mode. - `DistToDestination` metric is now computed by summing the (i) off-route distance driven by the vehicle from its last on-route position, and (ii) the distance to goal from the vehicle's last on-route position. - `Steps` metric is capped by scenario duration set in the scenario metadata. - Overall metric score is weighted by each agent's task difficulty. diff --git a/smarts/env/gymnasium/hiway_env_v1.py b/smarts/env/gymnasium/hiway_env_v1.py index 074d66dcfe..10552a4ca7 100644 --- a/smarts/env/gymnasium/hiway_env_v1.py +++ b/smarts/env/gymnasium/hiway_env_v1.py @@ -21,7 +21,7 @@ # THE SOFTWARE. import logging import os -from enum import IntEnum +from enum import IntEnum, auto from functools import partial from pathlib import Path from typing import ( @@ -90,6 +90,19 @@ class SumoOptions(NamedTuple): ) +class EnvReturnMode(IntEnum): + """Configuration to determine the interface type of the step function. + + This configures between the environment status return (i.e. reward means the environment reward) and the per-agent + status return (i.e. rewards means reward per agent). + """ + + per_agent = auto() + """Generate per-agent mode step returns in the form ``(rewards({id: float}), terminateds({id: bool}), truncateds ({id: bool}), info)``.""" + environment = auto() + """Generate environment mode step returns in the form ``(reward (float), terminated (bool), truncated (bool), info)``.""" + + class HiWayEnvV1(gym.Env): """A generic environment for various driving tasks simulated by SMARTS. @@ -125,6 +138,10 @@ class HiWayEnvV1(gym.Env): for how the formatting matches the action space. String version can be used instead. See :class:`~smarts.env.utils.action_conversion.ActionOptions`. Defaults to :attr:`~smarts.env.utils.action_conversion.ActionOptions.default`. + environment_return_mode (EnvReturnMode, str): This configures between the environment + step return information (i.e. reward means the environment reward) and the per-agent + step return information (i.e. reward means rewards as key-value per agent). Defaults to + :attr:`~smarts.env.gymnasium.hiway_env_v1.EnvReturnMode.per_agent`. """ metadata = {"render_modes": ["human"]} @@ -159,6 +176,7 @@ def __init__( ObservationOptions, str ] = ObservationOptions.default, action_options: Union[ActionOptions, str] = ActionOptions.default, + environment_return_mode: Union[EnvReturnMode, str] = EnvReturnMode.per_agent, ): self._log = logging.getLogger(self.__class__.__name__) smarts_seed(seed) @@ -198,6 +216,11 @@ def __init__( smarts_traffic = LocalTrafficProvider() traffic_sims += [smarts_traffic] + if isinstance(environment_return_mode, str): + self._environment_return_mode = EnvReturnMode[environment_return_mode] + else: + self._environment_return_mode = environment_return_mode + if isinstance(action_options, str): action_options = ActionOptions[action_options] self._action_formatter = ActionSpacesFormatter( @@ -288,7 +311,7 @@ def step( assert all("score" in v for v in info.values()) - if self._observations_formatter.observation_options == ObservationOptions.full: + if self._environment_return_mode == EnvReturnMode.environment: return ( self._observations_formatter.format(observations), sum(r for r in rewards.values()), @@ -296,19 +319,30 @@ def step( dones["__all__"], info, ) - elif self._observations_formatter.observation_options in ( - ObservationOptions.multi_agent, - ObservationOptions.unformatted, - ): - return ( - self._observations_formatter.format(observations), - rewards, - dones, - dones, - info, - ) + elif self._environment_return_mode == EnvReturnMode.per_agent: + observations = self._observations_formatter.format(observations) + if ( + self._observations_formatter.observation_options + == ObservationOptions.full + ): + dones = {**{id_: False for id_ in observations}, **dones} + return ( + observations, + {**{id_: np.nan for id_ in observations}, **rewards}, + dones, + dones, + info, + ) + else: + return ( + observations, + rewards, + dones, + dones, + info, + ) raise RuntimeError( - f"Invalid observation configuration using {self._observations_formatter.observation_options}" + f"Invalid observation configuration using {self._environment_return_mode}" ) def reset(