-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Match performance with stable-baselines (discrete case) #110
Conversation
@@ -74,7 +74,7 @@ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512): | |||
nn.ReLU(), | |||
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), | |||
nn.ReLU(), | |||
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0), | |||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch 🙈 Please don't tell me that solve your performance issue.
I know where it comes from ... I shouldn't have copy-pasted from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/model.py#L169
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thinking about that, we need to double check VecFrameStack
, even though it is the same as SB2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sadly (luckily? =) ) it did not fix the issues yet. SB3 is still consistently worse in a few of the Atari games I have tested. I am in the process of step-by-step comparisons, will see how that goes.
Edit: Ah yes, stacking on the wrong channels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or having that kind of issue: ikostrikov/pytorch-a2c-ppo-acktr-gail@84a7582
btw, is it better now with OMP_NUM_THREADS=1
w.r.t. fps? (maybe you should write in the comment the current stand of SB2 vs SB3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one thing that may change is the optimizer implementation and default parameters, for the initialization, I think (at least I tried) to reproduce what was done in SB2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my question was more what is the fps we want to reach? (what did you have with SB2?)
Hmm I do not have conclusive numbers just yet because I have been running many experiments on same system and can not guarantee fair comparisons, but I think PyTorch variants are about 10% slower with Atari games and 25% slower on toy environments. The latter required the OMP_NUM_THREADS tuning. This sounds reasonable, given the non-compiled nature of PyTorch and the fact the code has not been optimized much yet.
Yes, the issue was that nminibatches lead to different mini-batchsize depending on the number of environments
Ah alright. I will write big notes about this on the "moving from stable-baselines" docs :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One major change in parameters is the use of
batch_size=64
rather thannminibatches=4
in PPO. Using such small batch-size made things very slow FPS-wise, but in some cases sped up the learning. I will focus more on these running-speed things in an another PR.
I would like to add that we may be able to gain a non minuscule speedup by avoiding single data stores but instead storing a whole batch at once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I started documenting the migration here ;)
#123
I would like to add that we may be able to gain a non minuscule speedup by avoiding single data stores but instead storing a while batch at once.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mistyped, I meant that if we store a whole batch at once, we should get a sizeable speedup over storing one transition at a time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still not sure what you mean...
Listing what can be different from PyTorch vs Tensorflow:
EDIT: the tf clip norm seems to be here https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/clip_ops.py#L291 (not that easy to read vs pytorch), and the doc: https://www.tensorflow.org/api_docs/python/tf/clip_by_global_norm |
I am wondering: are you using |
I am using the parameters from rl-zoo for Atari PPO runs, where vf clipping is disabled and I set |
Some progression with A2C: With CartPole you get very similar learning curves (below, averaged over 10 random seeds) with rl-zoo parameters, after you update the PyTorch RMSProp to match TF's implementation. Turns out PyTorch RMSProp does things a little bit different, and these are crucial for stable learning like shown. These changes require a new optimizer or changes to PyTorch code, so should we include a modified RMSProp in stable-baselines3 like done here in another repo? We could include this as an additional optimizer and instruct to use it if one wants to replicate sb2 results, but we could also consider making it default RMSProp optimizer because of its (apparent) stability. |
A2C seems to check out mostly (see the original post with plots) with the fixed RMSprop that is now included under |
including Atari games?
Sounds reasonable, I don't see any better solution... The only thing is which default should we use? |
Yup! See the original post with plots. To me they seem "close enough" (with this limited amount of runs), except for Q*Bert which at end gets a sudden boost in performance in sb2. I will be checking PPO next and see if there is something common to A2C and PPO the is derp.
TF variant seems more stable and pytorch-image-models repo guys also say they have had better success with it. I'd personally go with that one by default.
Remember to set the parameters manually! I forgot this first time around ^^ policy_kwargs["optimizer_class"] = RMSpropTFLike
policy_kwargs["optimizer_kwargs"] = dict(alpha=0.99, eps=1e-5, weight_decay=0) |
After a quick run on Bullet envs, tf-rmsprop: 1192 |
How is the stability, though? I noticed Edit: In the light of these results we could keep the original enabled by default, though, and instruct people to use the TF-variant if they are experiencing unstable learning. |
A bit unstable at the beginning.
Yes, and add the tf-version as default in the zoo for Atari?
I only see the plots where the two are similar.
In the original post, I only see ppo plots... |
Works for me 👍
Hmm there should be four A2C plots in total under "TODO" heading: A2C cartpole comparisons (with rmsprop fixes), sb2 and sb3 Atari results for A2C and sb3 Atari results without rmsprop fix. |
🙈 I was looking at the issue, not the PR... |
Ran some more Atari PPO runs and now the performance seems to match (see the original post for plots). SB3 seems to be consistently lower than SB2 but nothing seems horribly broken. Q*Bert has an edge on SB2 for some reason with both PPO and A2C. I will be re-running experiments with more seeds, but that will take time. @araffin could you comment on the learning curves and tell what you think about the results? |
Do you know if the ADAM implementation is the same for A2C/PPO? |
Quick googling and skimming over the codes they seem to match, and also the A2C experiments matched with Adam (equally unstable :D), so I believe that part checks out. |
And how many random seeds did you try? |
Each of the curves is slightly different setup but, in general, tend to have the same result (see Figure 5 here, where we have five random seeds per curve). I.e. you can treat each curve as separate run with different random seed. But I will run some more for better conclusion. |
Note: I will retry to run DQN with the updated network and maybe with the updated RMSprop |
Actually DQN uses Adam for optimizing, and it has been using it since stable-baselinse2, while (I think) the original implementation used rmsprop. It might be worth of trying out what happens if you change the optimizer to stabler rmsprop, as Adam made things unstable with PPO. On sidenote: I ran Pong on sb3 DQN and was not able to get any improvement while sb2 learns it quickly (inside ~2M steps). I thought sb3 DQN was able to learn Pong, tho? Using parameters from rl-zoo, minus prioritized memory etc. |
It was but not as good as expected... SB2 DQN has nothing to do with vanilla DQN... |
To clarify to others: araffin referred to the fact how, by default, SB2 DQN has bunch of modifications enabled (Double-Q, Dueling). Those were disabled for those runs. I ran more experiments with Atari with the recent hotfix #132 . The learning curves are included in the main post and match mostly. While not perfect I can not tell if issue is in lack of random seeds used (three is rather low), and in any case I do not have the compute to run enough training runs to debug deeper if something differs. |
Looks good, no? SB3 DQN has even slightly better performance on one and I'm pretty sure SB3 DQN is faster than SB2, no? |
Otherwise, it looks like it is ready to merge, no? |
Preferably I would want to performance match in both good and bad (i.e. not better or worse) just to keep consistent results, but that'd still require a lot of work ^^. I used the hyperparameters from sb2 rl-zoo, plus disabling all the DQN improvements for SB2. I am not quite sure what you mean by "update defaults". |
I meant updating the default hyperparameters. The current ones are from the DQN nature paper and therefore do no correspond to your benchmark. The main differences are the buffer size and the final value of the exploration rate. |
Hmm I would those values from the original paper, as this is what users would expect when seeing "DQN". I do not think these parameters I used are the best (do not learn fastest / stablest), but I needed the replay-buffer size at the very least to be able to fit multiple experiments at same time on same machine. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, very impressive and valuable detective work =)
Thank you for your hard work on this to investigate and align the performance! This PR is currently referenced in the Atari Results section of the documentation here: https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html Regarding the learning curves, would you please be able to clarify,
Thank you for your help |
Thanks for the kind words! I ran these experiments using a different code base from zoo (one I was most familiar at the time), so replicating results exactly might be bit tricky.
|
This PR will be done when stable-baselines3 agent performance matches stable-baselines in discrete envs. Will be tested on discrete control tasks and Atari environments.
Closes #49
Closes #105
PS: Sorry about the confusing branch-name.
Changes
common.sb2_compat.RMSpropTFLike
, which is a modification of RMSprop that matches TF version, and is required for matching performance in A2C.TODO
Match performance of A2C and PPO.
A2C Cartpole matches (mostly, see this. Averaged over 10 random seeds for both. Requires the TF-like RMSprop, and even still in the very end SB3 seems more unstable.)
A2C Atari matches (mostly, see sb2 and sb3. Original sb3 result here. Three random seeds, each line separate run (ignore legend). Using TF-like RMSprop. Performance and stability mostly matches, except sb2 has sudden spike in performance in Q*Bert. Something to do with stability in distributions?)
PPO Cartpole (using rl-zoo parameters, see learning curves, averaged over 20 random seeds)
PPO Atari (mostly, see sb2 and sb3 results (shaded curves averaged over two seeds). Q*Bert still seems to have an edge on SB2 for unknown reasons)
Check and match performance of DQN. Seems ok. See following learning curves, each curve is an average over three random seeds:
atari_spaceinvaders.pdf
atari_qbert.pdf
atari_breakout.pdf
atari_pong.pdf
Check if "dones" fix can (and should) be moved to computing GAE side.
Write docs on how to match A2C and PPO settings to stable-baselines ("moving from stable-baselines"). There are some important quirks to note here.Move this to migration guide PR Migration Guide #123 .Types of changes
Checklist:
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)