Skip to content

Commit

Permalink
Merge pull request #89 from DLR-RM/base-class-review
Browse files Browse the repository at this point in the history
Refactor and clean-up of common code
  • Loading branch information
AdamGleave authored Jul 8, 2020
2 parents 3756d05 + bea1e02 commit c39ed39
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 191 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ keys/

# Virtualenv
/env
/venv


*.sublime-project
Expand Down
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
SHELL=/bin/bash
LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py

pytest:
./scripts/run_tests.sh
Expand All @@ -9,9 +10,9 @@ type:
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://lintlyci.github.io/Flake8Rules/
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings.
flake8 . --count --exit-zero --statistics
flake8 ${LINT_PATHS} --count --exit-zero --statistics

doc:
cd docs && make html
Expand Down
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
# PyEnchant.
try:
import sphinxcontrib.spelling
import sphinxcontrib.spelling # noqa: F401
enable_spell_check = True
except ImportError:
enable_spell_check = False
Expand Down Expand Up @@ -129,6 +129,7 @@ def __getattr__(cls, name):
def setup(app):
app.add_stylesheet("css/baselines_theme.css")


# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
Expand Down
98 changes: 52 additions & 46 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Abstract base classes for RL algorithms."""

import time
from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable
from typing import Union, Type, Optional, Dict, Any, Iterable, List, Tuple, Callable
from abc import ABC, abstractmethod
from collections import deque
import pathlib
Expand All @@ -23,12 +25,30 @@
from stable_baselines3.common.noise import ActionNoise


def maybe_make_env(env: Union[GymEnv, str, None], monitor_wrapper: bool, verbose: int) -> Optional[GymEnv]:
"""If env is a string, make the environment; otherwise, return env.
:param env: (Union[GymEnv, str, None]) The environment to learn from.
:param monitor_wrapper: (bool) Whether to wrap env in a Monitor when creating env.
:param verbose: (int) logging verbosity
:return A Gym (vector) environment.
"""
if isinstance(env, str):
if verbose >= 1:
print(f"Creating environment from the given name '{env}'")
env = gym.make(env)
if monitor_wrapper:
env = Monitor(env, filename=None)

return env


class BaseAlgorithm(ABC):
"""
The base of RL algorithms
:param policy: (Type[BasePolicy]) Policy object
:param env: (Union[GymEnv, str]) The environment to learn from
:param env: (Union[GymEnv, str, None]) The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param policy_base: (Type[BasePolicy]) The base policy used by this method
:param learning_rate: (float or callable) learning rate for the optimizer,
Expand All @@ -54,7 +74,7 @@ class BaseAlgorithm(ABC):

def __init__(self,
policy: Type[BasePolicy],
env: Union[GymEnv, str],
env: Union[GymEnv, str, None],
policy_base: Type[BasePolicy],
learning_rate: Union[float, Callable],
policy_kwargs: Dict[str, Any] = None,
Expand Down Expand Up @@ -116,18 +136,9 @@ def __init__(self,
if env is not None:
if isinstance(env, str):
if create_eval_env:
eval_env = gym.make(env)
if monitor_wrapper:
eval_env = Monitor(eval_env, filename=None)
self.eval_env = DummyVecEnv([lambda: eval_env])
if self.verbose >= 1:
print("Creating environment from the given name, wrapped in a DummyVecEnv.")

env = gym.make(env)
if monitor_wrapper:
env = Monitor(env, filename=None)
env = DummyVecEnv([lambda: env])
self.eval_env = maybe_make_env(env, monitor_wrapper, self.verbose)

env = maybe_make_env(env, monitor_wrapper, self.verbose)
env = self._wrap_env(env)

self.observation_space = env.observation_space
Expand All @@ -136,8 +147,8 @@ def __init__(self,
self.env = env

if not support_multi_env and self.n_envs > 1:
raise ValueError("Error: the model does not support multiple envs requires a single vectorized"
" environment.")
raise ValueError("Error: the model does not support multiple envs; it requires "
"a single vectorized environment.")

def _wrap_env(self, env: GymEnv) -> VecEnv:
if not isinstance(env, VecEnv):
Expand All @@ -153,10 +164,7 @@ def _wrap_env(self, env: GymEnv) -> VecEnv:

@abstractmethod
def _setup_model(self) -> None:
"""
Create networks, buffer and optimizers
"""
raise NotImplementedError()
"""Create networks, buffer and optimizers."""

def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]:
"""
Expand Down Expand Up @@ -238,7 +246,7 @@ def set_env(self, env: GymEnv) -> None:

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
Get the name of the torch variable that will be saved.
Get the name of the torch variables that will be saved.
``th.save`` and ``th.load`` will be used with the right device
instead of the default pickling strategy.
Expand All @@ -263,18 +271,16 @@ def learn(self, total_timesteps: int,
Return a trained model.
:param total_timesteps: (int) The total number of samples (env steps) to train on
:param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm.
It takes the local and global variables. If it returns False, training is aborted.
:param callback: (MaybeCallback) callback(s) called at every step with state of the algorithm.
:param log_interval: (int) The number of timesteps before logging.
:param tb_log_name: (str) the name of the run for tensorboard log
:param tb_log_name: (str) the name of the run for TensorBoard logging
:param eval_env: (gym.Env) Environment that will be used to evaluate the agent
:param eval_freq: (int) Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
:param n_eval_episodes: (int) Number of episode to evaluate the agent
:param eval_log_path: (Optional[str]) Path to a folder where the evaluations will be saved
:param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging)
:return: (BaseAlgorithm) the trained model
"""
raise NotImplementedError()

def predict(self, observation: np.ndarray,
state: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -329,8 +335,6 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAl
# load parameters
model.__dict__.update(data)
model.__dict__.update(kwargs)
if not hasattr(model, "_setup_model") and len(params) > 0:
raise NotImplementedError(f"{cls} has no ``_setup_model()`` method")
model._setup_model()

# put state_dicts back in place
Expand Down Expand Up @@ -366,14 +370,18 @@ def set_random_seed(self, seed: Optional[int] = None) -> None:
self.eval_env.seed(seed)

def _init_callback(self,
callback: Union[None, Callable, List[BaseCallback], BaseCallback],
callback: MaybeCallback,
eval_env: Optional[VecEnv] = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None) -> BaseCallback:
"""
:param callback: (Union[callable, [BaseCallback], BaseCallback, None])
:return: (BaseCallback)
:param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm.
:param eval_freq: (Optional[int]) How many steps between evaluations; if None, do not evaluate.
:param n_eval_episodes: (int) How many episodes to play per evaluation
:param n_eval_episodes: (int) Number of episodes to rollout during evaluation.
:param log_path: (Optional[str]) Path to a folder where the evaluations will be saved
:return: (BaseCallback) A hybrid callback calling `callback` and performing evaluation.
"""
# Convert a list of callbacks into a callback
if isinstance(callback, list):
Expand All @@ -396,7 +404,7 @@ def _init_callback(self,
def _setup_learn(self,
total_timesteps: int,
eval_env: Optional[GymEnv],
callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None,
callback: MaybeCallback = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
Expand All @@ -407,11 +415,11 @@ def _setup_learn(self,
Initialize different variables needed for training.
:param total_timesteps: (int) The total number of samples (env steps) to train on
:param eval_env: (Optional[GymEnv])
:param callback: (Union[None, BaseCallback, List[BaseCallback, Callable]])
:param eval_env: (Optional[VecEnv]) Environment to use for evaluation.
:param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm.
:param eval_freq: (int) How many steps between evaluations
:param n_eval_episodes: (int) How many episodes to play per evaluation
:param log_path (Optional[str]): Path to a log folder
:param log_path: (Optional[str]) Path to a folder where the evaluations will be saved
:param reset_num_timesteps: (bool) Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: (str) the name of the run for tensorboard log
:return: (Tuple[int, BaseCallback])
Expand Down Expand Up @@ -480,8 +488,8 @@ def excluded_save_params(self) -> List[str]:
def save(
self,
path: Union[str, pathlib.Path, io.BufferedIOBase],
exclude: Optional[List[str]] = None,
include: Optional[List[str]] = None,
exclude: Optional[Iterable[str]] = None,
include: Optional[Iterable[str]] = None,
) -> None:
"""
Save all the attributes of the object and the model parameters in a zip-file.
Expand All @@ -492,29 +500,27 @@ def save(
"""
# copy parameter list so we don't mutate the original dict
data = self.__dict__.copy()
# use standard list of excluded parameters if none given

# Exclude is union of specified parameters (if any) and standard exclusions
if exclude is None:
exclude = self.excluded_save_params()
else:
# append standard exclude params to the given params
exclude.extend([param for param in self.excluded_save_params() if param not in exclude])
exclude = []
exclude = set(exclude).union(self.excluded_save_params())

# do not exclude params if they are specifically included
# Do not exclude params if they are specifically included
if include is not None:
exclude = [param_name for param_name in exclude if param_name not in include]
exclude = exclude.difference(include)

state_dicts_names, tensors_names = self.get_torch_variables()
# any params that are in the save vars must not be saved by data
torch_variables = state_dicts_names + tensors_names
for torch_var in torch_variables:
# we need to get only the name of the top most module as we'll remove that
var_name = torch_var.split('.')[0]
exclude.append(var_name)
exclude.add(var_name)

# Remove parameter entries of parameters which are to be excluded
for param_name in exclude:
if param_name in data:
data.pop(param_name, None)
data.pop(param_name, None)

# Build dict of tensor variables
tensors = None
Expand Down
Loading

0 comments on commit c39ed39

Please sign in to comment.