Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Monitor episode reward/length for evaluate_policy #220

Merged
merged 17 commits into from
Nov 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ In the following example, we will train, save and load a DQN model on the Lunar
model = DQN.load("dqn_lunar")

# Evaluate the agent
# NOTE: If you use wrappers with your environment that modify rewards,
# this will be reflected here. To evaluate with original rewards,
# wrap environment in a "Monitor" wrapper before other wrappers.
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)

# Enjoy trained agent
Expand Down
14 changes: 10 additions & 4 deletions docs/guide/rl_tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ TL;DR

1. Read about RL and Stable Baselines3
2. Do quantitative experiments and hyperparameter tuning if needed
3. Evaluate the performance using a separate test environment
3. Evaluate the performance using a separate test environment (remember to check wrappers!)
4. For better performance, increase the training budget


Expand Down Expand Up @@ -68,18 +68,24 @@ Other method, like ``TRPO`` or ``PPO`` make use of a *trust region* to minimize
How to evaluate an RL algorithm?
--------------------------------

.. note::

Pay attention to environment wrappers when evaluating your agent and comparing results to others' results. Modifications to episode rewards
or lengths may also affect evaluation results which may not be desirable. Check ``evaluate_policy`` helper function in :ref:`Evaluation Helper <eval>` section.

Because most algorithms use exploration noise during training, you need a separate test environment to evaluate the performance
of your agent at a given time. It is recommended to periodically evaluate your agent for ``n`` test episodes (``n`` is usually between 5 and 20)
and average the reward per episode to have a good estimate.

.. note::

We provide an ``EvalCallback`` for doing such evaluation. You can read more about it in the :ref:`Callbacks <callbacks>` section.

As some policy are stochastic by default (e.g. A2C or PPO), you should also try to set `deterministic=True` when calling the `.predict()` method,
this frequently leads to better performance.
Looking at the training curve (episode reward function of the timesteps) is a good proxy but underestimates the agent true performance.


.. note::

We provide an ``EvalCallback`` for doing such evaluation. You can read more about it in the :ref:`Callbacks <callbacks>` section.



Expand Down
9 changes: 8 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,21 @@ Pre-Release 0.11.0a0 (WIP)

Breaking Changes:
^^^^^^^^^^^^^^^^^

- ``evaluate_policy`` now returns rewards/episode lengths from a ``Monitor`` wrapper if one is present,
this allows to return the unnormalized reward in the case of Atari games for instance.
- Renamed ``common.vec_env.is_wrapped`` to ``common.vec_env.is_vecenv_wrapped`` to avoid confusion
with the new ``is_wrapped()`` helper

New Features:
^^^^^^^^^^^^^
- Add support for ``VecFrameStack`` to stack on first or last observation dimension, along with
automatic check for image spaces.
- ``VecFrameStack`` now has a ``channels_order`` argument to tell if observations should be stacked
on the first or last observation dimension (originally always stacked on last).
- Added ``common.env_util.is_wrapped`` and ``common.env_util.unwrap_wrapper`` functions for checking/unwrapping
an environment for specific wrapper.
- Added ``env_is_wrapped()`` method for ``VecEnv`` to check if its environments are wrapped
with given Gym wrappers.

Bug Fixes:
^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
VecEnv,
VecNormalize,
VecTransposeImage,
is_wrapped,
is_vecenv_wrapped,
unwrap_vec_normalize,
)
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper
Expand Down Expand Up @@ -178,7 +178,7 @@ def _wrap_env(env: GymEnv, verbose: int = 0) -> VecEnv:

if (
is_image_space(env.observation_space)
and not is_wrapped(env, VecTransposeImage)
and not is_vecenv_wrapped(env, VecTransposeImage)
and not is_image_space_channels_first(env.observation_space)
):
if verbose >= 1:
Expand Down
5 changes: 5 additions & 0 deletions stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ class EvalCallback(EventCallback):
:param deterministic: Whether to render or not the environment during evaluation
:param render: Whether to render or not the environment during evaluation
:param verbose:
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
wrapped with a Monitor wrapper)
"""

def __init__(
Expand All @@ -289,6 +291,7 @@ def __init__(
deterministic: bool = True,
render: bool = False,
verbose: int = 1,
warn: bool = True,
):
super(EvalCallback, self).__init__(callback_on_new_best, verbose=verbose)
self.n_eval_episodes = n_eval_episodes
Expand All @@ -297,6 +300,7 @@ def __init__(
self.last_mean_reward = -np.inf
self.deterministic = deterministic
self.render = render
self.warn = warn

# Convert to VecEnv for consistency
if not isinstance(eval_env, VecEnv):
Expand Down Expand Up @@ -339,6 +343,7 @@ def _on_step(self) -> bool:
render=self.render,
deterministic=self.deterministic,
return_episode_rewards=True,
warn=self.warn,
)

if self.log_path is not None:
Expand Down
27 changes: 27 additions & 0 deletions stable_baselines3/common/env_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,33 @@
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv


def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.

:param env: Environment to unwrap
:param wrapper_class: Wrapper to look for
:return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it
"""
env_tmp = env
while isinstance(env_tmp, gym.Wrapper):
if isinstance(env_tmp, wrapper_class):
return env_tmp
env_tmp = env_tmp.env
return None


def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool:
"""
Check if a given environment has been wrapped with a given wrapper.

:param env: Environment to check
:param wrapper_class: Wrapper class to look for
:return: True if environment has been wrapped with ``wrapper_class``.
"""
return unwrap_wrapper(env, wrapper_class) is not None


def make_vec_env(
env_id: Union[str, Type[gym.Env]],
n_envs: int = 1,
Expand Down
67 changes: 57 additions & 10 deletions stable_baselines3/common/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import gym
Expand All @@ -16,11 +17,20 @@ def evaluate_policy(
callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
reward_threshold: Optional[float] = None,
return_episode_rewards: bool = False,
warn: bool = True,
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
"""
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
This is made to work only with one env.

.. note::
If environment has not been wrapped with ``Monitor`` wrapper, reward and
episode lengths are counted as it appears with ``env.step`` calls. If
the environment contains wrappers that modify rewards or episode lengths
(e.g. reward scaling, early episode reset), these will affect the evaluation
results as well. You can avoid this by wrapping environment with ``Monitor``
wrapper before anything else.

:param model: The RL agent you want to evaluate.
:param env: The gym environment. In the case of a ``VecEnv``
this must contain only one environment.
Expand All @@ -31,33 +41,70 @@ def evaluate_policy(
called after each step. Gets locals() and globals() passed as parameters.
:param reward_threshold: Minimum expected reward per episode,
this will raise an error if the performance is not met
:param return_episode_rewards: If True, a list of reward per episode
will be returned instead of the mean.
:return: Mean reward per episode, std of reward per episode
returns ([float], [int]) when ``return_episode_rewards`` is True
:param return_episode_rewards: If True, a list of rewards and episde lengths
per episode will be returned instead of the mean.
:param warn: If True (default), warns user about lack of a Monitor wrapper in the
evaluation environment.
:return: Mean reward per episode, std of reward per episode.
Returns ([float], [int]) when ``return_episode_rewards`` is True, first
list containing per-episode rewards and second containing per-episode lengths
(in number of steps).
"""
is_monitor_wrapped = False
# Avoid circular import
from stable_baselines3.common.env_util import is_wrapped
from stable_baselines3.common.monitor import Monitor

if isinstance(env, VecEnv):
assert env.num_envs == 1, "You must pass only one environment when using this function"
is_monitor_wrapped = env.env_is_wrapped(Monitor)[0]
else:
is_monitor_wrapped = is_wrapped(env, Monitor)

if not is_monitor_wrapped and warn:
warnings.warn(
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
"This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
"Consider wrapping environment first with ``Monitor`` wrapper.",
UserWarning,
)

episode_rewards, episode_lengths = [], []
for i in range(n_eval_episodes):
# Avoid double reset, as VecEnv are reset automatically
if not isinstance(env, VecEnv) or i == 0:
not_reseted = True
while len(episode_rewards) < n_eval_episodes:
# Number of loops here might differ from true episodes
# played, if underlying wrappers modify episode lengths.
# Avoid double reset, as VecEnv are reset automatically.
if not isinstance(env, VecEnv) or not_reseted:
obs = env.reset()
not_reseted = False
done, state = False, None
episode_reward = 0.0
episode_length = 0
while not done:
action, state = model.predict(obs, state=state, deterministic=deterministic)
obs, reward, done, _info = env.step(action)
obs, reward, done, info = env.step(action)
episode_reward += reward
if callback is not None:
callback(locals(), globals())
episode_length += 1
if render:
env.render()
episode_rewards.append(episode_reward)
episode_lengths.append(episode_length)

if is_monitor_wrapped:
# Do not trust "done" with episode endings.
# Remove vecenv stacking (if any)
if isinstance(env, VecEnv):
info = info[0]
if "episode" in info.keys():
# Monitor wrapper includes "episode" key in info if environment
# has been wrapped with it. Use those rewards instead.
episode_rewards.append(info["episode"]["r"])
episode_lengths.append(info["episode"]["l"])
else:
episode_rewards.append(episode_reward)
episode_lengths.append(episode_length)

mean_reward = np.mean(episode_rewards)
std_reward = np.std(episode_rewards)
if reward_threshold is not None:
Expand Down
5 changes: 2 additions & 3 deletions stable_baselines3/common/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import numpy as np
import torch as th

from stable_baselines3.common import callbacks
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common import callbacks, vec_env

GymEnv = Union[gym.Env, VecEnv]
GymEnv = Union[gym.Env, vec_env.VecEnv]
GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
GymStepReturn = Tuple[GymObs, float, bool, Dict]
TensorDict = Dict[str, th.Tensor]
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/vec_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]
return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type


def is_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
"""
Check if an environment is already wrapped by a given ``VecEnvWrapper``.

Expand Down
18 changes: 17 additions & 1 deletion stable_baselines3/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union

import cloudpickle
import gym
Expand Down Expand Up @@ -139,6 +139,19 @@ def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = No
"""
raise NotImplementedError()

@abstractmethod
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""
Check if environments are wrapped with a given wrapper.

:param method_name: The name of the environment method to invoke.
:param indices: Indices of envs whose method to call
:param method_args: Any positional arguments to provide in the call
:param method_kwargs: Any keyword arguments to provide in the call
:return: True if the env is wrapped, False otherwise, for each env queried.
"""
raise NotImplementedError()

def step(self, actions: np.ndarray) -> VecEnvStepReturn:
"""
Step the environments with the given action
Expand Down Expand Up @@ -280,6 +293,9 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) ->
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)

def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
return self.venv.env_is_wrapped(wrapper_class, indices=indices)

def __getattr__(self, name: str) -> Any:
"""Find attribute from wrapped venv(s) if this wrapper does not have it.
Useful for accessing attributes from venvs which are wrapped with multiple wrappers
Expand Down
10 changes: 9 additions & 1 deletion stable_baselines3/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, List, Optional, Sequence, Union
from typing import Any, Callable, List, Optional, Sequence, Type, Union

import gym
import numpy as np
Expand Down Expand Up @@ -112,6 +112,14 @@ def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = No
target_envs = self._get_target_envs(indices)
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]

def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
target_envs = self._get_target_envs(indices)
# Import here to avoid a circular import
from stable_baselines3.common import env_util

return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]

def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
indices = self._get_indices(indices)
return [self.envs[i] for i in indices]
14 changes: 13 additions & 1 deletion stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import multiprocessing as mp
from collections import OrderedDict
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union

import gym
import numpy as np
Expand All @@ -17,6 +17,9 @@
def _worker(
remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper
) -> None:
# Import here to avoid a circular import
from stable_baselines3.common.env_util import is_wrapped

parent_remote.close()
env = env_fn_wrapper.var()
while True:
Expand Down Expand Up @@ -49,6 +52,8 @@ def _worker(
remote.send(getattr(env, data))
elif cmd == "set_attr":
remote.send(setattr(env, data[0], data[1]))
elif cmd == "is_wrapped":
remote.send(is_wrapped(env, data))
else:
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
except EOFError:
Expand Down Expand Up @@ -170,6 +175,13 @@ def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = No
remote.send(("env_method", (method_name, method_args, method_kwargs)))
return [remote.recv() for remote in target_remotes]

def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("is_wrapped", wrapper_class))
return [remote.recv() for remote in target_remotes]

def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]:
"""
Get the connection object needed to communicate with the wanted
Expand Down
7 changes: 6 additions & 1 deletion tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ def test_callbacks(tmp_path, model_class):
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)

eval_callback = EvalCallback(
eval_env, callback_on_new_best=callback_on_best, best_model_save_path=log_folder, log_path=log_folder, eval_freq=100
eval_env,
callback_on_new_best=callback_on_best,
best_model_save_path=log_folder,
log_path=log_folder,
eval_freq=100,
warn=False,
)
# Equivalent to the `checkpoint_callback`
# but here in an event-driven manner
Expand Down
Loading