Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support only new step API (while retaining compatibility functions) #3019

Merged
merged 17 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ observation, info = env.reset(seed=42)

for _ in range(1000):
action = env.action_space.sample()
observation, reward, done, info = env.step(action)
observation, reward, terminated, truncarted, info = env.step(action)

if done:
if terminated or truncated:
observation, info = env.reset()
env.close()
```
Expand Down
50 changes: 10 additions & 40 deletions gym/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np

from gym import spaces
from gym.logger import deprecation, warn
from gym.logger import warn
from gym.utils import seeding

if TYPE_CHECKING:
Expand Down Expand Up @@ -83,16 +83,11 @@ def np_random(self) -> np.random.Generator:
def np_random(self, value: np.random.Generator):
self._np_random = value

def step(
self, action: ActType
) -> Union[
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict]
]:
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
arjun-kg marked this conversation as resolved.
Show resolved Hide resolved
"""Run one timestep of the environment's dynamics.

When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state.
Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`, or a tuple
(observation, reward, done, info). The latter is deprecated and will be removed in future versions.
Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove "either"


Args:
action (ActType): an action provided by the agent
Expand Down Expand Up @@ -226,25 +221,18 @@ class Wrapper(Env[ObsType, ActType]):
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
"""

def __init__(self, env: Env, new_step_api: bool = False):
def __init__(self, env: Env):
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.

Args:
env: The environment to wrap
new_step_api: Whether the wrapper's step method will output in new or old step API
"""
self.env = env

self._action_space: Optional[spaces.Space] = None
self._observation_space: Optional[spaces.Space] = None
self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None
self._metadata: Optional[dict] = None
self.new_step_api = new_step_api

if not self.new_step_api:
deprecation(
"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
)

def __getattr__(self, name):
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
Expand Down Expand Up @@ -326,17 +314,9 @@ def _np_random(self):
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
)

def step(
self, action: ActType
) -> Union[
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict]
]:
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""Steps through the environment with action."""
from gym.utils.step_api_compatibility import ( # avoid circular import
step_api_compatibility,
)

return step_api_compatibility(self.env.step(action), self.new_step_api)
return self.env.step(action)

def reset(self, **kwargs) -> Tuple[ObsType, dict]:
"""Resets the environment with kwargs."""
Expand Down Expand Up @@ -401,13 +381,8 @@ def reset(self, **kwargs):

def step(self, action):
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
step_returns = self.env.step(action)
if len(step_returns) == 5:
observation, reward, terminated, truncated, info = step_returns
return self.observation(observation), reward, terminated, truncated, info
else:
observation, reward, done, info = step_returns
return self.observation(observation), reward, done, info
observation, reward, terminated, truncated, info = self.env.step(action)
return self.observation(observation), reward, terminated, truncated, info

def observation(self, observation):
"""Returns a modified observation."""
Expand Down Expand Up @@ -440,13 +415,8 @@ def reward(self, reward):

def step(self, action):
"""Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`."""
step_returns = self.env.step(action)
if len(step_returns) == 5:
observation, reward, terminated, truncated, info = step_returns
return observation, self.reward(reward), terminated, truncated, info
else:
observation, reward, done, info = step_returns
return observation, self.reward(reward), done, info
observation, reward, terminated, truncated, info = self.env.step(action)
return observation, self.reward(reward), terminated, truncated, info

def reward(self, reward):
"""Returns a modified ``reward``."""
Expand Down
18 changes: 10 additions & 8 deletions gym/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class EnvSpec:
order_enforce: bool = field(default=True)
autoreset: bool = field(default=False)
disable_env_checker: bool = field(default=False)
new_step_api: bool = field(default=False)
apply_step_compatibility: bool = field(default=False)

# Environment arguments
kwargs: dict = field(default_factory=dict)
Expand Down Expand Up @@ -547,7 +547,7 @@ def make(
id: Union[str, EnvSpec],
max_episode_steps: Optional[int] = None,
autoreset: bool = False,
new_step_api: bool = False,
apply_step_compatibility: bool = False,
disable_env_checker: Optional[bool] = None,
**kwargs,
) -> Env:
Expand All @@ -557,7 +557,7 @@ def make(
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0
apply_step_compatibility: Whether to use apply compatibility wrapper that converts step method to return two bools (StepAPICompatibility wrapper)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't we removing this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or I guess it might be useful for automatically supporting legacy environments?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we should keep a parameter in make to easily apply the compatibility wrapper

disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker`
(which is by default False, running the environment checker),
otherwise will run according to this parameter (`True` = not run, `False` = run)
Expand Down Expand Up @@ -684,26 +684,28 @@ def make(
):
env = PassiveEnvChecker(env)

env = StepAPICompatibility(env, new_step_api)

# Add the order enforcing wrapper
if spec_.order_enforce:
env = OrderEnforcing(env)

# Add the time limit wrapper
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps, new_step_api)
env = TimeLimit(env, max_episode_steps)
elif spec_.max_episode_steps is not None:
env = TimeLimit(env, spec_.max_episode_steps, new_step_api)
env = TimeLimit(env, spec_.max_episode_steps)

# Add the autoreset wrapper
if autoreset:
env = AutoResetWrapper(env, new_step_api)
env = AutoResetWrapper(env)

# Add human rendering wrapper
if apply_human_rendering:
env = HumanRendering(env)

# Add step API wrapper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work if the compatibility wrapper is at the end? As far as I understand, the use case here is if someone has a legacy environment, then it would convert it to a new-style environment. But wouldn't one of the wrappers before this crash out if the compatibility is not handled in advance?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(checked now, it doesn't work, at least assuming my understanding is correct)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think you are correct, this should occur after the environment checker in order of wrapper

if apply_step_compatibility:
env = StepAPICompatibility(env, True)

return env


Expand Down
16 changes: 10 additions & 6 deletions gym/utils/passive_env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,16 @@ def env_reset_passive_checker(env, **kwargs):
logger.warn(
f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`"
)

obs, info = result
check_obs(obs, env.observation_space, "reset")
assert isinstance(
info, dict
), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
elif len(result) != 2:
logger.warn(
"The result returned by `env.reset()` should be `(obs, info)` by default, , where `obs` is a observation and `info` is a dictionary containing additional information."
)
else:
obs, info = result
check_obs(obs, env.observation_space, "reset")
assert isinstance(
info, dict
), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
return result


Expand Down
30 changes: 16 additions & 14 deletions gym/utils/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def play(
:class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward
for last 150 steps.

>>> def callback(obs_t, obs_tp1, action, rew, done, info):
>>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info):
... return [rew,]
>>> plotter = PlayPlot(callback, 150, ["reward"])
>>> play(gym.make("ALE/AirRaid-v5"), callback=plotter.callback)
Expand All @@ -187,7 +187,8 @@ def play(
obs_tp1: observation after performing action
action: action that was executed
rew: reward that was received
done: whether the environment is done or not
terminated: whether the environment is terminated or not
truncated: whether the environment is truncated or not
info: debug info
keys_to_action: Mapping from keys pressed to action performed.
Different formats are supported: Key combinations can either be expressed as a tuple of unicode code
Expand Down Expand Up @@ -219,11 +220,6 @@ def play(
deprecation(
"`play.py` currently supports only the old step API which returns one boolean, however this will soon be updated to support only the new step api that returns two bools."
)
if env.render_mode not in {"rgb_array", "single_rgb_array"}:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed? Seems irrelevant

logger.error(
"play method works only with rgb_array and single_rgb_array render modes, "
f"but your environment render_mode = {env.render_mode}."
)

env.reset(seed=seed)

Expand Down Expand Up @@ -261,9 +257,10 @@ def play(
else:
action = key_code_to_action.get(tuple(sorted(game.pressed_keys)), noop)
prev_obs = obs
obs, rew, done, info = env.step(action)
obs, rew, terminated, truncated, info = env.step(action)
arjun-kg marked this conversation as resolved.
Show resolved Hide resolved
done = terminated or truncated
if callback is not None:
callback(prev_obs, obs, action, rew, done, info)
callback(prev_obs, obs, action, rew, terminated, truncated, info)
if obs is not None:
rendered = env.render()
if isinstance(rendered, List):
Expand All @@ -290,13 +287,14 @@ class PlayPlot:
- obs_tp1: observation after performing action
- action: action that was executed
- rew: reward that was received
- done: whether the environment is done or not
- terminated: whether the environment is terminated or not
- truncated: whether the environment is truncated or not
- info: debug info

It should return a list of metrics that are computed from this data.
For instance, the function may look like this::

>>> def compute_metrics(obs_t, obs_tp, action, reward, done, info):
>>> def compute_metrics(obs_t, obs_tp, action, reward, terminated, truncated, info):
... return [reward, info["cumulative_reward"], np.linalg.norm(action)]

:class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function
Expand Down Expand Up @@ -353,7 +351,8 @@ def callback(
obs_tp1: ObsType,
action: ActType,
rew: float,
done: bool,
terminated: bool,
truncated: bool,
info: dict,
):
"""The callback that calls the provided data callback and adds the data to the plots.
Expand All @@ -363,10 +362,13 @@ def callback(
obs_tp1: The observation at time step t+1
action: The action
rew: The reward
done: If the environment is done
terminated: If the environment is terminated
truncated: If the environment is truncated
info: The information from the environment
"""
points = self.data_callback(obs_t, obs_tp1, action, rew, done, info)
points = self.data_callback(
obs_t, obs_tp1, action, rew, terminated, truncated, info
)
for point, data_series in zip(points, self.data):
data_series.append(point)
self.t += 1
Expand Down
45 changes: 23 additions & 22 deletions gym/utils/step_api_compatibility.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""Contains methods for step compatibility, from old-to-new and new-to-old API, to be removed in 1.0."""
"""Contains methods for step compatibility, from old-to-new and new-to-old API."""
from typing import Tuple, Union

import numpy as np

from gym.core import ObsType

OldStepType = Tuple[
DoneStepType = Tuple[
Union[ObsType, np.ndarray],
Union[float, np.ndarray],
Union[bool, np.ndarray],
Union[dict, list],
]

NewStepType = Tuple[
TerminatedTruncatedStepType = Tuple[
Union[ObsType, np.ndarray],
Union[float, np.ndarray],
Union[bool, np.ndarray],
Expand All @@ -21,9 +21,9 @@
]


def step_to_new_api(
step_returns: Union[OldStepType, NewStepType], is_vector_env=False
) -> NewStepType:
def convert_to_terminated_truncated_step_api(
step_returns: Union[DoneStepType, TerminatedTruncatedStepType], is_vector_env=False
) -> TerminatedTruncatedStepType:
"""Function to transform step returns to new step API irrespective of input API.

Args:
Expand Down Expand Up @@ -73,9 +73,10 @@ def step_to_new_api(
)


def step_to_old_api(
step_returns: Union[NewStepType, OldStepType], is_vector_env: bool = False
) -> OldStepType:
def convert_to_done_step_api(
step_returns: Union[TerminatedTruncatedStepType, DoneStepType],
is_vector_env: bool = False,
) -> DoneStepType:
"""Function to transform step returns to old step API irrespective of input API.

Args:
Expand Down Expand Up @@ -128,33 +129,33 @@ def step_to_old_api(


def step_api_compatibility(
step_returns: Union[NewStepType, OldStepType],
new_step_api: bool = False,
step_returns: Union[TerminatedTruncatedStepType, DoneStepType],
output_truncation_bool: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this argument name, it's pretty unclear. I think there's a different name used earlier?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to_termination_truncation_api ? I agree it is not a great name but uncertain of a better one

is_vector_env: bool = False,
) -> Union[NewStepType, OldStepType]:
"""Function to transform step returns to the API specified by `new_step_api` bool.
) -> Union[TerminatedTruncatedStepType, DoneStepType]:
"""Function to transform step returns to the API specified by `output_truncation_bool` bool.

Old step API refers to step() method returning (observation, reward, done, info)
New step API refers to step() method returning (observation, reward, terminated, truncated, info)
Done (old) step API refers to step() method returning (observation, reward, done, info)
Terminated Truncated (new) step API refers to step() method returning (observation, reward, terminated, truncated, info)
(Refer to docs for details on the API change)

Args:
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
new_step_api (bool): Whether the output should be in new step API or old (False by default)
output_truncation_bool (bool): Whether the output should return two booleans (new API) or one (old) (True by default)
is_vector_env (bool): Whether the step_returns are from a vector environment

Returns:
step_returns (tuple): Depending on `new_step_api` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
step_returns (tuple): Depending on `output_truncation_bool` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info)

Examples:
This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API,
wrapper is written in new API, and the final step output is desired to be in old API.

>>> obs, rew, done, info = step_api_compatibility(env.step(action))
>>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), new_step_api=True)
>>> obs, rew, done, info = step_api_compatibility(env.step(action), output_truncation_bool=False)
>>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), output_truncation_bool=True)
>>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True)
"""
if new_step_api:
return step_to_new_api(step_returns, is_vector_env)
if output_truncation_bool:
return convert_to_terminated_truncated_step_api(step_returns, is_vector_env)
else:
return step_to_old_api(step_returns, is_vector_env)
return convert_to_done_step_api(step_returns, is_vector_env)
Loading