Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow configuration of the env step return shapes #1920

Merged
merged 9 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ Copy and pasting the git commit messages is __NOT__ enough.
- Changed instances of `hiway-v0` and `gym` to use `hiway-v1` and `gymnasium`, respectively.
- `RoadMap.Route` now optionally stores the start and end lanes of the route.
- `DistToDestination` metric now adds lane error penalty when agent terminates in different lane but same road as the goal position.
- `hiway-v1` can now be configured for per-agent or environment reward(s), truncation(s), termination(s), and info(s) through `environment_return_mode`.
- `hiway-v1`'s `observation_options` no longer has an effect on the environment rewards, truncations, and terminations `agent|environment` style return mode.
### Deprecated
- `visdom` is set to be removed from the SMARTS object parameters.
- Deprecated `start_time` on missions.
Expand Down
34 changes: 27 additions & 7 deletions smarts/env/gymnasium/hiway_env_v1.py
Gamenot marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# THE SOFTWARE.
import logging
import os
from enum import IntEnum
from enum import IntEnum, auto
from functools import partial
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -90,6 +90,19 @@ class SumoOptions(NamedTuple):
)


class EnvReturnMode(IntEnum):
"""Configuration to determine the interface type of the step function.

This configures between the environment status return (i.e. reward means the environment reward) and the per-agent
status return (i.e. rewards means reward per agent).
"""

per_agent = auto()
"""Generate per-agent mode step returns in the form ``(rewards({id: float}), terminateds({id: bool}), truncateds ({id: bool}), info)``."""
environment = auto()
"""Generate environment mode step returns in the form ``(reward (float), terminated (bool), truncated (bool), info)``."""


class HiWayEnvV1(gym.Env):
"""A generic environment for various driving tasks simulated by SMARTS.

Expand Down Expand Up @@ -125,6 +138,10 @@ class HiWayEnvV1(gym.Env):
for how the formatting matches the action space. String version
can be used instead. See :class:`~smarts.env.utils.action_conversion.ActionOptions`. Defaults to
:attr:`~smarts.env.utils.action_conversion.ActionOptions.default`.
environment_return_mode (EnvReturnMode, str): This configures between the environment
step return information (i.e. reward means the environment reward) and the per-agent
step return information (i.e. reward means rewards as key-value per agent). Defaults to
:attr:`~smarts.env.gymnasium.hiway_env_v1.EnvReturnMode.per_agent`.
"""

metadata = {"render_modes": ["human"]}
Expand Down Expand Up @@ -159,6 +176,7 @@ def __init__(
ObservationOptions, str
] = ObservationOptions.default,
action_options: Union[ActionOptions, str] = ActionOptions.default,
environment_return_mode: Union[EnvReturnMode, str] = EnvReturnMode.per_agent,
):
self._log = logging.getLogger(self.__class__.__name__)
smarts_seed(seed)
Expand Down Expand Up @@ -198,6 +216,11 @@ def __init__(
smarts_traffic = LocalTrafficProvider()
traffic_sims += [smarts_traffic]

if isinstance(environment_return_mode, str):
self._environment_return_mode = EnvReturnMode[environment_return_mode]
else:
self._environment_return_mode = environment_return_mode

if isinstance(action_options, str):
action_options = ActionOptions[action_options]
self._action_formatter = ActionSpacesFormatter(
Expand Down Expand Up @@ -288,18 +311,15 @@ def step(

assert all("score" in v for v in info.values())

if self._observations_formatter.observation_options == ObservationOptions.full:
if self._environment_return_mode == EnvReturnMode.environment:
return (
self._observations_formatter.format(observations),
sum(r for r in rewards.values()),
dones["__all__"],
dones["__all__"],
info,
)
elif self._observations_formatter.observation_options in (
ObservationOptions.multi_agent,
ObservationOptions.unformatted,
):
elif self._environment_return_mode == EnvReturnMode.per_agent:
return (
self._observations_formatter.format(observations),
rewards,
Expand All @@ -308,7 +328,7 @@ def step(
info,
)
raise RuntimeError(
f"Invalid observation configuration using {self._observations_formatter.observation_options}"
f"Invalid observation configuration using {self._environment_return_mode}"
)

def reset(
Expand Down