From d412e400e5fcac5c9154f2daeccfcc9092bc36f3 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Fri, 5 Aug 2022 05:46:52 +0530 Subject: [PATCH 01/12] removing compatibility stuff --- gym/core.py | 45 ++++------------------- gym/envs/registration.py | 15 +++++--- gym/utils/step_api_compatibility.py | 8 ++-- gym/vector/__init__.py | 9 +---- gym/vector/async_vector_env.py | 28 +++++--------- gym/vector/sync_vector_env.py | 23 +++++------- gym/vector/vector_env.py | 8 ---- gym/wrappers/atari_preprocessing.py | 17 ++------- gym/wrappers/autoreset.py | 15 ++------ gym/wrappers/clip_action.py | 2 +- gym/wrappers/env_checker.py | 2 +- gym/wrappers/filter_observation.py | 2 +- gym/wrappers/flatten_observation.py | 2 +- gym/wrappers/frame_stack.py | 14 ++----- gym/wrappers/gray_scale_observation.py | 2 +- gym/wrappers/human_rendering.py | 2 +- gym/wrappers/normalize.py | 30 ++++----------- gym/wrappers/order_enforcing.py | 2 +- gym/wrappers/pixel_observation.py | 2 +- gym/wrappers/record_episode_statistics.py | 24 +++++------- gym/wrappers/record_video.py | 13 ++----- gym/wrappers/rescale_action.py | 2 +- gym/wrappers/resize_observation.py | 2 +- gym/wrappers/step_api_compatibility.py | 17 ++++----- gym/wrappers/time_aware_observation.py | 10 ++--- gym/wrappers/time_limit.py | 15 ++------ gym/wrappers/transform_observation.py | 2 +- gym/wrappers/transform_reward.py | 2 +- gym/wrappers/vector_list_info.py | 16 ++------ 29 files changed, 101 insertions(+), 230 deletions(-) diff --git a/gym/core.py b/gym/core.py index 984ea7e0f6f..77991193728 100644 --- a/gym/core.py +++ b/gym/core.py @@ -131,11 +131,7 @@ def np_random(self) -> RandomNumberGenerator: def np_random(self, value: RandomNumberGenerator): self._np_random = value - def step( - self, action: ActType - ) -> Union[ - Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] - ]: + def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: """Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state. @@ -311,12 +307,11 @@ class Wrapper(Env[ObsType, ActType]): Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`. """ - def __init__(self, env: Env, new_step_api: bool = False): + def __init__(self, env: Env): """Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods. Args: env: The environment to wrap - new_step_api: Whether the wrapper's step method will output in new or old step API """ self.env = env @@ -324,12 +319,6 @@ def __init__(self, env: Env, new_step_api: bool = False): self._observation_space: Optional[spaces.Space] = None self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None self._metadata: Optional[dict] = None - self.new_step_api = new_step_api - - if not self.new_step_api: - deprecation( - "Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future." - ) def __getattr__(self, name): """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" @@ -411,17 +400,9 @@ def _np_random(self): "Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`." ) - def step( - self, action: ActType - ) -> Union[ - Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] - ]: + def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: """Steps through the environment with action.""" - from gym.utils.step_api_compatibility import ( # avoid circular import - step_api_compatibility, - ) - - return step_api_compatibility(self.env.step(action), self.new_step_api) + return self.env.step(action) def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]: """Resets the environment with kwargs.""" @@ -493,13 +474,8 @@ def reset(self, **kwargs): def step(self, action): """Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`.""" - step_returns = self.env.step(action) - if len(step_returns) == 5: - observation, reward, terminated, truncated, info = step_returns - return self.observation(observation), reward, terminated, truncated, info - else: - observation, reward, done, info = step_returns - return self.observation(observation), reward, done, info + observation, reward, terminated, truncated, info = self.env.step(action) + return self.observation(observation), reward, terminated, truncated, info def observation(self, observation): """Returns a modified observation.""" @@ -532,13 +508,8 @@ def reward(self, reward): def step(self, action): """Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`.""" - step_returns = self.env.step(action) - if len(step_returns) == 5: - observation, reward, terminated, truncated, info = step_returns - return observation, self.reward(reward), terminated, truncated, info - else: - observation, reward, done, info = step_returns - return observation, self.reward(reward), done, info + observation, reward, terminated, truncated, info = self.env.step(action) + return observation, self.reward(reward), terminated, truncated, info def reward(self, reward): """Returns a modified ``reward``.""" diff --git a/gym/envs/registration.py b/gym/envs/registration.py index 318243a9eec..5846478bc27 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -547,7 +547,7 @@ def make( id: Union[str, EnvSpec], max_episode_steps: Optional[int] = None, autoreset: bool = False, - new_step_api: bool = False, + new_step_api: bool = True, disable_env_checker: Optional[bool] = None, **kwargs, ) -> Env: @@ -557,7 +557,7 @@ def make( id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0' max_episode_steps: Maximum length of an episode (TimeLimit wrapper). autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper). - new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0 + new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper) disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker` (which is by default False, running the environment checker), otherwise will run according to this parameter (`True` = not run, `False` = run) @@ -684,7 +684,6 @@ def make( ): env = PassiveEnvChecker(env) - env = StepAPICompatibility(env, new_step_api) # Add the order enforcing wrapper if spec_.order_enforce: @@ -692,18 +691,22 @@ def make( # Add the time limit wrapper if max_episode_steps is not None: - env = TimeLimit(env, max_episode_steps, new_step_api) + env = TimeLimit(env, max_episode_steps) elif spec_.max_episode_steps is not None: - env = TimeLimit(env, spec_.max_episode_steps, new_step_api) + env = TimeLimit(env, spec_.max_episode_steps) # Add the autoreset wrapper if autoreset: - env = AutoResetWrapper(env, new_step_api) + env = AutoResetWrapper(env) # Add human rendering wrapper if apply_human_rendering: env = HumanRendering(env) + # Add step API wrapper + if not new_step_api: + env = StepAPICompatibility(env, new_step_api) + return env diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index 2be07dbe35c..c1a0a8c27f8 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -1,4 +1,4 @@ -"""Contains methods for step compatibility, from old-to-new and new-to-old API, to be removed in 1.0.""" +"""Contains methods for step compatibility, from old-to-new and new-to-old API""" from typing import Tuple, Union import numpy as np @@ -149,7 +149,7 @@ def step_to_old_api( def step_api_compatibility( step_returns: Union[NewStepType, OldStepType], - new_step_api: bool = False, + new_step_api: bool = True, is_vector_env: bool = False, ) -> Union[NewStepType, OldStepType]: """Function to transform step returns to the API specified by `new_step_api` bool. @@ -160,7 +160,7 @@ def step_api_compatibility( Args: step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) - new_step_api (bool): Whether the output should be in new step API or old (False by default) + new_step_api (bool): Whether the output should be in new step API or old (True by default) is_vector_env (bool): Whether the step_returns are from a vector environment Returns: @@ -170,7 +170,7 @@ def step_api_compatibility( This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API, wrapper is written in new API, and the final step output is desired to be in old API. - >>> obs, rew, done, info = step_api_compatibility(env.step(action)) + >>> obs, rew, done, info = step_api_compatibility(env.step(action), new_step_api=False) >>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), new_step_api=True) >>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True) """ diff --git a/gym/vector/__init__.py b/gym/vector/__init__.py index 3dc4998fa8d..1eb9653f29c 100644 --- a/gym/vector/__init__.py +++ b/gym/vector/__init__.py @@ -15,7 +15,6 @@ def make( asynchronous: bool = True, wrappers: Optional[Union[callable, List[callable]]] = None, disable_env_checker: Optional[bool] = None, - new_step_api: bool = False, **kwargs, ) -> VectorEnv: """Create a vectorized environment from multiple copies of an environment, from its id. @@ -37,7 +36,6 @@ def make( wrappers: If not ``None``, then apply the wrappers to each internal environment during creation. disable_env_checker: If to run the env checker for the first environment only. None will default to the environment spec `disable_env_checker` parameter (that is by default False), otherwise will run according to this argument (True = not run, False = run) - new_step_api: If True, the vector environment's step method outputs two booleans `terminated`, `truncated` instead of one `done`. **kwargs: Keywords arguments applied during `gym.make` Returns: @@ -53,7 +51,6 @@ def _make_env(): env = gym.envs.registration.make( id, disable_env_checker=_disable_env_checker, - new_step_api=True, **kwargs, ) if wrappers is not None: @@ -73,8 +70,4 @@ def _make_env(): env_fns = [ create_env(disable_env_checker or env_num > 0) for env_num in range(num_envs) ] - return ( - AsyncVectorEnv(env_fns, new_step_api=new_step_api) - if asynchronous - else SyncVectorEnv(env_fns, new_step_api=new_step_api) - ) + return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns) diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index 0c71d959736..069aa36b44b 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -17,7 +17,6 @@ CustomSpaceError, NoAsyncCallError, ) -from gym.utils.step_api_compatibility import step_api_compatibility from gym.vector.utils import ( CloudpickleWrapper, clear_mpi_env_vars, @@ -67,7 +66,6 @@ def __init__( context: Optional[str] = None, daemon: bool = True, worker: Optional[callable] = None, - new_step_api: bool = False, ): """Vectorized environment that runs multiple environments in parallel. @@ -87,7 +85,6 @@ def __init__( so for some environments you may want to have it set to ``False``. worker: If set, then use that worker in a subprocess instead of a default one. Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled. - new_step_api: If True, step method returns 2 bools - terminated, truncated, instead of 1 bool - done Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start @@ -115,7 +112,6 @@ def __init__( num_envs=len(env_fns), observation_space=observation_space, action_space=action_space, - new_step_api=new_step_api, ) if self.shared_memory: @@ -342,7 +338,7 @@ def step_wait( timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out. Returns: - The batched environment step information, (obs, reward, terminated, truncated, info) or (obs, reward, done, info) depending on new_step_api + The batched environment step information, (obs, reward, terminated, truncated, info) Raises: ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). @@ -366,7 +362,7 @@ def step_wait( successes = [] for i, pipe in enumerate(self.parent_pipes): result, success = pipe.recv() - obs, rew, terminated, truncated, info = step_api_compatibility(result, True) + obs, rew, terminated, truncated, info = result successes.append(success) observations_list.append(obs) @@ -385,16 +381,12 @@ def step_wait( self.observations, ) - return step_api_compatibility( - ( - deepcopy(self.observations) if self.copy else self.observations, - np.array(rewards), - np.array(terminateds, dtype=np.bool_), - np.array(truncateds, dtype=np.bool_), - infos, - ), - self.new_step_api, - True, + return ( + deepcopy(self.observations) if self.copy else self.observations, + np.array(rewards), + np.array(terminateds, dtype=np.bool_), + np.array(truncateds, dtype=np.bool_), + infos, ) def call_async(self, name: str, *args, **kwargs): @@ -620,7 +612,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): terminated, truncated, info, - ) = step_api_compatibility(env.step(data), True) + ) = env.step(data) if terminated or truncated: info["final_observation"] = observation observation = env.reset() @@ -695,7 +687,7 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error terminated, truncated, info, - ) = step_api_compatibility(env.step(data), True) + ) = env.step(data) if terminated or truncated: info["final_observation"] = observation observation = env.reset() diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index cc3408e7adb..a3c9f2fdb5b 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -6,7 +6,6 @@ from gym import Env from gym.spaces import Space -from gym.utils.step_api_compatibility import step_api_compatibility from gym.vector.utils import concatenate, create_empty_array, iterate from gym.vector.vector_env import VectorEnv @@ -34,7 +33,6 @@ def __init__( observation_space: Space = None, action_space: Space = None, copy: bool = True, - new_step_api: bool = False, ): """Vectorized environment that serially runs multiple environments. @@ -62,7 +60,6 @@ def __init__( num_envs=len(self.envs), observation_space=observation_space, action_space=action_space, - new_step_api=new_step_api, ) self._check_spaces() @@ -156,13 +153,15 @@ def step_wait(self): """ observations, infos = [], {} for i, (env, action) in enumerate(zip(self.envs, self._actions)): + ( observation, self._rewards[i], self._terminateds[i], self._truncateds[i], info, - ) = step_api_compatibility(env.step(action), True) + ) = env.step(action) + if self._terminateds[i] or self._truncateds[i]: info["final_observation"] = observation observation = env.reset() @@ -172,16 +171,12 @@ def step_wait(self): self.single_observation_space, observations, self.observations ) - return step_api_compatibility( - ( - deepcopy(self.observations) if self.copy else self.observations, - np.copy(self._rewards), - np.copy(self._terminateds), - np.copy(self._truncateds), - infos, - ), - new_step_api=self.new_step_api, - is_vector_env=True, + return ( + deepcopy(self.observations) if self.copy else self.observations, + np.copy(self._rewards), + np.copy(self._terminateds), + np.copy(self._truncateds), + infos, ) def call(self, name, *args, **kwargs) -> tuple: diff --git a/gym/vector/vector_env.py b/gym/vector/vector_env.py index 3ca4663d822..86deff1b2ea 100644 --- a/gym/vector/vector_env.py +++ b/gym/vector/vector_env.py @@ -28,7 +28,6 @@ def __init__( num_envs: int, observation_space: gym.Space, action_space: gym.Space, - new_step_api: bool = False, ): """Base class for vectorized environments. @@ -36,7 +35,6 @@ def __init__( num_envs: Number of environments in the vectorized environment. observation_space: Observation space of a single environment. action_space: Action space of a single environment. - new_step_api (bool): Whether the vector environment's step method outputs two boolean arrays (new API) or one boolean array (old API) """ self.num_envs = num_envs self.is_vector_env = True @@ -51,12 +49,6 @@ def __init__( self.single_observation_space = observation_space self.single_action_space = action_space - self.new_step_api = new_step_api - if not self.new_step_api: - deprecation( - "Initializing vector env in old step API which returns one bool array instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future." - ) - def reset_async( self, seed: Optional[Union[int, List[int]]] = None, diff --git a/gym/wrappers/atari_preprocessing.py b/gym/wrappers/atari_preprocessing.py index 96c3ee7b176..58bf2947bfd 100644 --- a/gym/wrappers/atari_preprocessing.py +++ b/gym/wrappers/atari_preprocessing.py @@ -3,7 +3,6 @@ import gym from gym.spaces import Box -from gym.utils.step_api_compatibility import step_api_compatibility try: import cv2 @@ -38,7 +37,6 @@ def __init__( grayscale_obs: bool = True, grayscale_newaxis: bool = False, scale_obs: bool = False, - new_step_api: bool = False, ): """Wrapper for Atari 2600 preprocessing. @@ -60,7 +58,7 @@ def __init__( DependencyNotInstalled: opencv-python package not installed ValueError: Disable frame-skipping in the original env """ - super().__init__(env, new_step_api) + super().__init__(env) if cv2 is None: raise gym.error.DependencyNotInstalled( "opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari" @@ -119,9 +117,7 @@ def step(self, action): total_reward, terminated, truncated, info = 0.0, False, False, {} for t in range(self.frame_skip): - _, reward, terminated, truncated, info = step_api_compatibility( - self.env.step(action), True - ) + _, reward, terminated, truncated, info = self.env.step(action) total_reward += reward self.game_over = terminated @@ -143,10 +139,7 @@ def step(self, action): self.ale.getScreenGrayscale(self.obs_buffer[0]) else: self.ale.getScreenRGB(self.obs_buffer[0]) - return step_api_compatibility( - (self._get_obs(), total_reward, terminated, truncated, info), - self.new_step_api, - ) + return self._get_obs(), total_reward, terminated, truncated, info def reset(self, **kwargs): """Resets the environment using preprocessing.""" @@ -163,9 +156,7 @@ def reset(self, **kwargs): else 0 ) for _ in range(noops): - _, _, terminated, truncated, step_info = step_api_compatibility( - self.env.step(0), True - ) + _, _, terminated, truncated, step_info = self.env.step(0) reset_info.update(step_info) if terminated or truncated: if kwargs.get("return_info", False): diff --git a/gym/wrappers/autoreset.py b/gym/wrappers/autoreset.py index 6e20c92ffed..a07b67a389f 100644 --- a/gym/wrappers/autoreset.py +++ b/gym/wrappers/autoreset.py @@ -1,6 +1,5 @@ """Wrapper that autoreset environments when `terminated=True` or `truncated=True`.""" import gym -from gym.utils.step_api_compatibility import step_api_compatibility class AutoResetWrapper(gym.Wrapper): @@ -24,14 +23,13 @@ class AutoResetWrapper(gym.Wrapper): Make sure you know what you're doing if you use this wrapper! """ - def __init__(self, env: gym.Env, new_step_api: bool = False): + def __init__(self, env: gym.Env): """A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`. Args: env (gym.Env): The environment to apply the wrapper - new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env, new_step_api) + super().__init__(env) def step(self, action): """Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered. @@ -42,10 +40,7 @@ def step(self, action): Returns: The autoreset environment :meth:`step` """ - obs, reward, terminated, truncated, info = step_api_compatibility( - self.env.step(action), True - ) - + obs, reward, terminated, truncated, info = self.env.step(action) if terminated or truncated: new_obs, new_info = self.env.reset(return_info=True) @@ -62,6 +57,4 @@ def step(self, action): obs = new_obs info = new_info - return step_api_compatibility( - (obs, reward, terminated, truncated, info), self.new_step_api - ) + return obs, reward, terminated, truncated, info diff --git a/gym/wrappers/clip_action.py b/gym/wrappers/clip_action.py index 58d981e96a1..de236384768 100644 --- a/gym/wrappers/clip_action.py +++ b/gym/wrappers/clip_action.py @@ -26,7 +26,7 @@ def __init__(self, env: gym.Env): env: The environment to apply the wrapper """ assert isinstance(env.action_space, Box) - super().__init__(env, new_step_api=True) + super().__init__(env) def action(self, action): """Clips the action within the valid bounds. diff --git a/gym/wrappers/env_checker.py b/gym/wrappers/env_checker.py index 412689c5c5b..9c8d4b63f30 100644 --- a/gym/wrappers/env_checker.py +++ b/gym/wrappers/env_checker.py @@ -15,7 +15,7 @@ class PassiveEnvChecker(gym.Wrapper): def __init__(self, env): """Initialises the wrapper with the environments, run the observation and action space tests.""" - super().__init__(env, new_step_api=True) + super().__init__(env) assert hasattr( env, "action_space" diff --git a/gym/wrappers/filter_observation.py b/gym/wrappers/filter_observation.py index bcbe13b5065..922c8288038 100644 --- a/gym/wrappers/filter_observation.py +++ b/gym/wrappers/filter_observation.py @@ -35,7 +35,7 @@ def __init__(self, env: gym.Env, filter_keys: Sequence[str] = None): ValueError: If the environment's observation space is not :class:`spaces.Dict` ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space """ - super().__init__(env, new_step_api=True) + super().__init__(env) wrapped_observation_space = env.observation_space if not isinstance(wrapped_observation_space, spaces.Dict): diff --git a/gym/wrappers/flatten_observation.py b/gym/wrappers/flatten_observation.py index 95aa13e0d01..fe6518b875b 100644 --- a/gym/wrappers/flatten_observation.py +++ b/gym/wrappers/flatten_observation.py @@ -25,7 +25,7 @@ def __init__(self, env: gym.Env): Args: env: The environment to apply the wrapper """ - super().__init__(env, new_step_api=True) + super().__init__(env) self.observation_space = spaces.flatten_space(env.observation_space) def observation(self, observation): diff --git a/gym/wrappers/frame_stack.py b/gym/wrappers/frame_stack.py index 388e1e75de5..9fd6d73e4c9 100644 --- a/gym/wrappers/frame_stack.py +++ b/gym/wrappers/frame_stack.py @@ -7,7 +7,6 @@ import gym from gym.error import DependencyNotInstalled from gym.spaces import Box -from gym.utils.step_api_compatibility import step_api_compatibility class LazyFrames: @@ -128,7 +127,6 @@ def __init__( env: gym.Env, num_stack: int, lz4_compress: bool = False, - new_step_api: bool = False, ): """Observation wrapper that stacks the observations in a rolling manner. @@ -136,9 +134,8 @@ def __init__( env (Env): The environment to apply the wrapper num_stack (int): The number of frames to stack lz4_compress (bool): Use lz4 to compress the frames internally - new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env, new_step_api) + super().__init__(env) self.num_stack = num_stack self.lz4_compress = lz4_compress @@ -173,14 +170,9 @@ def step(self, action): Returns: Stacked observations, reward, terminated, truncated, and information from the environment """ - observation, reward, terminated, truncated, info = step_api_compatibility( - self.env.step(action), True - ) + observation, reward, terminated, truncated, info = self.env.step(action) self.frames.append(observation) - return step_api_compatibility( - (self.observation(None), reward, terminated, truncated, info), - self.new_step_api, - ) + return self.observation(None), reward, terminated, truncated, info def reset(self, **kwargs): """Reset the environment with kwargs. diff --git a/gym/wrappers/gray_scale_observation.py b/gym/wrappers/gray_scale_observation.py index cf8a2ea05c7..1c626f41f4f 100644 --- a/gym/wrappers/gray_scale_observation.py +++ b/gym/wrappers/gray_scale_observation.py @@ -28,7 +28,7 @@ def __init__(self, env: gym.Env, keep_dim: bool = False): keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1. Otherwise, they are of shape AxB. """ - super().__init__(env, new_step_api=True) + super().__init__(env) self.keep_dim = keep_dim assert ( diff --git a/gym/wrappers/human_rendering.py b/gym/wrappers/human_rendering.py index e2b234cf903..50bc12751d9 100644 --- a/gym/wrappers/human_rendering.py +++ b/gym/wrappers/human_rendering.py @@ -45,7 +45,7 @@ def __init__(self, env): Args: env: The environment that is being wrapped """ - super().__init__(env, new_step_api=True) + super().__init__(env) assert env.render_mode in [ "single_rgb_array", "rgb_array", diff --git a/gym/wrappers/normalize.py b/gym/wrappers/normalize.py index 0c6ab04a48b..25571ca3200 100644 --- a/gym/wrappers/normalize.py +++ b/gym/wrappers/normalize.py @@ -2,7 +2,6 @@ import numpy as np import gym -from gym.utils.step_api_compatibility import step_api_compatibility # taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py @@ -55,15 +54,14 @@ class NormalizeObservation(gym.core.Wrapper): newly instantiated or the policy was changed recently. """ - def __init__(self, env: gym.Env, epsilon: float = 1e-8, new_step_api: bool = False): + def __init__(self, env: gym.Env, epsilon: float = 1e-8): """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. Args: env (Env): The environment to apply the wrapper epsilon: A stability parameter that is used when scaling the observations. - new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env, new_step_api) + super().__init__(env) self.num_envs = getattr(env, "num_envs", 1) self.is_vector_env = getattr(env, "is_vector_env", False) if self.is_vector_env: @@ -74,18 +72,12 @@ def __init__(self, env: gym.Env, epsilon: float = 1e-8, new_step_api: bool = Fal def step(self, action): """Steps through the environment and normalizes the observation.""" - obs, rews, terminateds, truncateds, infos = step_api_compatibility( - self.env.step(action), True, self.is_vector_env - ) + obs, rews, terminateds, truncateds, infos = self.env.step(action) if self.is_vector_env: obs = self.normalize(obs) else: obs = self.normalize(np.array([obs]))[0] - return step_api_compatibility( - (obs, rews, terminateds, truncateds, infos), - self.new_step_api, - self.is_vector_env, - ) + return obs, rews, terminateds, truncateds, infos def reset(self, **kwargs): """Resets the environment and normalizes the observation.""" @@ -125,7 +117,6 @@ def __init__( env: gym.Env, gamma: float = 0.99, epsilon: float = 1e-8, - new_step_api: bool = False, ): """This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. @@ -133,9 +124,8 @@ def __init__( env (env): The environment to apply the wrapper epsilon (float): A stability parameter gamma (float): The discount factor that is used in the exponential moving average. - new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env, new_step_api) + super().__init__(env) self.num_envs = getattr(env, "num_envs", 1) self.is_vector_env = getattr(env, "is_vector_env", False) self.return_rms = RunningMeanStd(shape=()) @@ -145,9 +135,7 @@ def __init__( def step(self, action): """Steps through the environment, normalizing the rewards returned.""" - obs, rews, terminateds, truncateds, infos = step_api_compatibility( - self.env.step(action), True, self.is_vector_env - ) + obs, rews, terminateds, truncateds, infos = self.env.step(action) if not self.is_vector_env: rews = np.array([rews]) self.returns = self.returns * self.gamma + rews @@ -159,11 +147,7 @@ def step(self, action): self.returns[dones] = 0.0 if not self.is_vector_env: rews = rews[0] - return step_api_compatibility( - (obs, rews, terminateds, truncateds, infos), - self.new_step_api, - self.is_vector_env, - ) + return obs, rews, terminateds, truncateds, infos def normalize(self, rews): """Normalizes the rewards with the running mean rewards and their variance.""" diff --git a/gym/wrappers/order_enforcing.py b/gym/wrappers/order_enforcing.py index 0e9da7f878e..d9f853e72bc 100644 --- a/gym/wrappers/order_enforcing.py +++ b/gym/wrappers/order_enforcing.py @@ -26,7 +26,7 @@ def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False): env: The environment to wrap disable_render_order_enforcing: If to disable render order enforcing """ - super().__init__(env, new_step_api=True) + super().__init__(env) self._has_reset: bool = False self._disable_render_order_enforcing: bool = disable_render_order_enforcing diff --git a/gym/wrappers/pixel_observation.py b/gym/wrappers/pixel_observation.py index 2d47d4829c5..628e53e2428 100644 --- a/gym/wrappers/pixel_observation.py +++ b/gym/wrappers/pixel_observation.py @@ -77,7 +77,7 @@ def __init__( specified ``pixel_keys``. TypeError: When an unexpected pixel type is used """ - super().__init__(env, new_step_api=True) + super().__init__(env) # Avoid side-effects that occur when render_kwargs is manipulated render_kwargs = copy.deepcopy(render_kwargs) diff --git a/gym/wrappers/record_episode_statistics.py b/gym/wrappers/record_episode_statistics.py index 26cdb98f895..0a822cea4ea 100644 --- a/gym/wrappers/record_episode_statistics.py +++ b/gym/wrappers/record_episode_statistics.py @@ -6,7 +6,6 @@ import numpy as np import gym -from gym.utils.step_api_compatibility import step_api_compatibility def add_vector_episode_statistics( @@ -77,15 +76,14 @@ class RecordEpisodeStatistics(gym.Wrapper): length_queue: The lengths of the last ``deque_size``-many episodes """ - def __init__(self, env: gym.Env, deque_size: int = 100, new_step_api: bool = False): + def __init__(self, env: gym.Env, deque_size: int = 100): """This wrapper will keep track of cumulative rewards and episode lengths. Args: env (Env): The environment to apply the wrapper deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue` - new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env, new_step_api) + super().__init__(env) self.num_envs = getattr(env, "num_envs", 1) self.t0 = time.perf_counter() self.episode_count = 0 @@ -110,7 +108,7 @@ def step(self, action): terminateds, truncateds, infos, - ) = step_api_compatibility(self.env.step(action), True, self.is_vector_env) + ) = self.env.step(action) assert isinstance( infos, dict ), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order." @@ -144,14 +142,10 @@ def step(self, action): self.episode_count += 1 self.episode_returns[i] = 0 self.episode_lengths[i] = 0 - return step_api_compatibility( - ( - observations, - rewards, - terminateds if self.is_vector_env else terminateds[0], - truncateds if self.is_vector_env else truncateds[0], - infos, - ), - self.new_step_api, - self.is_vector_env, + return ( + observations, + rewards, + terminateds if self.is_vector_env else terminateds[0], + truncateds if self.is_vector_env else truncateds[0], + infos, ) diff --git a/gym/wrappers/record_video.py b/gym/wrappers/record_video.py index 8736915576e..e2fb59f0dec 100644 --- a/gym/wrappers/record_video.py +++ b/gym/wrappers/record_video.py @@ -4,7 +4,6 @@ import gym from gym import logger -from gym.utils.step_api_compatibility import step_api_compatibility from gym.wrappers.monitoring import video_recorder @@ -46,7 +45,6 @@ def __init__( step_trigger: Callable[[int], bool] = None, video_length: int = 0, name_prefix: str = "rl-video", - new_step_api: bool = False, ): """Wrapper records videos of rollouts. @@ -58,9 +56,8 @@ def __init__( video_length (int): The length of recorded episodes. If 0, entire episodes are recorded. Otherwise, snippets of the specified length are captured name_prefix (str): Will be prepended to the filename of the recordings - new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env, new_step_api) + super().__init__(env) if episode_trigger is None and step_trigger is None: episode_trigger = capped_cubic_video_schedule @@ -143,7 +140,7 @@ def step(self, action): terminateds, truncateds, infos, - ) = step_api_compatibility(self.env.step(action), True, self.is_vector_env) + ) =self.env.step(action) if not (self.terminated or self.truncated): # increment steps and episodes @@ -175,11 +172,7 @@ def step(self, action): elif self._video_enabled(): self.start_video_recorder() - return step_api_compatibility( - (observations, rewards, terminateds, truncateds, infos), - self.new_step_api, - self.is_vector_env, - ) + return observations, rewards, terminateds, truncateds, infos def close_video_recorder(self): """Closes the video recorder if currently recording.""" diff --git a/gym/wrappers/rescale_action.py b/gym/wrappers/rescale_action.py index c5f2238159e..bf3cf6cd157 100644 --- a/gym/wrappers/rescale_action.py +++ b/gym/wrappers/rescale_action.py @@ -45,7 +45,7 @@ def __init__( ), f"expected Box action space, got {type(env.action_space)}" assert np.less_equal(min_action, max_action).all(), (min_action, max_action) - super().__init__(env, new_step_api=True) + super().__init__(env) self.min_action = ( np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action ) diff --git a/gym/wrappers/resize_observation.py b/gym/wrappers/resize_observation.py index 29116be7f09..4f486a97bdf 100644 --- a/gym/wrappers/resize_observation.py +++ b/gym/wrappers/resize_observation.py @@ -32,7 +32,7 @@ def __init__(self, env: gym.Env, shape: Union[tuple, int]): env: The environment to apply the wrapper shape: The shape of the resized observations """ - super().__init__(env, new_step_api=True) + super().__init__(env) if isinstance(shape, int): shape = (shape, shape) assert all(x > 0 for x in shape), shape diff --git a/gym/wrappers/step_api_compatibility.py b/gym/wrappers/step_api_compatibility.py index 72d4c8a1e07..8f67167254e 100644 --- a/gym/wrappers/step_api_compatibility.py +++ b/gym/wrappers/step_api_compatibility.py @@ -11,33 +11,32 @@ class StepAPICompatibility(gym.Wrapper): New step API refers to step() method returning (observation, reward, terminated, truncated, info) (Refer to docs for details on the API change) - This wrapper is to be used to ease transition to new API and for backward compatibility. - Args: env (gym.Env): the env to wrap. Can be in old or new API - new_step_api (bool): True to use env with new step API, False to use env with old step API. (False by default) + new_step_api (bool): True to use env with new step API, False to use env with old step API. (True by default) Examples: >>> env = gym.make("CartPole-v1") - >>> env # wrapper applied by default, set to old API - >>>> - >>> env = gym.make("CartPole-v1", new_step_api=True) # set to new API + >>> env # wrapper not applied by default, set to new API + >>>> + >>> env = gym.make("CartPole-v1", new_step_api=False) # set to old API + >>>>> >>> env = StepAPICompatibility(CustomEnv(), new_step_api=True) # manually using wrapper on unregistered envs """ - def __init__(self, env: gym.Env, new_step_api=False): + def __init__(self, env: gym.Env, new_step_api=True): """A wrapper which can transform an environment from new step API to old and vice-versa. Args: env (gym.Env): the env to wrap. Can be in old or new API new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env, new_step_api) + super().__init__(env) self.new_step_api = new_step_api if not self.new_step_api: deprecation( - "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future." + "Initializing environment in old step API which returns one bool instead of two." ) def step(self, action): diff --git a/gym/wrappers/time_aware_observation.py b/gym/wrappers/time_aware_observation.py index 2307eb06334..781a77c2533 100644 --- a/gym/wrappers/time_aware_observation.py +++ b/gym/wrappers/time_aware_observation.py @@ -3,7 +3,6 @@ import gym from gym.spaces import Box -from gym.utils.step_api_compatibility import step_api_compatibility class TimeAwareObservation(gym.ObservationWrapper): @@ -22,14 +21,13 @@ class TimeAwareObservation(gym.ObservationWrapper): array([ 0.03881167, -0.16021058, 0.0220928 , 0.28875574, 1. ]) """ - def __init__(self, env: gym.Env, new_step_api: bool = False): + def __init__(self, env: gym.Env): """Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` observation space. Args: env: The environment to apply the wrapper - new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env, new_step_api) + super().__init__(env) assert isinstance(env.observation_space, Box) assert env.observation_space.dtype == np.float32 low = np.append(self.observation_space.low, 0.0) @@ -58,9 +56,7 @@ def step(self, action): The environment's step using the action. """ self.t += 1 - return step_api_compatibility( - super().step(action), self.new_step_api, self.is_vector_env - ) + return super().step(action) def reset(self, **kwargs): """Reset the environment setting the time to zero. diff --git a/gym/wrappers/time_limit.py b/gym/wrappers/time_limit.py index 8e9f67f4ae9..735f3b820cb 100644 --- a/gym/wrappers/time_limit.py +++ b/gym/wrappers/time_limit.py @@ -2,7 +2,6 @@ from typing import Optional import gym -from gym.utils.step_api_compatibility import step_api_compatibility class TimeLimit(gym.Wrapper): @@ -28,16 +27,14 @@ def __init__( self, env: gym.Env, max_episode_steps: Optional[int] = None, - new_step_api: bool = False, ): """Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur. Args: env: The environment to apply the wrapper max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used) - new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env, new_step_api) + super().__init__(env) if max_episode_steps is None and self.env.spec is not None: max_episode_steps = env.spec.max_episode_steps if self.env.spec is not None: @@ -56,19 +53,13 @@ def step(self, action): when truncated (the number of steps elapsed >= max episode steps) or "TimeLimit.truncated"=False if the environment terminated """ - observation, reward, terminated, truncated, info = step_api_compatibility( - self.env.step(action), - True, - ) + observation, reward, terminated, truncated, info = (self.env.step(action),) self._elapsed_steps += 1 if self._elapsed_steps >= self._max_episode_steps: truncated = True - return step_api_compatibility( - (observation, reward, terminated, truncated, info), - self.new_step_api, - ) + return observation, reward, terminated, truncated, info def reset(self, **kwargs): """Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero. diff --git a/gym/wrappers/transform_observation.py b/gym/wrappers/transform_observation.py index 4da9db5bac9..2af2e9afb40 100644 --- a/gym/wrappers/transform_observation.py +++ b/gym/wrappers/transform_observation.py @@ -27,7 +27,7 @@ def __init__(self, env: gym.Env, f: Callable[[Any], Any]): env: The environment to apply the wrapper f: A function that transforms the observation """ - super().__init__(env, new_step_api=True) + super().__init__(env) assert callable(f) self.f = f diff --git a/gym/wrappers/transform_reward.py b/gym/wrappers/transform_reward.py index 13278182d6b..a17a8ef1bc0 100644 --- a/gym/wrappers/transform_reward.py +++ b/gym/wrappers/transform_reward.py @@ -28,7 +28,7 @@ def __init__(self, env: gym.Env, f: Callable[[float], float]): env: The environment to apply the wrapper f: A function that transforms the reward """ - super().__init__(env, new_step_api=True) + super().__init__(env) assert callable(f) self.f = f diff --git a/gym/wrappers/vector_list_info.py b/gym/wrappers/vector_list_info.py index 727d4981066..d5a4c98db0a 100644 --- a/gym/wrappers/vector_list_info.py +++ b/gym/wrappers/vector_list_info.py @@ -3,7 +3,6 @@ from typing import List import gym -from gym.utils.step_api_compatibility import step_api_compatibility class VectorListInfo(gym.Wrapper): @@ -30,30 +29,23 @@ class VectorListInfo(gym.Wrapper): """ - def __init__(self, env, new_step_api=False): + def __init__(self, env): """This wrapper will convert the info into the list format. Args: env (Env): The environment to apply the wrapper - new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ assert getattr( env, "is_vector_env", False ), "This wrapper can only be used in vectorized environments." - super().__init__(env, new_step_api) + super().__init__(env) def step(self, action): """Steps through the environment, convert dict info to list.""" - observation, reward, terminated, truncated, infos = step_api_compatibility( - self.env.step(action), True, True - ) + observation, reward, terminated, truncated, infos = self.env.step(action) list_info = self._convert_info_to_list(infos) - return step_api_compatibility( - (observation, reward, terminated, truncated, list_info), - self.new_step_api, - True, - ) + return observation, reward, terminated, truncated, list_info def reset(self, **kwargs): """Resets the environment using kwargs.""" From fd427bdd66ac24f17f8bf58d8c1dc97592e8aa24 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Fri, 5 Aug 2022 23:01:02 +0530 Subject: [PATCH 02/12] update wrappers (except timelimit) --- gym/vector/async_vector_env.py | 2 +- tests/envs/test_action_dim_check.py | 8 +- tests/envs/test_envs.py | 15 +++- tests/envs/test_mujoco.py | 19 ++-- tests/utils/test_terminated_truncated.py | 16 ++-- tests/vector/test_async_vector_env.py | 19 ++-- .../vector/test_step_compatibility_vector.py | 88 ------------------- tests/vector/test_sync_vector_env.py | 17 ++-- tests/vector/test_vector_env.py | 9 +- tests/vector/test_vector_env_info.py | 16 ++-- tests/wrappers/test_atari_preprocessing.py | 8 +- tests/wrappers/test_autoreset.py | 20 ++--- tests/wrappers/test_clip_action.py | 7 +- tests/wrappers/test_frame_stack.py | 9 +- tests/wrappers/test_human_rendering.py | 4 +- .../test_record_episode_statistics.py | 4 +- tests/wrappers/test_record_video.py | 8 +- tests/wrappers/test_rescale_action.py | 4 +- tests/wrappers/test_step_compatibility.py | 22 ++--- tests/wrappers/test_time_aware_observation.py | 4 +- tests/wrappers/test_time_limit.py | 13 +-- tests/wrappers/test_transform_observation.py | 13 ++- tests/wrappers/test_transform_reward.py | 12 +-- tests/wrappers/test_vector_list_info.py | 12 +-- 24 files changed, 145 insertions(+), 204 deletions(-) delete mode 100644 tests/vector/test_step_compatibility_vector.py diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index 069aa36b44b..40ccf9ffc09 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -331,7 +331,7 @@ def step_async(self, actions: np.ndarray): def step_wait( self, timeout: Optional[Union[int, float]] = None - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[dict]]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[dict]]: """Wait for the calls to :obj:`step` in each sub-environment to finish. Args: diff --git a/tests/envs/test_action_dim_check.py b/tests/envs/test_action_dim_check.py index af857567795..1643f6e34f2 100644 --- a/tests/envs/test_action_dim_check.py +++ b/tests/envs/test_action_dim_check.py @@ -112,14 +112,12 @@ def test_box_actions_out_of_bound(env: gym.Env): zip(env.action_space.bounded_above, env.action_space.bounded_below) ): if is_upper_bound: - obs, _, _, _, _ = env.step( - upper_bounds - ) # `env` is unwrapped, and in new step API + obs, _, _, _, _ = env.step(upper_bounds) oob_action = upper_bounds.copy() oob_action[i] += np.cast[dtype](OOB_VALUE) assert oob_action[i] > upper_bounds[i] - oob_obs, _, _, _ = oob_env.step(oob_action) + oob_obs, _, _, _, _ = oob_env.step(oob_action) assert np.alltrue(obs == oob_obs) @@ -131,7 +129,7 @@ def test_box_actions_out_of_bound(env: gym.Env): oob_action[i] -= np.cast[dtype](OOB_VALUE) assert oob_action[i] < lower_bounds[i] - oob_obs, _, _, _ = oob_env.step(oob_action) + oob_obs, _, _, _, _ = oob_env.step(oob_action) assert np.alltrue(obs == oob_obs) diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index db9312a2129..3526ed8cb8c 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -85,8 +85,8 @@ def test_env_determinism_rollout(env_spec: EnvSpec): # We don't evaluate the determinism of actions action = env_1.action_space.sample() - obs_1, rew_1, done_1, info_1 = env_1.step(action) - obs_2, rew_2, done_2, info_2 = env_2.step(action) + obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action) + obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action) assert_equals(obs_1, obs_2, f"[{time_step}] ") assert env_1.observation_space.contains( @@ -94,10 +94,17 @@ def test_env_determinism_rollout(env_spec: EnvSpec): ) # obs_2 verified by previous assertion assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}" - assert done_1 == done_2, f"[{time_step}] done 1={done_1}, done 2={done_2}" + assert ( + terminated_1 == terminated_2 + ), f"[{time_step}] done 1={terminated_1}, done 2={terminated_2}" + assert ( + truncated_1 == truncated_2 + ), f"[{time_step}] done 1={truncated_1}, done 2={truncated_2}" assert_equals(info_1, info_2, f"[{time_step}] ") - if done_1: # done_2 verified by previous assertion + if ( + terminated_1 or truncated_1 + ): # terminated_2, truncated_2 verified by previous assertion env_1.reset(seed=SEED) env_2.reset(seed=SEED) diff --git a/tests/envs/test_mujoco.py b/tests/envs/test_mujoco.py index 894c74566e6..a26f35056af 100644 --- a/tests/envs/test_mujoco.py +++ b/tests/envs/test_mujoco.py @@ -24,17 +24,22 @@ def verify_environments_match( for i in range(num_actions): action = old_env.action_space.sample() - old_obs, old_reward, old_done, old_info = old_env.step(action) - new_obs, new_reward, new_done, new_info = new_env.step(action) + old_obs, old_reward, old_terminated, old_truncated, old_info = old_env.step( + action + ) + new_obs, new_reward, new_terminated, new_truncated, new_info = new_env.step( + action + ) np.testing.assert_allclose(old_obs, new_obs, atol=EPS) np.testing.assert_allclose(old_reward, new_reward, atol=EPS) - np.testing.assert_equal(old_done, new_done) + np.testing.assert_equal(old_terminated, new_terminated) + np.testing.assert_equal(old_truncated, new_truncated) for key in old_info: np.testing.assert_allclose(old_info[key], new_info[key], atol=EPS) - if old_done: + if old_terminated or old_truncated: break @@ -62,7 +67,7 @@ def test_obs_space_mujoco_environments(env_spec: EnvSpec): ), f"Obseravtion returned by reset() of {env_spec.id} is not contained in the default observation space {env.observation_space}." action = env.action_space.sample() - step_obs, _, _, _ = env.step(action) + step_obs, _, _, _, _ = env.step(action) assert env.observation_space.contains( step_obs ), f"Obseravtion returned by step(action) of {env_spec.id} is not contained in the default observation space {env.observation_space}." @@ -78,7 +83,7 @@ def test_obs_space_mujoco_environments(env_spec: EnvSpec): reset_obs ), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when excluding current position from observation." - step_obs, _, _, _ = env.step(action) + step_obs, _, _, _, _ = env.step(action) assert env.observation_space.contains( step_obs ), f"Obseravtion returned by step(action) of {env_spec.id} is not contained in the default observation space {env.observation_space} when excluding current position from observation." @@ -91,7 +96,7 @@ def test_obs_space_mujoco_environments(env_spec: EnvSpec): reset_obs ), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when using contact forces." - step_obs, _, _, _ = env.step(action) + step_obs, _, _, _, _ = env.step(action) assert env.observation_space.contains( step_obs ), f"Obseravtion returned by step(action) of {env_spec.id} is not contained in the default observation space {env.observation_space} when using contact forces." diff --git a/tests/utils/test_terminated_truncated.py b/tests/utils/test_terminated_truncated.py index e74fdc85378..30014bd995b 100644 --- a/tests/utils/test_terminated_truncated.py +++ b/tests/utils/test_terminated_truncated.py @@ -29,7 +29,7 @@ def reset(self): @pytest.mark.parametrize("time_limit", [10, 20, 30]) def test_terminated_truncated(time_limit): - test_env = TimeLimit(DummyEnv(), time_limit, new_step_api=True) + test_env = TimeLimit(DummyEnv(), time_limit) terminated = False truncated = False @@ -53,13 +53,11 @@ def test_terminated_truncated(time_limit): def test_terminated_truncated_vector(): - env0 = TimeLimit(DummyEnv(), 10, new_step_api=True) - env1 = TimeLimit(DummyEnv(), 20, new_step_api=True) - env2 = TimeLimit(DummyEnv(), 30, new_step_api=True) + env0 = TimeLimit(DummyEnv(), 10) + env1 = TimeLimit(DummyEnv(), 20) + env2 = TimeLimit(DummyEnv(), 30) - async_env = AsyncVectorEnv( - [lambda: env0, lambda: env1, lambda: env2], new_step_api=True - ) + async_env = AsyncVectorEnv([lambda: env0, lambda: env1, lambda: env2]) async_env.reset() terminateds = [False, False, False] truncateds = [False, False, False] @@ -74,9 +72,7 @@ def test_terminated_truncated_vector(): assert all(terminateds == [False, True, True]) assert all(truncateds == [True, True, False]) - sync_env = SyncVectorEnv( - [lambda: env0, lambda: env1, lambda: env2], new_step_api=True - ) + sync_env = SyncVectorEnv([lambda: env0, lambda: env1, lambda: env2]) sync_env.reset() terminateds = [False, False, False] truncateds = [False, False, False] diff --git a/tests/vector/test_async_vector_env.py b/tests/vector/test_async_vector_env.py index 46ea85d73ed..9cd67c4b508 100644 --- a/tests/vector/test_async_vector_env.py +++ b/tests/vector/test_async_vector_env.py @@ -80,7 +80,7 @@ def test_step_async_vector_env(shared_memory, use_single_action_space): actions = [env.single_action_space.sample() for _ in range(8)] else: actions = env.action_space.sample() - observations, rewards, dones, _ = env.step(actions) + observations, rewards, terminateds, truncateds, _ = env.step(actions) env.close() @@ -95,10 +95,15 @@ def test_step_async_vector_env(shared_memory, use_single_action_space): assert rewards.ndim == 1 assert rewards.size == 8 - assert isinstance(dones, np.ndarray) - assert dones.dtype == np.bool_ - assert dones.ndim == 1 - assert dones.size == 8 + assert isinstance(terminateds, np.ndarray) + assert terminateds.dtype == np.bool_ + assert terminateds.ndim == 1 + assert terminateds.size == 8 + + assert isinstance(truncateds, np.ndarray) + assert truncateds.dtype == np.bool_ + assert truncateds.ndim == 1 + assert truncateds.size == 8 @pytest.mark.parametrize("shared_memory", [True, False]) @@ -181,7 +186,7 @@ def test_step_timeout_async_vector_env(shared_memory): with pytest.raises(TimeoutError): env.reset() env.step_async(np.array([0.1, 0.1, 0.3, 0.1])) - observations, rewards, dones, _ = env.step_wait(timeout=0.1) + observations, rewards, terminateds, truncateds, _ = env.step_wait(timeout=0.1) env.close(terminate=True) @@ -274,7 +279,7 @@ def test_custom_space_async_vector_env(): assert isinstance(env.action_space, Tuple) actions = ("action-2", "action-3", "action-5", "action-7") - step_observations, rewards, dones, _ = env.step(actions) + step_observations, rewards, terminateds, truncateds, _ = env.step(actions) env.close() diff --git a/tests/vector/test_step_compatibility_vector.py b/tests/vector/test_step_compatibility_vector.py deleted file mode 100644 index d0305300fc7..00000000000 --- a/tests/vector/test_step_compatibility_vector.py +++ /dev/null @@ -1,88 +0,0 @@ -import numpy as np -import pytest - -import gym -from gym.spaces import Discrete -from gym.vector import AsyncVectorEnv, SyncVectorEnv - - -class OldStepEnv(gym.Env): - def __init__(self): - self.action_space = Discrete(2) - self.observation_space = Discrete(2) - - def reset(self): - return 0 - - def step(self, action): - obs = self.observation_space.sample() - rew = 0 - done = False - info = {} - return obs, rew, done, info - - -class NewStepEnv(gym.Env): - def __init__(self): - self.action_space = Discrete(2) - self.observation_space = Discrete(2) - - def reset(self): - return 0 - - def step(self, action): - obs = self.observation_space.sample() - rew = 0 - terminated = False - truncated = False - info = {} - return obs, rew, terminated, truncated, info - - -@pytest.mark.parametrize("VecEnv", [AsyncVectorEnv, SyncVectorEnv]) -def test_vector_step_compatibility_new_env(VecEnv): - - envs = [ - OldStepEnv(), - NewStepEnv(), - ] - - vec_env = VecEnv([lambda: env for env in envs]) - vec_env.reset() - step_returns = vec_env.step([0, 0]) - assert len(step_returns) == 4 - _, _, dones, _ = step_returns - assert dones.dtype == np.bool_ - vec_env.close() - - vec_env = VecEnv([lambda: env for env in envs], new_step_api=True) - vec_env.reset() - step_returns = vec_env.step([0, 0]) - assert len(step_returns) == 5 - _, _, terminateds, truncateds, _ = step_returns - assert terminateds.dtype == np.bool_ - assert truncateds.dtype == np.bool_ - vec_env.close() - - -@pytest.mark.parametrize("async_bool", [True, False]) -def test_vector_step_compatibility_existing(async_bool): - - env = gym.vector.make("CartPole-v1", num_envs=3, asynchronous=async_bool) - env.reset() - step_returns = env.step(env.action_space.sample()) - assert len(step_returns) == 4 - _, _, dones, _ = step_returns - assert dones.dtype == np.bool_ - env.close() - - env = gym.vector.make( - "CartPole-v1", num_envs=3, asynchronous=async_bool, new_step_api=True - ) - env.reset() - step_returns = env.step(env.action_space.sample()) - assert len(step_returns) == 5 - _, _, terminateds, truncateds, _ = step_returns - assert terminateds.dtype == np.bool_ - assert truncateds.dtype == np.bool_ - env.close() diff --git a/tests/vector/test_sync_vector_env.py b/tests/vector/test_sync_vector_env.py index eb01c5edbd6..4653ba0344a 100644 --- a/tests/vector/test_sync_vector_env.py +++ b/tests/vector/test_sync_vector_env.py @@ -76,7 +76,7 @@ def test_step_sync_vector_env(use_single_action_space): actions = [env.single_action_space.sample() for _ in range(8)] else: actions = env.action_space.sample() - observations, rewards, dones, _ = env.step(actions) + observations, rewards, terminateds, truncateds, _ = env.step(actions) env.close() @@ -91,10 +91,15 @@ def test_step_sync_vector_env(use_single_action_space): assert rewards.ndim == 1 assert rewards.size == 8 - assert isinstance(dones, np.ndarray) - assert dones.dtype == np.bool_ - assert dones.ndim == 1 - assert dones.size == 8 + assert isinstance(terminateds, np.ndarray) + assert terminateds.dtype == np.bool_ + assert terminateds.ndim == 1 + assert terminateds.size == 8 + + assert isinstance(truncateds, np.ndarray) + assert truncateds.dtype == np.bool_ + assert truncateds.ndim == 1 + assert truncateds.size == 8 def test_call_sync_vector_env(): @@ -151,7 +156,7 @@ def test_custom_space_sync_vector_env(): assert isinstance(env.action_space, Tuple) actions = ("action-2", "action-3", "action-5", "action-7") - step_observations, rewards, dones, _ = env.step(actions) + step_observations, rewards, terminateds, truncateds, _ = env.step(actions) env.close() diff --git a/tests/vector/test_vector_env.py b/tests/vector/test_vector_env.py index d74e646bedc..7c445ba64d2 100644 --- a/tests/vector/test_vector_env.py +++ b/tests/vector/test_vector_env.py @@ -31,11 +31,11 @@ def test_vector_env_equal(shared_memory): assert actions in sync_env.action_space # fmt: off - async_observations, async_rewards, async_dones, async_infos = async_env.step(actions) - sync_observations, sync_rewards, sync_dones, sync_infos = sync_env.step(actions) + async_observations, async_rewards, async_terminateds, async_truncateds, async_infos = async_env.step(actions) + sync_observations, sync_rewards, sync_terminateds, sync_truncateds, sync_infos = sync_env.step(actions) # fmt: on - if any(sync_dones): + if any(sync_terminateds) or any(sync_truncateds): assert "final_observation" in async_infos assert "_final_observation" in async_infos assert "final_observation" in sync_infos @@ -43,7 +43,8 @@ def test_vector_env_equal(shared_memory): assert np.all(async_observations == sync_observations) assert np.all(async_rewards == sync_rewards) - assert np.all(async_dones == sync_dones) + assert np.all(async_terminateds == sync_terminateds) + assert np.all(async_truncateds == sync_truncateds) async_env.close() sync_env.close() diff --git a/tests/vector/test_vector_env_info.py b/tests/vector/test_vector_env_info.py index 33849bdeca3..989673e5a04 100644 --- a/tests/vector/test_vector_env_info.py +++ b/tests/vector/test_vector_env_info.py @@ -20,16 +20,16 @@ def test_vector_env_info(asynchronous): for _ in range(ENV_STEPS): env.action_space.seed(SEED) action = env.action_space.sample() - _, _, dones, infos = env.step(action) - if any(dones): + _, _, terminateds, truncateds, infos = env.step(action) + if any(terminateds) or any(truncateds): assert len(infos["final_observation"]) == NUM_ENVS assert len(infos["_final_observation"]) == NUM_ENVS assert isinstance(infos["final_observation"], np.ndarray) assert isinstance(infos["_final_observation"], np.ndarray) - for i, done in enumerate(dones): - if done: + for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): + if terminated or truncated: assert infos["_final_observation"][i] else: assert not infos["_final_observation"][i] @@ -44,11 +44,11 @@ def test_vector_env_info_concurrent_termination(concurrent_ends): envs = SyncVectorEnv(envs) for _ in range(ENV_STEPS): - _, _, dones, infos = envs.step(actions) - if any(dones): - for i, done in enumerate(dones): + _, _, terminateds, truncateds, infos = envs.step(actions) + if any(terminateds) or any(truncateds): + for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): if i < concurrent_ends: - assert done + assert terminated or truncated assert infos["_final_observation"][i] else: assert not infos["_final_observation"][i] diff --git a/tests/wrappers/test_atari_preprocessing.py b/tests/wrappers/test_atari_preprocessing.py index e083d3f0ed5..a34dbcf0674 100644 --- a/tests/wrappers/test_atari_preprocessing.py +++ b/tests/wrappers/test_atari_preprocessing.py @@ -91,7 +91,7 @@ def test_atari_preprocessing_grayscale(env, obs_shape): obs, _ = env.reset(seed=0, return_info=True) assert obs in env.observation_space - obs, _, _, _ = env.step(env.action_space.sample()) + obs, _, _, _, _ = env.step(env.action_space.sample()) assert obs in env.observation_space env.close() @@ -115,9 +115,9 @@ def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10): max_obs = 1 if scaled else 255 assert np.all(0 <= obs) and np.all(obs <= max_obs) - done, step_i = False, 0 - while not done and step_i <= max_test_steps: - obs, _, done, _ = env.step(env.action_space.sample()) + terminated, truncated, step_i = False, False, 0 + while not (terminated or truncated)) and step_i <= max_test_steps: + obs, _, terminated, truncated, _ = env.step(env.action_space.sample()) assert np.all(0 <= obs) and np.all(obs <= max_obs) step_i += 1 diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index e4ed3f9b593..541e1d30900 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -80,8 +80,8 @@ def test_make_autoreset_true(spec): env.unwrapped.reset = MagicMock(side_effect=env.unwrapped.reset) done = False - while not done: - obs, reward, done, info = env.step(env.action_space.sample()) + while not (terminated or truncated): + obs, reward, terminated, truncated, info = env.step(env.action_space.sample()) assert env.unwrapped.reset.called env.close() @@ -118,21 +118,21 @@ def test_autoreset_wrapper_autoreset(): assert info == {"count": 0} action = 0 - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) assert obs == np.array([1]) assert reward == 0 - assert done is False + assert (terminated or truncated) is False assert info == {"count": 1} - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) assert obs == np.array([2]) - assert done is False + assert (terminated or truncated) is False assert reward == 0 assert info == {"count": 2} - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) assert obs == np.array([0]) - assert done is True + assert (terminated or truncated) is True assert reward == 1 assert info == { "count": 0, @@ -140,10 +140,10 @@ def test_autoreset_wrapper_autoreset(): "final_info": {"count": 3}, } - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) assert obs == np.array([1]) assert reward == 0 - assert done is False + assert (terminated or truncated) is False assert info == {"count": 1} env.close() diff --git a/tests/wrappers/test_clip_action.py b/tests/wrappers/test_clip_action.py index 8614c104b4f..1696034dda7 100644 --- a/tests/wrappers/test_clip_action.py +++ b/tests/wrappers/test_clip_action.py @@ -18,10 +18,11 @@ def test_clip_action(): actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]] for action in actions: - obs1, r1, d1, _ = env.step( + obs1, r1, ter1, trunc1, _ = env.step( np.clip(action, env.action_space.low, env.action_space.high) ) - obs2, r2, d2, _ = wrapped_env.step(action) + obs2, r2, ter2, trunc2, _ = wrapped_env.step(action) assert np.allclose(r1, r2) assert np.allclose(obs1, obs2) - assert d1 == d2 + assert ter1 == ter2 + assert trunc1 == trunc2 diff --git a/tests/wrappers/test_frame_stack.py b/tests/wrappers/test_frame_stack.py index 8c4ed0664de..b07a164bbc1 100644 --- a/tests/wrappers/test_frame_stack.py +++ b/tests/wrappers/test_frame_stack.py @@ -39,13 +39,14 @@ def test_frame_stack(env_id, num_stack, lz4_compress): for _ in range(num_stack**2): action = env.action_space.sample() - dup_obs, _, dup_done, _ = dup.step(action) - obs, _, done, _ = env.step(action) + dup_obs, _, dup_terminated, dup_truncated, _ = dup.step(action) + obs, _, terminated, truncated, _ = env.step(action) - assert dup_done == done + assert dup_terminated == terminated + assert dup_truncated == truncated assert np.allclose(obs[-1], dup_obs) - if done: + if terminated or truncated: break assert len(obs) == num_stack diff --git a/tests/wrappers/test_human_rendering.py b/tests/wrappers/test_human_rendering.py index ae34acf6a54..8910d186e7f 100644 --- a/tests/wrappers/test_human_rendering.py +++ b/tests/wrappers/test_human_rendering.py @@ -15,8 +15,8 @@ def test_human_rendering(): env.reset() for _ in range(75): - _, _, done, _ = env.step(env.action_space.sample()) - if done: + _, _, terminated, truncated, _ = env.step(env.action_space.sample()) + if terminated or truncated: env.reset() env.close() diff --git a/tests/wrappers/test_record_episode_statistics.py b/tests/wrappers/test_record_episode_statistics.py index c3b2d3c247d..483bf0cd486 100644 --- a/tests/wrappers/test_record_episode_statistics.py +++ b/tests/wrappers/test_record_episode_statistics.py @@ -18,8 +18,8 @@ def test_record_episode_statistics(env_id, deque_size): assert env.episode_returns[0] == 0.0 assert env.episode_lengths[0] == 0 for t in range(env.spec.max_episode_steps): - _, _, done, info = env.step(env.action_space.sample()) - if done: + _, _, terminated, truncated, info = env.step(env.action_space.sample()) + if terminated or truncated: assert "episode" in info assert all([item in info["episode"] for item in ["r", "l", "t"]]) break diff --git a/tests/wrappers/test_record_video.py b/tests/wrappers/test_record_video.py index 81a2e148126..43c7e24a3d2 100644 --- a/tests/wrappers/test_record_video.py +++ b/tests/wrappers/test_record_video.py @@ -11,8 +11,8 @@ def test_record_video_using_default_trigger(): env.reset() for _ in range(199): action = env.action_space.sample() - _, _, done, _ = env.step(action) - if done: + _, _, terminated, truncated, _ = env.step(action) + if terminated or truncated: env.reset() env.close() assert os.path.isdir("videos") @@ -60,8 +60,8 @@ def test_record_video_step_trigger(): env.reset() for _ in range(199): action = env.action_space.sample() - _, _, done, _ = env.step(action) - if done: + _, _, terminated, truncated, _ = env.step(action) + if terminated or truncated: env.reset() env.close() assert os.path.isdir("videos") diff --git a/tests/wrappers/test_rescale_action.py b/tests/wrappers/test_rescale_action.py index 45b9b1ff1cd..0e7037a1be7 100644 --- a/tests/wrappers/test_rescale_action.py +++ b/tests/wrappers/test_rescale_action.py @@ -22,10 +22,10 @@ def test_rescale_action(): wrapped_obs = wrapped_env.reset(seed=seed) assert np.allclose(obs, wrapped_obs) - obs, reward, _, _ = env.step([1.5]) + obs, reward, _, _, _ = env.step([1.5]) with pytest.raises(AssertionError): wrapped_env.step([1.5]) - wrapped_obs, wrapped_reward, _, _ = wrapped_env.step([0.75]) + wrapped_obs, wrapped_reward, _, _, _ = wrapped_env.step([0.75]) assert np.allclose(obs, wrapped_obs) assert np.allclose(reward, wrapped_reward) diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py index 83557f02db6..996da50b2a3 100644 --- a/tests/wrappers/test_step_compatibility.py +++ b/tests/wrappers/test_step_compatibility.py @@ -33,8 +33,12 @@ def step(self, action): @pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) -def test_step_compatibility_to_new_api(env): - env = StepAPICompatibility(env(), True) +@pytest.mark.parametrize("new_step_api", [None, False]) +def test_step_compatibility_to_new_api(env, new_step_api): + if new_step_api is None: + env = StepAPICompatibility(env()) + else: + env = StepAPICompatibility(env(), new_step_api) step_returns = env.step(0) _, _, terminated, truncated, _ = step_returns assert isinstance(terminated, bool) @@ -42,12 +46,8 @@ def test_step_compatibility_to_new_api(env): @pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) -@pytest.mark.parametrize("new_step_api", [None, False]) -def test_step_compatibility_to_old_api(env, new_step_api): - if new_step_api is None: - env = StepAPICompatibility(env()) # default behavior is to retain old API - else: - env = StepAPICompatibility(env(), new_step_api) +def test_step_compatibility_to_old_api(env): + env = StepAPICompatibility(env(), False) step_returns = env.step(0) assert len(step_returns) == 4 _, _, done, _ = step_returns @@ -56,11 +56,13 @@ def test_step_compatibility_to_old_api(env, new_step_api): @pytest.mark.parametrize("new_step_api", [None, True, False]) def test_step_compatibility_in_make(new_step_api): - if new_step_api is None: + if new_step_api is False: with pytest.warns( DeprecationWarning, match="Initializing environment in old step API" ): - env = gym.make("CartPole-v1") + env = gym.make("CartPole-v1", new_step_api=False) + elif new_step_api is None: + env = gym.make("CartPole-v1") else: env = gym.make("CartPole-v1", new_step_api=new_step_api) diff --git a/tests/wrappers/test_time_aware_observation.py b/tests/wrappers/test_time_aware_observation.py index 7588089c329..e3733ecac27 100644 --- a/tests/wrappers/test_time_aware_observation.py +++ b/tests/wrappers/test_time_aware_observation.py @@ -20,12 +20,12 @@ def test_time_aware_observation(env_id): assert wrapped_obs[-1] == 0.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 - wrapped_obs, _, _, _ = wrapped_env.step(env.action_space.sample()) + wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample()) assert wrapped_env.t == 1.0 assert wrapped_obs[-1] == 1.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 - wrapped_obs, _, _, _ = wrapped_env.step(env.action_space.sample()) + wrapped_obs, _,_, _, _ = wrapped_env.step(env.action_space.sample()) assert wrapped_env.t == 2.0 assert wrapped_obs[-1] == 2.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 diff --git a/tests/wrappers/test_time_limit.py b/tests/wrappers/test_time_limit.py index e4732db5984..01ea4723e40 100644 --- a/tests/wrappers/test_time_limit.py +++ b/tests/wrappers/test_time_limit.py @@ -33,12 +33,12 @@ def test_time_limit_wrapper(double_wrap): # if it was already set env = TimeLimit(env, max_episode_length) env.reset() - done = False + terminated, truncated = False, False n_steps = 0 info = {} - while not done: + while not terminated or truncated: n_steps += 1 - _, _, done, info = env.step(env.action_space.sample()) + _, _, terminated, truncated, info = env.step(env.action_space.sample()) assert n_steps == max_episode_length assert "TimeLimit.truncated" in info @@ -61,7 +61,8 @@ def patched_step(_action): if double_wrap: env = TimeLimit(env, max_episode_length) env.reset() - _, _, done, info = env.step(env.action_space.sample()) - assert done is True - assert "TimeLimit.truncated" in info + _, _, terminated, truncated, info = env.step(env.action_space.sample()) + assert terminated is True + assert truncated is True + assert "TimeLimit.truncated" in info # part of old API but retained assert info["TimeLimit.truncated"] is False diff --git a/tests/wrappers/test_transform_observation.py b/tests/wrappers/test_transform_observation.py index 949302770d5..0363212046a 100644 --- a/tests/wrappers/test_transform_observation.py +++ b/tests/wrappers/test_transform_observation.py @@ -20,8 +20,15 @@ def affine_transform(x): assert np.allclose(wrapped_obs, affine_transform(obs)) action = env.action_space.sample() - obs, reward, done, _ = env.step(action) - wrapped_obs, wrapped_reward, wrapped_done, _ = wrapped_env.step(action) + obs, reward, terminated, truncated, _ = env.step(action) + ( + wrapped_obs, + wrapped_reward, + wrapped_terminated, + wrapped_truncated, + _, + ) = wrapped_env.step(action) assert np.allclose(wrapped_obs, affine_transform(obs)) assert np.allclose(wrapped_reward, reward) - assert wrapped_done == done + assert wrapped_terminated == terminated + assert wrapped_truncated == truncated diff --git a/tests/wrappers/test_transform_reward.py b/tests/wrappers/test_transform_reward.py index e3140ae0cf5..2687f7f8a5e 100644 --- a/tests/wrappers/test_transform_reward.py +++ b/tests/wrappers/test_transform_reward.py @@ -19,8 +19,8 @@ def test_transform_reward(env_id): env.reset(seed=0) wrapped_env.reset(seed=0) - _, reward, _, _ = env.step(action) - _, wrapped_reward, _, _ = wrapped_env.step(action) + _, reward, _, _, _ = env.step(action) + _, wrapped_reward, _, _, _ = wrapped_env.step(action) assert wrapped_reward == scale * reward del env, wrapped_env @@ -37,8 +37,8 @@ def test_transform_reward(env_id): env.reset(seed=0) wrapped_env.reset(seed=0) - _, reward, _, _ = env.step(action) - _, wrapped_reward, _, _ = wrapped_env.step(action) + _, reward, _, _, _ = env.step(action) + _, wrapped_reward, _, _, _ = wrapped_env.step(action) assert abs(wrapped_reward) < abs(reward) assert wrapped_reward == -0.0005 or wrapped_reward == 0.0002 @@ -55,8 +55,8 @@ def test_transform_reward(env_id): for _ in range(1000): action = env.action_space.sample() - _, wrapped_reward, done, _ = wrapped_env.step(action) + _, wrapped_reward, terminated, truncated, _ = wrapped_env.step(action) assert wrapped_reward in [-1.0, 0.0, 1.0] - if done: + if terminated or truncated: break del env, wrapped_env diff --git a/tests/wrappers/test_vector_list_info.py b/tests/wrappers/test_vector_list_info.py index 26c6e772876..d410746f1fd 100644 --- a/tests/wrappers/test_vector_list_info.py +++ b/tests/wrappers/test_vector_list_info.py @@ -29,9 +29,9 @@ def test_info_to_list(): for _ in range(ENV_STEPS): action = wrapped_env.action_space.sample() - _, _, dones, list_info = wrapped_env.step(action) - for i, done in enumerate(dones): - if done: + _, _, terminateds, truncateds, list_info = wrapped_env.step(action) + for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): + if terminated or truncated: assert "final_observation" in list_info[i] else: assert "final_observation" not in list_info[i] @@ -47,9 +47,9 @@ def test_info_to_list_statistics(): for _ in range(ENV_STEPS): action = wrapped_env.action_space.sample() - _, _, dones, list_info = wrapped_env.step(action) - for i, done in enumerate(dones): - if done: + _, _, terminateds, truncateds, list_info = wrapped_env.step(action) + for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): + if terminated or truncated: assert "episode" in list_info[i] for stats in ["r", "l", "t"]: assert stats in list_info[i]["episode"] From 515fe5596b8807c20e1ad77146e543d75520addc Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Wed, 10 Aug 2022 23:42:48 +0530 Subject: [PATCH 03/12] change param name, other fixes --- gym/envs/registration.py | 11 +++++----- gym/utils/play.py | 18 +++++++-------- gym/utils/step_api_compatibility.py | 16 +++++++------- gym/wrappers/record_video.py | 2 +- gym/wrappers/step_api_compatibility.py | 18 +++++++-------- gym/wrappers/time_limit.py | 14 ++++-------- tests/envs/test_envs.py | 6 ++--- tests/utils/test_play.py | 18 +++++++-------- tests/vector/utils.py | 8 +++---- tests/wrappers/test_atari_preprocessing.py | 16 +++++++++----- tests/wrappers/test_autoreset.py | 9 ++++---- tests/wrappers/test_normalize.py | 16 +++++++++----- .../test_record_episode_statistics.py | 6 ++--- tests/wrappers/test_step_compatibility.py | 22 +++++++++---------- tests/wrappers/test_time_aware_observation.py | 2 +- tests/wrappers/test_time_limit.py | 15 +++++-------- 16 files changed, 97 insertions(+), 100 deletions(-) diff --git a/gym/envs/registration.py b/gym/envs/registration.py index 5846478bc27..0e4878bc266 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -140,7 +140,7 @@ class EnvSpec: order_enforce: bool = field(default=True) autoreset: bool = field(default=False) disable_env_checker: bool = field(default=False) - new_step_api: bool = field(default=False) + apply_step_compatibility: bool = field(default=False) # Environment arguments kwargs: dict = field(default_factory=dict) @@ -547,7 +547,7 @@ def make( id: Union[str, EnvSpec], max_episode_steps: Optional[int] = None, autoreset: bool = False, - new_step_api: bool = True, + apply_step_compatibility: bool = False, disable_env_checker: Optional[bool] = None, **kwargs, ) -> Env: @@ -557,7 +557,7 @@ def make( id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0' max_episode_steps: Maximum length of an episode (TimeLimit wrapper). autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper). - new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper) + apply_step_compatibility: Whether to use apply compatibility wrapper that converts step method to return two bools (StepAPICompatibility wrapper) disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker` (which is by default False, running the environment checker), otherwise will run according to this parameter (`True` = not run, `False` = run) @@ -684,7 +684,6 @@ def make( ): env = PassiveEnvChecker(env) - # Add the order enforcing wrapper if spec_.order_enforce: env = OrderEnforcing(env) @@ -704,8 +703,8 @@ def make( env = HumanRendering(env) # Add step API wrapper - if not new_step_api: - env = StepAPICompatibility(env, new_step_api) + if apply_step_compatibility: + env = StepAPICompatibility(env, True) return env diff --git a/gym/utils/play.py b/gym/utils/play.py index 5793021c2a1..a187c4165fc 100644 --- a/gym/utils/play.py +++ b/gym/utils/play.py @@ -211,10 +211,6 @@ def play( seed: Random seed used when resetting the environment. If None, no seed is used. noop: The action used when no key input has been entered, or the entered key combination is unknown. """ - deprecation( - "`play.py` currently supports only the old step API which returns one boolean, however this will soon be updated to support only the new step api that returns two bools." - ) - env.reset(seed=seed) if keys_to_action is None: @@ -251,9 +247,9 @@ def play( else: action = key_code_to_action.get(tuple(sorted(game.pressed_keys)), noop) prev_obs = obs - obs, rew, done, info = env.step(action) + obs, rew, terminated, truncated, info = env.step(action) if callback is not None: - callback(prev_obs, obs, action, rew, done, info) + callback(prev_obs, obs, action, rew, terminated, truncated, info) if obs is not None: # TODO: this needs to be updated when the render API change goes through rendered = env.render(mode="rgb_array") @@ -341,7 +337,8 @@ def callback( obs_tp1: ObsType, action: ActType, rew: float, - done: bool, + terminated: bool, + truncated: bool, info: dict, ): """The callback that calls the provided data callback and adds the data to the plots. @@ -351,10 +348,13 @@ def callback( obs_tp1: The observation at time step t+1 action: The action rew: The reward - done: If the environment is done + terminated: If the environment is terminated + truncated: If the environment is truncated info: The information from the environment """ - points = self.data_callback(obs_t, obs_tp1, action, rew, done, info) + points = self.data_callback( + obs_t, obs_tp1, action, rew, terminated, truncated, info + ) for point, data_series in zip(points, self.data): data_series.append(point) self.t += 1 diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index c1a0a8c27f8..a69784e2885 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -1,4 +1,4 @@ -"""Contains methods for step compatibility, from old-to-new and new-to-old API""" +"""Contains methods for step compatibility, from old-to-new and new-to-old API.""" from typing import Tuple, Union import numpy as np @@ -149,10 +149,10 @@ def step_to_old_api( def step_api_compatibility( step_returns: Union[NewStepType, OldStepType], - new_step_api: bool = True, + output_truncation_bool: bool = True, is_vector_env: bool = False, ) -> Union[NewStepType, OldStepType]: - """Function to transform step returns to the API specified by `new_step_api` bool. + """Function to transform step returns to the API specified by `output_truncation_bool` bool. Old step API refers to step() method returning (observation, reward, done, info) New step API refers to step() method returning (observation, reward, terminated, truncated, info) @@ -160,21 +160,21 @@ def step_api_compatibility( Args: step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) - new_step_api (bool): Whether the output should be in new step API or old (True by default) + output_truncation_bool (bool): Whether the output should return two booleans (new API) or one (old) (True by default) is_vector_env (bool): Whether the step_returns are from a vector environment Returns: - step_returns (tuple): Depending on `new_step_api` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info) + step_returns (tuple): Depending on `output_truncation_bool` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info) Examples: This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API, wrapper is written in new API, and the final step output is desired to be in old API. - >>> obs, rew, done, info = step_api_compatibility(env.step(action), new_step_api=False) - >>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), new_step_api=True) + >>> obs, rew, done, info = step_api_compatibility(env.step(action), output_truncation_bool=False) + >>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), output_truncation_bool=True) >>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True) """ - if new_step_api: + if output_truncation_bool: return step_to_new_api(step_returns, is_vector_env) else: return step_to_old_api(step_returns, is_vector_env) diff --git a/gym/wrappers/record_video.py b/gym/wrappers/record_video.py index e2fb59f0dec..d3e92b5348a 100644 --- a/gym/wrappers/record_video.py +++ b/gym/wrappers/record_video.py @@ -140,7 +140,7 @@ def step(self, action): terminateds, truncateds, infos, - ) =self.env.step(action) + ) = self.env.step(action) if not (self.terminated or self.truncated): # increment steps and episodes diff --git a/gym/wrappers/step_api_compatibility.py b/gym/wrappers/step_api_compatibility.py index 8f67167254e..661fb8a5b8e 100644 --- a/gym/wrappers/step_api_compatibility.py +++ b/gym/wrappers/step_api_compatibility.py @@ -13,34 +13,34 @@ class StepAPICompatibility(gym.Wrapper): Args: env (gym.Env): the env to wrap. Can be in old or new API - new_step_api (bool): True to use env with new step API, False to use env with old step API. (True by default) + apply_step_compatibility (bool): Apply to convert environment to use new step API that returns two bools. (False by default) Examples: >>> env = gym.make("CartPole-v1") >>> env # wrapper not applied by default, set to new API >>>> - >>> env = gym.make("CartPole-v1", new_step_api=False) # set to old API + >>> env = gym.make("CartPole-v1", apply_step_compatibility=True) # set to old API >>>>> - >>> env = StepAPICompatibility(CustomEnv(), new_step_api=True) # manually using wrapper on unregistered envs + >>> env = StepAPICompatibility(CustomEnv(), apply_step_compatibility=False) # manually using wrapper on unregistered envs """ - def __init__(self, env: gym.Env, new_step_api=True): + def __init__(self, env: gym.Env, output_truncation_bool: bool = True): """A wrapper which can transform an environment from new step API to old and vice-versa. Args: env (gym.Env): the env to wrap. Can be in old or new API - new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) + output_truncation_bool (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ super().__init__(env) - self.new_step_api = new_step_api - if not self.new_step_api: + self.output_truncation_bool = output_truncation_bool + if not self.output_truncation_bool: deprecation( "Initializing environment in old step API which returns one bool instead of two." ) def step(self, action): - """Steps through the environment, returning 5 or 4 items depending on `new_step_api`. + """Steps through the environment, returning 5 or 4 items depending on `apply_step_compatibility`. Args: action: action to step through the environment with @@ -49,7 +49,7 @@ def step(self, action): (observation, reward, terminated, truncated, info) or (observation, reward, done, info) """ step_returns = self.env.step(action) - if self.new_step_api: + if self.output_truncation_bool: return step_to_new_api(step_returns) else: return step_to_old_api(step_returns) diff --git a/gym/wrappers/time_limit.py b/gym/wrappers/time_limit.py index 735f3b820cb..854876e3a90 100644 --- a/gym/wrappers/time_limit.py +++ b/gym/wrappers/time_limit.py @@ -10,12 +10,6 @@ class TimeLimit(gym.Wrapper): If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued. Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP. - (deprecated) - This information is passed through ``info`` that is returned when `done`-signal was issued. - The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if - the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``. This will be removed in favour - of only issuing a `truncated` signal in future versions. - Example: >>> from gym.envs.classic_control import CartPoleEnv >>> from gym.wrappers import TimeLimit @@ -49,11 +43,11 @@ def step(self, action): action: The environment step action Returns: - The environment step ``(observation, reward, done, info)`` with "TimeLimit.truncated"=True - when truncated (the number of steps elapsed >= max episode steps) or - "TimeLimit.truncated"=False if the environment terminated + The environment step ``(observation, reward, terminated, truncated, info)`` with `truncated=True` + if the number of steps elapsed >= max episode steps + """ - observation, reward, terminated, truncated, info = (self.env.step(action),) + observation, reward, terminated, truncated, info = self.env.step(action) self._elapsed_steps += 1 if self._elapsed_steps >= self._max_episode_steps: diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 3526ed8cb8c..58ebea64ef7 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -11,8 +11,7 @@ f"\x1b[33mWARN: {message}\x1b[0m" for message in [ "This version of the mujoco environments depends on the mujoco-py bindings, which are no longer maintained and may stop working. Please upgrade to the v4 versions of the environments (which depend on the mujoco python bindings instead), unless you are trying to precisely replicate previous works).", - "Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.", - "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.", + "Initializing environment in old step API which returns one bool instead of two.", ] ] @@ -23,8 +22,7 @@ "A Box observation space minimum value is -infinity. This is probably too low.", "A Box observation space maximum value is -infinity. This is probably too high.", "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.", - "Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.", - "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.", + "Initializing environment in old step API which returns one bool instead of two.", ] ] diff --git a/tests/utils/test_play.py b/tests/utils/test_play.py index cc4171502ee..a6f9d2ecb25 100644 --- a/tests/utils/test_play.py +++ b/tests/utils/test_play.py @@ -24,8 +24,8 @@ class DummyEnvSpec: class DummyPlayEnv(gym.Env): def step(self, action): obs = np.zeros((1, 1)) - rew, done, info = 1, False, {} - return obs, rew, done, info + rew, terminated, truncated, info = 1, False, False, {} + return obs, rew, terminated, truncated, info def reset(self, seed=None): ... @@ -49,9 +49,9 @@ def __init__(self, callback: Callable): self.cumulative_reward = 0 self.last_observation = None - def callback(self, obs_t, obs_tp1, action, rew, done, info): - _, obs_tp1, _, rew, _, _ = self.data_callback( - obs_t, obs_tp1, action, rew, done, info + def callback(self, obs_t, obs_tp1, action, rew, terminated, truncated, info): + _, obs_tp1, _, rew, _, _, _ = self.data_callback( + obs_t, obs_tp1, action, rew, terminated, truncated, info ) self.cumulative_reward += rew self.last_observation = obs_tp1 @@ -174,7 +174,7 @@ def test_play_loop_real_env(): ] keydown_events = [k for k in callback_events if k.type == KEYDOWN] - def callback(obs_t, obs_tp1, action, rew, done, info): + def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): pygame_event = callback_events.pop(0) event.post(pygame_event) @@ -184,7 +184,7 @@ def callback(obs_t, obs_tp1, action, rew, done, info): pygame_event = callback_events.pop(0) event.post(pygame_event) - return obs_t, obs_tp1, action, rew, done, info + return obs_t, obs_tp1, action, rew, terminated, truncated, info env = gym.make(ENV, disable_env_checker=True) env.reset(seed=SEED) @@ -194,10 +194,10 @@ def callback(obs_t, obs_tp1, action, rew, done, info): # first action is 0 because at the first iteration # we can not inject a callback event into play() - obs, _, _, _ = env.step(0) + obs, _, _, _, _ = env.step(0) for e in keydown_events: action = keys_to_action[chr(e.key) if str_keys else (e.key,)] - obs, _, _, _ = env.step(action) + obs, _, _, _, _ = env.step(action) env_play = gym.make(ENV, disable_env_checker=True) if apply_wrapper: diff --git a/tests/vector/utils.py b/tests/vector/utils.py index 99d95ada4a1..2ea7b34c293 100644 --- a/tests/vector/utils.py +++ b/tests/vector/utils.py @@ -68,8 +68,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def step(self, action): time.sleep(action) observation = self.observation_space.sample() - reward, done = 0.0, False - return observation, reward, done, {} + reward, terminated, truncated = 0.0, False, False + return observation, reward, terminated, truncated, {} class CustomSpace(gym.Space): @@ -103,8 +103,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def step(self, action): observation = f"step({action:s})" - reward, done = 0.0, False - return observation, reward, done, {} + reward, terminated, truncated = 0.0, False, False + return observation, reward, terminated, truncated, {} def make_env(env_name, seed, **kwargs): diff --git a/tests/wrappers/test_atari_preprocessing.py b/tests/wrappers/test_atari_preprocessing.py index a34dbcf0674..292e58faa75 100644 --- a/tests/wrappers/test_atari_preprocessing.py +++ b/tests/wrappers/test_atari_preprocessing.py @@ -2,7 +2,7 @@ import pytest from gym.spaces import Box, Discrete -from gym.wrappers import AtariPreprocessing +from gym.wrappers import AtariPreprocessing, StepAPICompatibility from tests.testing_env import GenericTestEnv, old_step_fn @@ -49,7 +49,7 @@ def get_action_meanings(self): (AtariTestingEnv(), (210, 160, 3)), ( AtariPreprocessing( - AtariTestingEnv(), + StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True), screen_size=84, grayscale_obs=True, frame_skip=1, @@ -59,7 +59,7 @@ def get_action_meanings(self): ), ( AtariPreprocessing( - AtariTestingEnv(), + StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True), screen_size=84, grayscale_obs=False, frame_skip=1, @@ -69,7 +69,7 @@ def get_action_meanings(self): ), ( AtariPreprocessing( - AtariTestingEnv(), + StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True), screen_size=84, grayscale_obs=True, frame_skip=1, @@ -86,6 +86,10 @@ def test_atari_preprocessing_grayscale(env, obs_shape): # It is not possible to test the outputs as we are not using actual observations. # todo: update when ale-py is compatible with the ci + env = StepAPICompatibility( + env, output_truncation_bool=True + ) # using compatibility wrapper since ale-py uses old step API + obs = env.reset(seed=0) assert obs in env.observation_space obs, _ = env.reset(seed=0, return_info=True) @@ -102,7 +106,7 @@ def test_atari_preprocessing_grayscale(env, obs_shape): def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10): # arbitrarily chosen number for stepping into env. and ensuring all observations are in the required range env = AtariPreprocessing( - AtariTestingEnv(), + StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True), screen_size=84, grayscale_obs=grayscale, scale_obs=scaled, @@ -116,7 +120,7 @@ def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10): assert np.all(0 <= obs) and np.all(obs <= max_obs) terminated, truncated, step_i = False, False, 0 - while not (terminated or truncated)) and step_i <= max_test_steps: + while not (terminated or truncated) and step_i <= max_test_steps: obs, _, terminated, truncated, _ = env.step(env.action_space.sample()) assert np.all(0 <= obs) and np.all(obs <= max_obs) diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index 541e1d30900..736485af370 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -14,7 +14,7 @@ class DummyResetEnv(gym.Env): """A dummy environment which returns ascending numbers starting at `0` when :meth:`self.step()` is called. - After the second call to :meth:`self.step()` done is true. + After the second call to :meth:`self.step()` terminated is true. Info dicts are also returned containing the same number returned as an observation, accessible via the key "count". This environment is provided for the purpose of testing the autoreset wrapper. """ @@ -30,12 +30,13 @@ def __init__(self): self.count = 0 def step(self, action: int): - """Steps the DummyEnv with the incremented step, reward and done `if self.count > 1` and updated info.""" + """Steps the DummyEnv with the incremented step, reward and terminated `if self.count > 1` and updated info.""" self.count += 1 return ( np.array([self.count]), # Obs self.count > 2, # Reward - self.count > 2, # Done + self.count > 2, # Terminated + False, # Truncated {"count": self.count}, # Info ) @@ -79,7 +80,7 @@ def test_make_autoreset_true(spec): env.reset(seed=0) env.unwrapped.reset = MagicMock(side_effect=env.unwrapped.reset) - done = False + terminated, truncated = False, False while not (terminated or truncated): obs, reward, terminated, truncated, info = env.step(env.action_space.sample()) diff --git a/tests/wrappers/test_normalize.py b/tests/wrappers/test_normalize.py index 13bf32011be..ee73163a1d4 100644 --- a/tests/wrappers/test_normalize.py +++ b/tests/wrappers/test_normalize.py @@ -22,7 +22,13 @@ def __init__(self, return_reward_idx=0): def step(self, action): self.t += 1 - return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {} + return ( + np.array([self.t]), + self.t, + self.t == len(self.returned_rewards), + False, + {}, + ) def reset( self, @@ -94,7 +100,7 @@ def test_normalize_observation_vector_env(): env_fns = [make_env(0), make_env(1)] envs = gym.vector.SyncVectorEnv(env_fns) envs.reset() - obs, reward, _, _ = envs.step(envs.action_space.sample()) + obs, reward, _, _, _ = envs.step(envs.action_space.sample()) np.testing.assert_almost_equal(obs, np.array([[1], [2]]), decimal=4) np.testing.assert_almost_equal(reward, np.array([1, 2]), decimal=4) @@ -107,7 +113,7 @@ def test_normalize_observation_vector_env(): np.mean([0.5]), # the mean of first observations [[0, 1]] decimal=4, ) - obs, reward, _, _ = envs.step(envs.action_space.sample()) + obs, reward, _, _, _ = envs.step(envs.action_space.sample()) assert_almost_equal( envs.obs_rms.mean, np.mean([1.0]), # the mean of first and second observations [[0, 1], [1, 2]] @@ -120,13 +126,13 @@ def test_normalize_return_vector_env(): envs = gym.vector.SyncVectorEnv(env_fns) envs = NormalizeReward(envs) obs = envs.reset() - obs, reward, _, _ = envs.step(envs.action_space.sample()) + obs, reward, _, _, _ = envs.step(envs.action_space.sample()) assert_almost_equal( envs.return_rms.mean, np.mean([1.5]), # the mean of first returns [[1, 2]] decimal=4, ) - obs, reward, _, _ = envs.step(envs.action_space.sample()) + obs, reward, _, _, _ = envs.step(envs.action_space.sample()) assert_almost_equal( envs.return_rms.mean, np.mean( diff --git a/tests/wrappers/test_record_episode_statistics.py b/tests/wrappers/test_record_episode_statistics.py index 483bf0cd486..5107a88b773 100644 --- a/tests/wrappers/test_record_episode_statistics.py +++ b/tests/wrappers/test_record_episode_statistics.py @@ -58,11 +58,11 @@ def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous): ) envs.reset() for _ in range(max_episode_step + 1): - _, _, dones, infos = envs.step(envs.action_space.sample()) - if any(dones): + _, _, terminateds, truncateds, infos = envs.step(envs.action_space.sample()) + if any(terminateds) or any(truncateds): assert "episode" in infos assert "_episode" in infos - assert all(infos["_episode"] == dones) + assert all(infos["_episode"] == np.bitwise_or(terminateds, truncateds)) assert all([item in infos["episode"] for item in ["r", "l", "t"]]) break else: diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py index 996da50b2a3..c44a1780871 100644 --- a/tests/wrappers/test_step_compatibility.py +++ b/tests/wrappers/test_step_compatibility.py @@ -33,12 +33,12 @@ def step(self, action): @pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) -@pytest.mark.parametrize("new_step_api", [None, False]) -def test_step_compatibility_to_new_api(env, new_step_api): - if new_step_api is None: +@pytest.mark.parametrize("output_truncation_bool", [None, True]) +def test_step_compatibility_to_new_api(env, output_truncation_bool): + if output_truncation_bool is None: env = StepAPICompatibility(env()) else: - env = StepAPICompatibility(env(), new_step_api) + env = StepAPICompatibility(env(), output_truncation_bool) step_returns = env.step(0) _, _, terminated, truncated, _ = step_returns assert isinstance(terminated, bool) @@ -54,21 +54,21 @@ def test_step_compatibility_to_old_api(env): assert isinstance(done, bool) -@pytest.mark.parametrize("new_step_api", [None, True, False]) -def test_step_compatibility_in_make(new_step_api): - if new_step_api is False: +@pytest.mark.parametrize("output_truncation_bool", [None, True, False]) +def test_step_compatibility_in_make(output_truncation_bool): + if output_truncation_bool is False: with pytest.warns( DeprecationWarning, match="Initializing environment in old step API" ): - env = gym.make("CartPole-v1", new_step_api=False) - elif new_step_api is None: + env = gym.make("CartPole-v1", output_truncation_bool=False) + elif output_truncation_bool is None: env = gym.make("CartPole-v1") else: - env = gym.make("CartPole-v1", new_step_api=new_step_api) + env = gym.make("CartPole-v1", output_truncation_bool=output_truncation_bool) env.reset() step_returns = env.step(0) - if new_step_api: + if output_truncation_bool: assert len(step_returns) == 5 _, _, terminated, truncated, _ = step_returns assert isinstance(terminated, bool) diff --git a/tests/wrappers/test_time_aware_observation.py b/tests/wrappers/test_time_aware_observation.py index e3733ecac27..b9301c19e34 100644 --- a/tests/wrappers/test_time_aware_observation.py +++ b/tests/wrappers/test_time_aware_observation.py @@ -25,7 +25,7 @@ def test_time_aware_observation(env_id): assert wrapped_obs[-1] == 1.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 - wrapped_obs, _,_, _, _ = wrapped_env.step(env.action_space.sample()) + wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample()) assert wrapped_env.t == 2.0 assert wrapped_obs[-1] == 2.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 diff --git a/tests/wrappers/test_time_limit.py b/tests/wrappers/test_time_limit.py index 01ea4723e40..d887d02f60b 100644 --- a/tests/wrappers/test_time_limit.py +++ b/tests/wrappers/test_time_limit.py @@ -28,9 +28,6 @@ def test_time_limit_wrapper(double_wrap): max_episode_length = 20 env = TimeLimit(env, max_episode_length) if double_wrap: - # TimeLimit wrapper should not overwrite - # the TimeLimit.truncated key - # if it was already set env = TimeLimit(env, max_episode_length) env.reset() terminated, truncated = False, False @@ -41,14 +38,14 @@ def test_time_limit_wrapper(double_wrap): _, _, terminated, truncated, info = env.step(env.action_space.sample()) assert n_steps == max_episode_length - assert "TimeLimit.truncated" in info - assert info["TimeLimit.truncated"] is True + assert truncated @pytest.mark.parametrize("double_wrap", [False, True]) def test_termination_on_last_step(double_wrap): # Special case: termination at the last timestep - # but not due to timeout + # Truncation due to timeout also happens at the same step + env = PendulumEnv() def patched_step(_action): @@ -61,8 +58,6 @@ def patched_step(_action): if double_wrap: env = TimeLimit(env, max_episode_length) env.reset() - _, _, terminated, truncated, info = env.step(env.action_space.sample()) - assert terminated is True + _, _, terminated, truncated, _ = env.step(env.action_space.sample()) + assert terminated is True assert truncated is True - assert "TimeLimit.truncated" in info # part of old API but retained - assert info["TimeLimit.truncated"] is False From af9f0e98fed26737c1a5a1156c8925ff1f44336e Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 11 Aug 2022 07:49:11 +0530 Subject: [PATCH 04/12] fix tests --- tests/wrappers/test_record_video.py | 2 +- tests/wrappers/test_step_compatibility.py | 23 +++++++++++------------ tests/wrappers/test_time_limit.py | 4 ++-- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/tests/wrappers/test_record_video.py b/tests/wrappers/test_record_video.py index 43c7e24a3d2..c25c2bb8b5a 100644 --- a/tests/wrappers/test_record_video.py +++ b/tests/wrappers/test_record_video.py @@ -90,7 +90,7 @@ def test_record_video_within_vector(): envs = gym.wrappers.RecordEpisodeStatistics(envs) envs.reset() for i in range(199): - _, _, _, infos = envs.step(envs.action_space.sample()) + _, _, _, _, infos = envs.step(envs.action_space.sample()) # break when every env is done if "episode" in infos and all(infos["_episode"]): diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py index c44a1780871..1a3a8de15cd 100644 --- a/tests/wrappers/test_step_compatibility.py +++ b/tests/wrappers/test_step_compatibility.py @@ -54,21 +54,20 @@ def test_step_compatibility_to_old_api(env): assert isinstance(done, bool) -@pytest.mark.parametrize("output_truncation_bool", [None, True, False]) -def test_step_compatibility_in_make(output_truncation_bool): - if output_truncation_bool is False: - with pytest.warns( - DeprecationWarning, match="Initializing environment in old step API" - ): - env = gym.make("CartPole-v1", output_truncation_bool=False) - elif output_truncation_bool is None: - env = gym.make("CartPole-v1") - else: - env = gym.make("CartPole-v1", output_truncation_bool=output_truncation_bool) +@pytest.mark.parametrize("apply_step_compatibility", [None, True, False]) +def test_step_compatibility_in_make(apply_step_compatibility): + gym.register("OldStepEnv-v0", entry_point=OldStepEnv) + + if apply_step_compatibility is not None: + env = gym.make( + "OldStepEnv-v0", apply_step_compatibility=apply_step_compatibility + ) + elif apply_step_compatibility is None: + env = gym.make("OldStepEnv-v0") env.reset() step_returns = env.step(0) - if output_truncation_bool: + if apply_step_compatibility: assert len(step_returns) == 5 _, _, terminated, truncated, _ = step_returns assert isinstance(terminated, bool) diff --git a/tests/wrappers/test_time_limit.py b/tests/wrappers/test_time_limit.py index d887d02f60b..b35164b9589 100644 --- a/tests/wrappers/test_time_limit.py +++ b/tests/wrappers/test_time_limit.py @@ -33,7 +33,7 @@ def test_time_limit_wrapper(double_wrap): terminated, truncated = False, False n_steps = 0 info = {} - while not terminated or truncated: + while not (terminated or truncated): n_steps += 1 _, _, terminated, truncated, info = env.step(env.action_space.sample()) @@ -49,7 +49,7 @@ def test_termination_on_last_step(double_wrap): env = PendulumEnv() def patched_step(_action): - return env.observation_space.sample(), 0.0, True, {} + return env.observation_space.sample(), 0.0, True, False, {} env.step = patched_step From be7fa6f000352c5edbef445170fbc4dee4f3ed6b Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Wed, 17 Aug 2022 15:15:40 +0530 Subject: [PATCH 05/12] language changes --- gym/utils/step_api_compatibility.py | 25 +++++++++++++------------ gym/wrappers/step_api_compatibility.py | 9 ++++++--- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index 2954dae5636..0a4d9a4675f 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -5,14 +5,14 @@ from gym.core import ObsType -OldStepType = Tuple[ +DoneStepType = Tuple[ Union[ObsType, np.ndarray], Union[float, np.ndarray], Union[bool, np.ndarray], Union[dict, list], ] -NewStepType = Tuple[ +TerminatedTruncatedStepType = Tuple[ Union[ObsType, np.ndarray], Union[float, np.ndarray], Union[bool, np.ndarray], @@ -21,9 +21,9 @@ ] -def step_to_new_api( - step_returns: Union[OldStepType, NewStepType], is_vector_env=False -) -> NewStepType: +def convert_to_terminated_truncated_step_api( + step_returns: Union[DoneStepType, TerminatedTruncatedStepType], is_vector_env=False +) -> TerminatedTruncatedStepType: """Function to transform step returns to new step API irrespective of input API. Args: @@ -98,9 +98,10 @@ def step_to_new_api( ) -def step_to_old_api( - step_returns: Union[NewStepType, OldStepType], is_vector_env: bool = False -) -> OldStepType: +def convert_to_done_step_api( + step_returns: Union[TerminatedTruncatedStepType, DoneStepType], + is_vector_env: bool = False, +) -> DoneStepType: """Function to transform step returns to old step API irrespective of input API. Args: @@ -152,10 +153,10 @@ def step_to_old_api( def step_api_compatibility( - step_returns: Union[NewStepType, OldStepType], + step_returns: Union[TerminatedTruncatedStepType, DoneStepType], output_truncation_bool: bool = True, is_vector_env: bool = False, -) -> Union[NewStepType, OldStepType]: +) -> Union[TerminatedTruncatedStepType, DoneStepType]: """Function to transform step returns to the API specified by `output_truncation_bool` bool. Old step API refers to step() method returning (observation, reward, done, info) @@ -179,6 +180,6 @@ def step_api_compatibility( >>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True) """ if output_truncation_bool: - return step_to_new_api(step_returns, is_vector_env) + return convert_to_terminated_truncated_step_api(step_returns, is_vector_env) else: - return step_to_old_api(step_returns, is_vector_env) + return convert_to_done_step_api(step_returns, is_vector_env) diff --git a/gym/wrappers/step_api_compatibility.py b/gym/wrappers/step_api_compatibility.py index 661fb8a5b8e..108c23362ab 100644 --- a/gym/wrappers/step_api_compatibility.py +++ b/gym/wrappers/step_api_compatibility.py @@ -1,7 +1,10 @@ """Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API.""" import gym from gym.logger import deprecation -from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api +from gym.utils.step_api_compatibility import ( + convert_to_done_step_api, + convert_to_terminated_truncated_step_api, +) class StepAPICompatibility(gym.Wrapper): @@ -50,6 +53,6 @@ def step(self, action): """ step_returns = self.env.step(action) if self.output_truncation_bool: - return step_to_new_api(step_returns) + return convert_to_terminated_truncated_step_api(step_returns) else: - return step_to_old_api(step_returns) + return convert_to_done_step_api(step_returns) From fb61509ff18f21f072f17efcfa324e8aa6cbaeea Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Tue, 23 Aug 2022 09:42:39 +0530 Subject: [PATCH 06/12] fix tests --- tests/utils/test_step_api_compatibility.py | 16 ++++++++-------- tests/wrappers/test_autoreset.py | 1 - 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/utils/test_step_api_compatibility.py b/tests/utils/test_step_api_compatibility.py index ade45892072..883e88c2ce7 100644 --- a/tests/utils/test_step_api_compatibility.py +++ b/tests/utils/test_step_api_compatibility.py @@ -2,7 +2,7 @@ import pytest from gym.utils.env_checker import data_equivalence -from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api +from gym.utils.step_api_compatibility import convert_to_done_step_api, convert_to_terminated_truncated_step_api @pytest.mark.parametrize( @@ -54,7 +54,7 @@ def test_to_done_step_api( is_vector_env, done_returns, expected_terminated, expected_truncated ): - _, _, terminated, truncated, info = step_to_new_api( + _, _, terminated, truncated, info = convert_to_terminated_truncated_step_api( done_returns, is_vector_env=is_vector_env ) assert np.all(terminated == expected_terminated) @@ -67,7 +67,7 @@ def test_to_done_step_api( else: # isinstance(info, dict) assert "TimeLimit.truncated" not in info - roundtripped_returns = step_to_old_api( + roundtripped_returns = convert_to_done_step_api( (0, 0, terminated, truncated, info), is_vector_env=is_vector_env ) assert data_equivalence(done_returns, roundtripped_returns) @@ -112,7 +112,7 @@ def test_to_done_step_api( def test_to_terminated_truncated_step_api( is_vector_env, terminated_truncated_returns, expected_done, expected_truncated ): - _, _, done, info = step_to_old_api( + _, _, done, info = convert_to_done_step_api( terminated_truncated_returns, is_vector_env=is_vector_env ) assert np.all(done == expected_done) @@ -136,7 +136,7 @@ def test_to_terminated_truncated_step_api( else: assert "TimeLimit.truncated" not in info - roundtripped_returns = step_to_new_api( + roundtripped_returns = convert_to_terminated_truncated_step_api( (0, 0, done, info), is_vector_env=is_vector_env ) assert data_equivalence(terminated_truncated_returns, roundtripped_returns) @@ -146,19 +146,19 @@ def test_edge_case(): # When converting between the two-step APIs this is not possible in a single case # terminated=True and truncated=True -> done=True and info={} # We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail - _, _, done, info = step_to_old_api((0, 0, True, True, {})) + _, _, done, info = convert_to_done_step_api((0, 0, True, True, {})) assert done is True assert info == {"TimeLimit.truncated": False} # Test with vector dict info - _, _, done, info = step_to_old_api( + _, _, done, info = convert_to_done_step_api( (0, 0, np.array([True]), np.array([True]), {}), is_vector_env=True ) assert np.all(done) assert info == {"TimeLimit.truncated": np.array([False])} # Test with vector list info - _, _, done, info = step_to_old_api( + _, _, done, info = convert_to_done_step_api( (0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]), is_vector_env=True, ) diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index 4500f57fd3e..736485af370 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -139,7 +139,6 @@ def test_autoreset_wrapper_autoreset(): "count": 0, "final_observation": np.array([3]), "final_info": {"count": 3}, - "TimeLimit.truncated": False, } obs, reward, terminated, truncated, info = env.step(action) From 48dad41fbdc0dd9bf2c348ffdba463c14ec94a09 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Tue, 23 Aug 2022 22:27:40 +0530 Subject: [PATCH 07/12] pre-commit --- gym/core.py | 2 +- gym/vector/vector_env.py | 1 - tests/utils/test_step_api_compatibility.py | 5 ++++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/gym/core.py b/gym/core.py index c52afc7482b..6ee32f11c1d 100644 --- a/gym/core.py +++ b/gym/core.py @@ -16,7 +16,7 @@ import numpy as np from gym import spaces -from gym.logger import deprecation, warn +from gym.logger import warn from gym.utils import seeding if TYPE_CHECKING: diff --git a/gym/vector/vector_env.py b/gym/vector/vector_env.py index d16d59dceeb..450fa77de43 100644 --- a/gym/vector/vector_env.py +++ b/gym/vector/vector_env.py @@ -4,7 +4,6 @@ import numpy as np import gym -from gym.logger import deprecation from gym.vector.utils.spaces import batch_space __all__ = ["VectorEnv"] diff --git a/tests/utils/test_step_api_compatibility.py b/tests/utils/test_step_api_compatibility.py index 883e88c2ce7..a3b4ea343d0 100644 --- a/tests/utils/test_step_api_compatibility.py +++ b/tests/utils/test_step_api_compatibility.py @@ -2,7 +2,10 @@ import pytest from gym.utils.env_checker import data_equivalence -from gym.utils.step_api_compatibility import convert_to_done_step_api, convert_to_terminated_truncated_step_api +from gym.utils.step_api_compatibility import ( + convert_to_done_step_api, + convert_to_terminated_truncated_step_api, +) @pytest.mark.parametrize( From 21bb6ee961a8fbfcd6e929f8fc4fd548ca864600 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Tue, 23 Aug 2022 23:05:54 +0530 Subject: [PATCH 08/12] fix docstrings etc --- README.md | 4 +- gym/core.py | 3 +- gym/utils/play.py | 11 ++- gym/utils/step_api_compatibility.py | 4 +- gym/wrappers/autoreset.py | 5 +- gym/wrappers/normalize.py | 5 +- tests/envs/test_envs.py | 4 +- .../vector/test_step_compatibility_vector.py | 88 ------------------- 8 files changed, 18 insertions(+), 106 deletions(-) delete mode 100644 tests/vector/test_step_compatibility_vector.py diff --git a/README.md b/README.md index d5f80857471..33e5916922a 100644 --- a/README.md +++ b/README.md @@ -31,9 +31,9 @@ observation, info = env.reset(seed=42) for _ in range(1000): action = env.action_space.sample() - observation, reward, done, info = env.step(action) + observation, reward, terminated, truncarted, info = env.step(action) - if done: + if terminated or truncated: observation, info = env.reset() env.close() ``` diff --git a/gym/core.py b/gym/core.py index 6ee32f11c1d..dea7dfa7b52 100644 --- a/gym/core.py +++ b/gym/core.py @@ -87,8 +87,7 @@ def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: """Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state. - Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`, or a tuple - (observation, reward, done, info). The latter is deprecated and will be removed in future versions. + Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`. Args: action (ActType): an action provided by the agent diff --git a/gym/utils/play.py b/gym/utils/play.py index 21f4ac73439..f4fbdbf49d8 100644 --- a/gym/utils/play.py +++ b/gym/utils/play.py @@ -170,7 +170,7 @@ def play( :class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward for last 150 steps. - >>> def callback(obs_t, obs_tp1, action, rew, done, info): + >>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): ... return [rew,] >>> plotter = PlayPlot(callback, 150, ["reward"]) >>> play(gym.make("ALE/AirRaid-v5"), callback=plotter.callback) @@ -187,7 +187,8 @@ def play( obs_tp1: observation after performing action action: action that was executed rew: reward that was received - done: whether the environment is done or not + terminated: whether the environment is terminated or not + truncated: whether the environment is truncated or not info: debug info keys_to_action: Mapping from keys pressed to action performed. Different formats are supported: Key combinations can either be expressed as a tuple of unicode code @@ -257,6 +258,7 @@ def play( action = key_code_to_action.get(tuple(sorted(game.pressed_keys)), noop) prev_obs = obs obs, rew, terminated, truncated, info = env.step(action) + done = terminated or truncated if callback is not None: callback(prev_obs, obs, action, rew, terminated, truncated, info) if obs is not None: @@ -285,13 +287,14 @@ class PlayPlot: - obs_tp1: observation after performing action - action: action that was executed - rew: reward that was received - - done: whether the environment is done or not + - terminated: whether the environment is terminated or not + - truncated: whether the environment is truncated or not - info: debug info It should return a list of metrics that are computed from this data. For instance, the function may look like this:: - >>> def compute_metrics(obs_t, obs_tp, action, reward, done, info): + >>> def compute_metrics(obs_t, obs_tp, action, reward, terminated, truncated, info): ... return [reward, info["cumulative_reward"], np.linalg.norm(action)] :class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index 771df73dd2d..2ac2f9b2384 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -135,8 +135,8 @@ def step_api_compatibility( ) -> Union[TerminatedTruncatedStepType, DoneStepType]: """Function to transform step returns to the API specified by `output_truncation_bool` bool. - Old step API refers to step() method returning (observation, reward, done, info) - New step API refers to step() method returning (observation, reward, terminated, truncated, info) + Done (old) step API refers to step() method returning (observation, reward, done, info) + Terminated Truncated (new) step API refers to step() method returning (observation, reward, terminated, truncated, info) (Refer to docs for details on the API change) Args: diff --git a/gym/wrappers/autoreset.py b/gym/wrappers/autoreset.py index 5c4026da0bc..17646abfea0 100644 --- a/gym/wrappers/autoreset.py +++ b/gym/wrappers/autoreset.py @@ -10,14 +10,15 @@ class AutoResetWrapper(gym.Wrapper): with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API. - ``new_obs`` is the first observation after calling :meth:`self.env.reset` - ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`. - - ``final_done`` is always True. In the new API, either ``final_terminated`` or ``final_truncated`` is True + - ``final_terminated`` is the terminated value before calling :meth:`self.env.reset`. + - ``final_truncated`` is the truncated value before calling :meth:`self.env.reset`. Both `final_terminated` and `final_truncated` cannot be False. - ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`, with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step` and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`. Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the - final reward and done state from the previous episode. + final reward, terminated and truncated state from the previous episode. If you need the final state from the previous episode, you need to retrieve it via the "final_observation" key in the info dict. Make sure you know what you're doing if you use this wrapper! diff --git a/gym/wrappers/normalize.py b/gym/wrappers/normalize.py index 41b63a2371f..e8b51675c02 100644 --- a/gym/wrappers/normalize.py +++ b/gym/wrappers/normalize.py @@ -132,10 +132,7 @@ def step(self, action): rews = np.array([rews]) self.returns = self.returns * self.gamma + rews rews = self.normalize(rews) - if not self.is_vector_env: - dones = terminateds or truncateds - else: - dones = np.bitwise_or(terminateds, truncateds) + dones = np.logical_or(terminateds, truncateds) self.returns[dones] = 0.0 if not self.is_vector_env: rews = rews[0] diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 7517a58b033..f23eb3148f4 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -18,7 +18,7 @@ f"\x1b[33mWARN: {message}\x1b[0m" for message in [ "This version of the mujoco environments depends on the mujoco-py bindings, which are no longer maintained and may stop working. Please upgrade to the v4 versions of the environments (which depend on the mujoco python bindings instead), unless you are trying to precisely replicate previous works).", - "Initializing environment in old step API which returns one bool instead of two.", + "Initializing environment in done (old) step API which returns one bool instead of two.", ] ] @@ -29,7 +29,7 @@ "A Box observation space minimum value is -infinity. This is probably too low.", "A Box observation space maximum value is -infinity. This is probably too high.", "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.", - "Initializing environment in old step API which returns one bool instead of two.", + "Initializing environment in done (old) step API which returns one bool instead of two.", ] ] diff --git a/tests/vector/test_step_compatibility_vector.py b/tests/vector/test_step_compatibility_vector.py deleted file mode 100644 index b031b670104..00000000000 --- a/tests/vector/test_step_compatibility_vector.py +++ /dev/null @@ -1,88 +0,0 @@ -import numpy as np -import pytest - -import gym -from gym.spaces import Discrete -from gym.vector import AsyncVectorEnv, SyncVectorEnv - - -class OldStepEnv(gym.Env): - def __init__(self): - self.action_space = Discrete(2) - self.observation_space = Discrete(2) - - def reset(self): - return 0, {} - - def step(self, action): - obs = self.observation_space.sample() - rew = 0 - done = False - info = {} - return obs, rew, done, info - - -class NewStepEnv(gym.Env): - def __init__(self): - self.action_space = Discrete(2) - self.observation_space = Discrete(2) - - def reset(self): - return 0, {} - - def step(self, action): - obs = self.observation_space.sample() - rew = 0 - terminated = False - truncated = False - info = {} - return obs, rew, terminated, truncated, info - - -@pytest.mark.parametrize("VecEnv", [AsyncVectorEnv, SyncVectorEnv]) -def test_vector_step_compatibility_new_env(VecEnv): - - envs = [ - OldStepEnv(), - NewStepEnv(), - ] - - vec_env = VecEnv([lambda: env for env in envs]) - vec_env.reset() - step_returns = vec_env.step([0, 0]) - assert len(step_returns) == 4 - _, _, dones, _ = step_returns - assert dones.dtype == np.bool_ - vec_env.close() - - vec_env = VecEnv([lambda: env for env in envs], new_step_api=True) - vec_env.reset() - step_returns = vec_env.step([0, 0]) - assert len(step_returns) == 5 - _, _, terminateds, truncateds, _ = step_returns - assert terminateds.dtype == np.bool_ - assert truncateds.dtype == np.bool_ - vec_env.close() - - -@pytest.mark.parametrize("async_bool", [True, False]) -def test_vector_step_compatibility_existing(async_bool): - - env = gym.vector.make("CartPole-v1", num_envs=3, asynchronous=async_bool) - env.reset() - step_returns = env.step(env.action_space.sample()) - assert len(step_returns) == 4 - _, _, dones, _ = step_returns - assert dones.dtype == np.bool_ - env.close() - - env = gym.vector.make( - "CartPole-v1", num_envs=3, asynchronous=async_bool, new_step_api=True - ) - env.reset() - step_returns = env.step(env.action_space.sample()) - assert len(step_returns) == 5 - _, _, terminateds, truncateds, _ = step_returns - assert terminateds.dtype == np.bool_ - assert truncateds.dtype == np.bool_ - env.close() From a00163103f61732ac884d962176a852386d3ced0 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 25 Aug 2022 18:57:33 +0530 Subject: [PATCH 09/12] delete old test --- tests/wrappers/test_step_compatibility.py | 78 ----------------------- 1 file changed, 78 deletions(-) delete mode 100644 tests/wrappers/test_step_compatibility.py diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py deleted file mode 100644 index 1a3a8de15cd..00000000000 --- a/tests/wrappers/test_step_compatibility.py +++ /dev/null @@ -1,78 +0,0 @@ -import pytest - -import gym -from gym.spaces import Discrete -from gym.wrappers import StepAPICompatibility - - -class OldStepEnv(gym.Env): - def __init__(self): - self.action_space = Discrete(2) - self.observation_space = Discrete(2) - - def step(self, action): - obs = self.observation_space.sample() - rew = 0 - done = False - info = {} - return obs, rew, done, info - - -class NewStepEnv(gym.Env): - def __init__(self): - self.action_space = Discrete(2) - self.observation_space = Discrete(2) - - def step(self, action): - obs = self.observation_space.sample() - rew = 0 - terminated = False - truncated = False - info = {} - return obs, rew, terminated, truncated, info - - -@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) -@pytest.mark.parametrize("output_truncation_bool", [None, True]) -def test_step_compatibility_to_new_api(env, output_truncation_bool): - if output_truncation_bool is None: - env = StepAPICompatibility(env()) - else: - env = StepAPICompatibility(env(), output_truncation_bool) - step_returns = env.step(0) - _, _, terminated, truncated, _ = step_returns - assert isinstance(terminated, bool) - assert isinstance(truncated, bool) - - -@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) -def test_step_compatibility_to_old_api(env): - env = StepAPICompatibility(env(), False) - step_returns = env.step(0) - assert len(step_returns) == 4 - _, _, done, _ = step_returns - assert isinstance(done, bool) - - -@pytest.mark.parametrize("apply_step_compatibility", [None, True, False]) -def test_step_compatibility_in_make(apply_step_compatibility): - gym.register("OldStepEnv-v0", entry_point=OldStepEnv) - - if apply_step_compatibility is not None: - env = gym.make( - "OldStepEnv-v0", apply_step_compatibility=apply_step_compatibility - ) - elif apply_step_compatibility is None: - env = gym.make("OldStepEnv-v0") - - env.reset() - step_returns = env.step(0) - if apply_step_compatibility: - assert len(step_returns) == 5 - _, _, terminated, truncated, _ = step_returns - assert isinstance(terminated, bool) - assert isinstance(truncated, bool) - else: - assert len(step_returns) == 4 - _, _, done, _ = step_returns - assert isinstance(done, bool) From 0277ca42e6aa13c54650d44203838919c05e5cfd Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 25 Aug 2022 19:18:15 +0530 Subject: [PATCH 10/12] restore test --- tests/wrappers/test_step_compatibility.py | 78 +++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tests/wrappers/test_step_compatibility.py diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py new file mode 100644 index 00000000000..1a3a8de15cd --- /dev/null +++ b/tests/wrappers/test_step_compatibility.py @@ -0,0 +1,78 @@ +import pytest + +import gym +from gym.spaces import Discrete +from gym.wrappers import StepAPICompatibility + + +class OldStepEnv(gym.Env): + def __init__(self): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + + def step(self, action): + obs = self.observation_space.sample() + rew = 0 + done = False + info = {} + return obs, rew, done, info + + +class NewStepEnv(gym.Env): + def __init__(self): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + + def step(self, action): + obs = self.observation_space.sample() + rew = 0 + terminated = False + truncated = False + info = {} + return obs, rew, terminated, truncated, info + + +@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) +@pytest.mark.parametrize("output_truncation_bool", [None, True]) +def test_step_compatibility_to_new_api(env, output_truncation_bool): + if output_truncation_bool is None: + env = StepAPICompatibility(env()) + else: + env = StepAPICompatibility(env(), output_truncation_bool) + step_returns = env.step(0) + _, _, terminated, truncated, _ = step_returns + assert isinstance(terminated, bool) + assert isinstance(truncated, bool) + + +@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) +def test_step_compatibility_to_old_api(env): + env = StepAPICompatibility(env(), False) + step_returns = env.step(0) + assert len(step_returns) == 4 + _, _, done, _ = step_returns + assert isinstance(done, bool) + + +@pytest.mark.parametrize("apply_step_compatibility", [None, True, False]) +def test_step_compatibility_in_make(apply_step_compatibility): + gym.register("OldStepEnv-v0", entry_point=OldStepEnv) + + if apply_step_compatibility is not None: + env = gym.make( + "OldStepEnv-v0", apply_step_compatibility=apply_step_compatibility + ) + elif apply_step_compatibility is None: + env = gym.make("OldStepEnv-v0") + + env.reset() + step_returns = env.step(0) + if apply_step_compatibility: + assert len(step_returns) == 5 + _, _, terminated, truncated, _ = step_returns + assert isinstance(terminated, bool) + assert isinstance(truncated, bool) + else: + assert len(step_returns) == 4 + _, _, done, _ = step_returns + assert isinstance(done, bool) From f8add0fdec30da2aad17a477e1dac978a283e0c7 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 25 Aug 2022 19:28:18 +0530 Subject: [PATCH 11/12] passive env checker, fix old test --- gym/utils/passive_env_checker.py | 18 +++++++++++------- tests/wrappers/test_step_compatibility.py | 6 ++++-- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/gym/utils/passive_env_checker.py b/gym/utils/passive_env_checker.py index 4c76db3fbcd..86a80fef8f5 100644 --- a/gym/utils/passive_env_checker.py +++ b/gym/utils/passive_env_checker.py @@ -195,13 +195,17 @@ def env_reset_passive_checker(env, **kwargs): logger.warn( f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`" ) - - obs, info = result - check_obs(obs, env.observation_space, "reset") - assert isinstance( - info, dict - ), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}" - return result + elif len(result) != 2: + logger.warn( + "The result returned by `env.reset()` should be `(obs, info)` by default, , where `obs` is a observation and `info` is a dictionary containing additional information." + ) + else: + obs, info = result + check_obs(obs, env.observation_space, "reset") + assert isinstance( + info, dict + ), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}" + return result def env_step_passive_checker(env, action): diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py index 1a3a8de15cd..5e0bd39f1ea 100644 --- a/tests/wrappers/test_step_compatibility.py +++ b/tests/wrappers/test_step_compatibility.py @@ -60,10 +60,12 @@ def test_step_compatibility_in_make(apply_step_compatibility): if apply_step_compatibility is not None: env = gym.make( - "OldStepEnv-v0", apply_step_compatibility=apply_step_compatibility + "OldStepEnv-v0", + apply_step_compatibility=apply_step_compatibility, + disable_env_checker=True, ) elif apply_step_compatibility is None: - env = gym.make("OldStepEnv-v0") + env = gym.make("OldStepEnv-v0", disable_env_checker=True) env.reset() step_returns = env.step(0) From 7a8f6d4885c73dd5adc021703cb2925fcc7546f6 Mon Sep 17 00:00:00 2001 From: arjun_kg Date: Thu, 25 Aug 2022 19:33:44 +0530 Subject: [PATCH 12/12] de-indent --- gym/utils/passive_env_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gym/utils/passive_env_checker.py b/gym/utils/passive_env_checker.py index 86a80fef8f5..bd826510f48 100644 --- a/gym/utils/passive_env_checker.py +++ b/gym/utils/passive_env_checker.py @@ -205,7 +205,7 @@ def env_reset_passive_checker(env, **kwargs): assert isinstance( info, dict ), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}" - return result + return result def env_step_passive_checker(env, action):