Skip to content

Commit

Permalink
Upgrade Stable-Baselines3 (#19)
Browse files Browse the repository at this point in the history
* Upgrade Stable-Baselines3

* Fix policy saving/loading
  • Loading branch information
araffin authored Feb 27, 2021
1 parent b15cc3d commit 74e6038
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 21 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
Changelog
==========

Pre-Release 0.11.0a5 (WIP)
Pre-Release 0.11.0 (2021-02-27)
-------------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 0.11.0

New Features:
^^^^^^^^^^^^^
Expand Down
8 changes: 4 additions & 4 deletions sb3_contrib/qrdqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Ten
action = q_values.argmax(dim=1).reshape(-1)
return action

def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()

data.update(
dict(
Expand Down Expand Up @@ -176,8 +176,8 @@ def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
return self.quantile_net._predict(obs, deterministic=deterministic)

def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()

data.update(
dict(
Expand Down
7 changes: 2 additions & 5 deletions sb3_contrib/qrdqn/qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ class QRDQN(OffPolicyAlgorithm):
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Set to `-1` to disable.
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout
(see ``train_freq`` and ``n_episodes_rollout``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes.
Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
Expand Down Expand Up @@ -66,7 +65,6 @@ def __init__(
gamma: float = 0.99,
train_freq: int = 4,
gradient_steps: int = 1,
n_episodes_rollout: int = -1,
optimize_memory_usage: bool = False,
target_update_interval: int = 10000,
exploration_fraction: float = 0.005,
Expand Down Expand Up @@ -94,7 +92,6 @@ def __init__(
gamma,
train_freq,
gradient_steps,
n_episodes_rollout,
action_noise=None, # No action noise
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
Expand Down
8 changes: 4 additions & 4 deletions sb3_contrib/tqc/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def __init__(
self.mu = nn.Linear(last_layer_dim, action_dim)
self.log_std = nn.Linear(last_layer_dim, action_dim)

def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()

data.update(
dict(
Expand Down Expand Up @@ -374,8 +374,8 @@ def _build(self, lr_schedule: Callable) -> None:

self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)

def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()

data.update(
dict(
Expand Down
7 changes: 2 additions & 5 deletions sb3_contrib/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ class TQC(OffPolicyAlgorithm):
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps.
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient update after each step
:param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes.
Note that this cannot be used at the same time as ``train_freq``
:param action_noise: the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
Expand Down Expand Up @@ -74,7 +73,6 @@ def __init__(
gamma: float = 0.99,
train_freq: int = 1,
gradient_steps: int = 1,
n_episodes_rollout: int = -1,
action_noise: Optional[ActionNoise] = None,
optimize_memory_usage: bool = False,
ent_coef: Union[str, float] = "auto",
Expand Down Expand Up @@ -105,7 +103,6 @@ def __init__(
gamma,
train_freq,
gradient_steps,
n_episodes_rollout,
action_noise,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.11.0a5
0.11.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3[tests,docs]>=0.11.0a2",
"stable_baselines3[tests,docs]>=0.11.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
Expand Down

0 comments on commit 74e6038

Please sign in to comment.