Skip to content

Commit 2cd6a4f

Browse files
Miffyliaraffin
andauthored
Match performance with stable-baselines (discrete case) (#110)
* Fix storing correct episode dones * Fix number of filters in NatureCNN network * Add TF-like RMSprop for matching performance with sb2 * Remove stuff that was accidentally included * Reformat * Clarify variable naming * Update changelog * Add comment on RMSprop implementations to A2C * Add test for RMSpropTFLike Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 3253ee1 commit 2cd6a4f

File tree

9 files changed

+149
-2
lines changed

9 files changed

+149
-2
lines changed

docs/misc/changelog.rst

+3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ Bug Fixes:
3434
- Fixed a bug with orthogonal initialization when `bias=False` in custom policy (@rk37)
3535
- Fixed approximate entropy calculation in PPO and A2C. (@andyshih12)
3636
- Fixed DQN target network sharing feature extractor with the main network.
37+
- Fixed storing correct ``dones`` in on-policy algorithm rollout collection. (@andyshih12)
38+
- Fixed number of filters in final convolutional layer in NatureCNN to match original implementation.
3739

3840
Deprecations:
3941
^^^^^^^^^^^^^
@@ -49,6 +51,7 @@ Others:
4951
- Ignored errors from newer pytype version
5052
- Added a check when using ``gSDE``
5153
- Removed codacy dependency from Dockerfile
54+
- Added ``common.sb2_compat.RMSpropTFLike`` optimizer, which corresponds closer to the implementation of RMSprop from Tensorflow.
5255

5356
Documentation:
5457
^^^^^^^^^^^^^^

docs/modules/a2c.rst

+8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ A synchronous, deterministic variant of `Asynchronous Advantage Actor Critic (A3
1010
It uses multiple workers to avoid the use of a replay buffer.
1111

1212

13+
.. warning::
14+
15+
If you find training unstable or want to match performance of stable-baselines A2C, consider using
16+
``RMSpropTFLike`` optimizer from ``stable_baselines3.common.sb2_compat.rmsprop_tf_like``.
17+
You can change optimizer with ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike))``.
18+
Read more `here <https://github.com/DLR-RM/stable-baselines3/pull/110#issuecomment-663255241>`_.
19+
20+
1321
Notes
1422
-----
1523

stable_baselines3/a2c/a2c.py

+1
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def train(self) -> None:
116116
# Update optimizer learning rate
117117
self._update_learning_rate(self.policy.optimizer)
118118

119+
# This will only loop once (get all data in one go)
119120
for rollout_data in self.rollout_buffer.get(batch_size=None):
120121

121122
actions = rollout_data.actions

stable_baselines3/common/base_class.py

+2
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__(
123123
self.tensorboard_log = tensorboard_log
124124
self.lr_schedule = None # type: Optional[Callable]
125125
self._last_obs = None # type: Optional[np.ndarray]
126+
self._last_dones = None # type: Optional[np.ndarray]
126127
# When using VecNormalize:
127128
self._last_original_obs = None # type: Optional[np.ndarray]
128129
self._episode_num = 0
@@ -474,6 +475,7 @@ def _setup_learn(
474475
# Avoid resetting the environment when calling ``.learn()`` consecutive times
475476
if reset_num_timesteps or self._last_obs is None:
476477
self._last_obs = self.env.reset()
478+
self._last_dones = np.zeros((self.env.num_envs,), dtype=np.bool)
477479
# Retrieve unnormalized observation for saving into the buffer
478480
if self._vec_normalize_env is not None:
479481
self._last_original_obs = self._vec_normalize_env.get_original_obs()

stable_baselines3/common/on_policy_algorithm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,9 @@ def collect_rollouts(
173173
if isinstance(self.action_space, gym.spaces.Discrete):
174174
# Reshape in case of discrete action
175175
actions = actions.reshape(-1, 1)
176-
rollout_buffer.add(self._last_obs, actions, rewards, dones, values, log_probs)
176+
rollout_buffer.add(self._last_obs, actions, rewards, self._last_dones, values, log_probs)
177177
self._last_obs = new_obs
178+
self._last_dones = dones
178179

179180
rollout_buffer.compute_returns_and_advantage(values, dones=dones)
180181

stable_baselines3/common/sb2_compat/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import torch
2+
from torch.optim import Optimizer
3+
4+
5+
class RMSpropTFLike(Optimizer):
6+
r"""Implements RMSprop algorithm with closer match to Tensorflow version.
7+
8+
For reproducibility with original stable-baselines. Use this
9+
version with e.g. A2C for stabler learning than with the PyTorch
10+
RMSProp. Based on the PyTorch v1.5.0 implementation of RMSprop.
11+
12+
See a more throughout conversion in pytorch-image-models repository:
13+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/rmsprop_tf.py
14+
15+
Changes to the original RMSprop:
16+
- Move epsilon inside square root
17+
- Initialize squared gradient to ones rather than zeros
18+
19+
Proposed by G. Hinton in his
20+
`course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
21+
22+
The centered version first appears in `Generating Sequences
23+
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
24+
25+
The implementation here takes the square root of the gradient average before
26+
adding epsilon (note that TensorFlow interchanges these two operations). The effective
27+
learning rate is thus :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha`
28+
is the scheduled learning rate and :math:`v` is the weighted moving average
29+
of the squared gradient.
30+
31+
Arguments:
32+
params (iterable): iterable of parameters to optimize or dicts defining
33+
parameter groups
34+
lr (float, optional): learning rate (default: 1e-2)
35+
momentum (float, optional): momentum factor (default: 0)
36+
alpha (float, optional): smoothing constant (default: 0.99)
37+
eps (float, optional): term added to the denominator to improve
38+
numerical stability (default: 1e-8)
39+
centered (bool, optional) : if ``True``, compute the centered RMSProp,
40+
the gradient is normalized by an estimation of its variance
41+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
42+
43+
"""
44+
45+
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
46+
if not 0.0 <= lr:
47+
raise ValueError("Invalid learning rate: {}".format(lr))
48+
if not 0.0 <= eps:
49+
raise ValueError("Invalid epsilon value: {}".format(eps))
50+
if not 0.0 <= momentum:
51+
raise ValueError("Invalid momentum value: {}".format(momentum))
52+
if not 0.0 <= weight_decay:
53+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
54+
if not 0.0 <= alpha:
55+
raise ValueError("Invalid alpha value: {}".format(alpha))
56+
57+
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
58+
super(RMSpropTFLike, self).__init__(params, defaults)
59+
60+
def __setstate__(self, state):
61+
super(RMSpropTFLike, self).__setstate__(state)
62+
for group in self.param_groups:
63+
group.setdefault("momentum", 0)
64+
group.setdefault("centered", False)
65+
66+
@torch.no_grad()
67+
def step(self, closure=None):
68+
"""Performs a single optimization step.
69+
70+
Arguments:
71+
closure (callable, optional): A closure that reevaluates the model
72+
and returns the loss.
73+
"""
74+
loss = None
75+
if closure is not None:
76+
with torch.enable_grad():
77+
loss = closure()
78+
79+
for group in self.param_groups:
80+
for p in group["params"]:
81+
if p.grad is None:
82+
continue
83+
grad = p.grad
84+
if grad.is_sparse:
85+
raise RuntimeError("RMSpropTF does not support sparse gradients")
86+
state = self.state[p]
87+
88+
# State initialization
89+
if len(state) == 0:
90+
state["step"] = 0
91+
# PyTorch initialized to zeros here
92+
state["square_avg"] = torch.ones_like(p, memory_format=torch.preserve_format)
93+
if group["momentum"] > 0:
94+
state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)
95+
if group["centered"]:
96+
state["grad_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
97+
98+
square_avg = state["square_avg"]
99+
alpha = group["alpha"]
100+
101+
state["step"] += 1
102+
103+
if group["weight_decay"] != 0:
104+
grad = grad.add(p, alpha=group["weight_decay"])
105+
106+
square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
107+
108+
if group["centered"]:
109+
grad_avg = state["grad_avg"]
110+
grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
111+
# PyTorch added epsilon after square root
112+
# avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(group['eps'])
113+
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add_(group["eps"]).sqrt_()
114+
else:
115+
# PyTorch added epsilon after square root
116+
# avg = square_avg.sqrt().add_(group['eps'])
117+
avg = square_avg.add(group["eps"]).sqrt_()
118+
119+
if group["momentum"] > 0:
120+
buf = state["momentum_buffer"]
121+
buf.mul_(group["momentum"]).addcdiv_(grad, avg)
122+
p.add_(buf, alpha=-group["lr"])
123+
else:
124+
p.addcdiv_(grad, avg, value=-group["lr"])
125+
126+
return loss

stable_baselines3/common/torch_layers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
7474
nn.ReLU(),
7575
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
7676
nn.ReLU(),
77-
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
77+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
7878
nn.ReLU(),
7979
nn.Flatten(),
8080
)

tests/test_custom_policy.py

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch as th
33

44
from stable_baselines3 import A2C, PPO, SAC, TD3
5+
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike
56

67

78
@pytest.mark.parametrize(
@@ -32,3 +33,8 @@ def test_custom_offpolicy(model_class, net_arch):
3233
def test_custom_optimizer(model_class, optimizer_kwargs):
3334
policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
3435
_ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(1000)
36+
37+
38+
def test_tf_like_rmsprop_optimizer():
39+
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
40+
_ = A2C("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(1000)

0 commit comments

Comments
 (0)