-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SB3 and SBX versions of SAC have radically different behaviours #55
Comments
Hello, |
Hi @araffin, Thanks for pointing out that difference. I hadn't seen it. I'm not sure I understand your point about jax not supporting lr schedule though. The adam optimizer in optax, which sbx uses, does support learning rate schedules (see here). Indeed, optax has a whole zoo of schedules available. Am I missing something? |
the optax schedule are done in term of gradient steps, where the SB3 schedule are using total timesteps (not known when creating the model). Last time I tried, it was not possible to do something like |
Looking at google-deepmind/optax#4, re-creating the optimizer using the previous state would be an option (although it sounds a bit overkill). |
Things are more stable when I use an optax linear learning rate schedule, but the performance is still bad. I noticed that SB3 SAC optimizes the log of the entropy coeff instead of the entropy coeff (link), as discussed here. In contrast, SBX SAC optimizes the entropy coeff (link). I've modified SBX SAC so that it optimizes the log of the entropy coeff, as in SB3 SAC, and now the performance is good (I haven't done extensive testing to see if it is as good as SB3, but it is certainly respectable now). I've made a PR for this change, which should improve SBX SAC and make it equivalent (in this respect) to SB3 SAC. I'm happy to look into incorporating optax schedules too, in a way that's consistent with how schedules are used in SB3. It seems like it shouldn't be difficult in principle. On this note, is there a reason why you prefer to update the learning rate based on environment time steps rather than gradient steps --- the latter seems more natural to me. For example, say you wait until the end of an episode to perform perform multiple gradients steps (e.g. as many gradient steps as time steps in the episode), then the learning rate changes abruptly at the end of each episode and is constant for all the gradient steps for that episode. In contrast, if it was updated based on gradient steps, the learning rate would gradually change. If the schedule is linear, the former will give you a piecewise constant learning rate while the latter will actually be linearly. |
when not talking about RL, I agree it is more natural. In RL, you don't always know how many gradient steps you are going to do.
that's true for episodic RL (which is not what most people do nowadays), although I would disagree that the lr changes abruptly (the length of an episode is usually << rate at which the lr schedule changes)
that would be a nice addition =) |
Fixed by #56 normally |
I am following a tutorial that trains the myosuite myoHandReorient8-v0 environment using the stable baselines 3 version of SAC. The main block of code for performing training (which details the SAC parameters, hence why I'm putting it here) is:
When I call this train function and use the stable baselines 3 version of SAC (
from stable_baselines3 import SAC
), the model trains well. However, if I instead use the sbx version of SAC (from sbx import SAC
), the actors loss, critic loss and entropy coefficient diverge:The mujoco simulation also often becomes unstable in the SBX case:
Naively, I would have thought that the SB3 and SBX versions of SAC would perform approximately the same for the same training parameters. Can you help me understand why this is not the case, and why parameters that work well for SB3 SAC are catastrophic for SBX SAC?
I am using stable_baselines3 2.3.2 and sbx 0.13.0.
The text was updated successfully, but these errors were encountered: