-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prioritized experience replay #1622
base: master
Are you sure you want to change the base?
Prioritized experience replay #1622
Conversation
- Created SumTree (to be ultimated) - Started PrioritizedReplayBuffer - constructor and 'sample' method - to be tested
@araffin could you (or anyone) please have a look at the 2 pytype errors? I don't quite understand how to fix them |
Thanks @araffin ! |
to be consistent with the rest of the buffers and because PyTorch is not needed here (no gpu computation needed). |
Hello @araffin , |
Added list of rainbow extensions, specifying which ones are currently implemented in the library
yes probably, but the most important thing for now is to test the implementation (performance test, check we can reproduce the results from the paper), document it and add additional tests/doc (for sumtree for instance). |
After some initial test on Breakout following hyperparameters from the paper, the run didn't improve or worsen DQN performance so far... |
Thanks for starting to test it! |
@araffin I've also done some initial tests and it looks like PER might lead to a slightly faster convergence, for example on cartpole, but nothing super evident unfortunately. |
# Special case when using PrioritizedReplayBuffer (PER) | ||
if isinstance(self.replay_buffer, PrioritizedReplayBuffer): | ||
# TD error in absolute value | ||
td_error = th.abs(current_q_values - target_q_values) | ||
# Weighted Huber loss using importance sampling weights | ||
loss = (replay_data.weights * th.where(td_error < 1.0, 0.5 * td_error**2, td_error - 0.5)).mean() | ||
# Update priorities, they will be proportional to the td error | ||
assert replay_data.leaf_nodes_indices is not None, "Node leaf node indices provided" | ||
self.replay_buffer.update_priorities( | ||
replay_data.leaf_nodes_indices, td_error, self._current_progress_remaining | ||
) | ||
else: | ||
# Compute Huber loss (less sensitive to outliers) | ||
loss = F.smooth_l1_loss(current_q_values, target_q_values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AlexPasqua Ideally, we'd like to be able to associate it with all off-policy algo's without adaptation, but I don't see a simple way of doing it at this stage.
Also related, we had discussed not modifying DQN: Stable-Baselines-Team/stable-baselines3-contrib#127 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm interested in this PR. Since every algo-specific train
method includes a replay_buffer.sample
line, couldn't we just additionally add a replay_buffer.update
line? The update function could take in the current and target q values whenever a value function is present or maybe even all the local variables. It would do nothing for the vanilla replay buffer. Would this be an acceptable modification?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comment!
How do you handle the loss in your proposal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want this to work for general off-policy algorithms, we could update the ReplayBufferSample-like classes to additionally include an importance_sampling_weight
attribute which would be updated from the replay_buffer.update
method.
Then I see two ways to handle the loss under this interface:
- Estimate TD error from the loss as such:
losses = loss_fn(current_q_values, target_q_values, reduction='none')
# e.g. If loss is L2, then it's basically th.sqrt(loss). If loss is L1, td_error = loss
td_error = importance_sampling_weight * function_to_approx_td_error(losses)
loss = losses.mean()
Obviously the downside of this is that it requires hand engineering for the different types of loss functions or priority metrics.
- Make any value-based
train
methods "td-error" centric in the sense that we always computetd_error = importance_sampling_weight * th.abs(current_q_values - target_q_values)
first, then the lossloss = loss_fn(td_error)
. The downsides of this approach is that we cant use the pytorch api for computing the loss, and would have to write functions for those.
Either approach requires computing a td_error
variable which unfortunately requires somewhat intrusive code changes. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe to make things clearer: my plan is not to have PER for all algorithms, mainly for two reasons:
- Keep the code concise (in fact, I would like to have RAINBOW and keep vanilla DQN, see [Feature Request] RAINBOW #622)
- I don't think it works for entropy-RL algorithms (SAC and derivates), so it would be limited to DQN/QR-DQN and TD3
If the users really want PER in other algo, they would take inspiration from a reference implementation in SB3 and integrate it (the same way we don't provide maskable + recurrent PPO at the same time).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"just" yes, I would be happy to receive such PR =)
the main thing is to benchmark the implementation and reproduce the published results.
This PR is also still open because I was not satisfied by the result of DQN + PER (I couldn't see significant different with respect to DQN).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing I had in mind was to implement CNN for SBX (https://github.com/araffin/sbx) in order to iterate faster and check the PER, but I had no time to do so until now...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we implement the toy environment from figure 1 of https://arxiv.org/pdf/1511.05952 as the PER benchmark? It would be a simpler initial check for correctness than the Atari environments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, will definitely work towards it!
Just a comment, I've tested this implementation with QR-DQN with Vecenv multiple environment but it fails because of the missing part. But good job to start the work on it! I hope it will be merged soon! 👍 |
I've just tried validating the implementation on blind cliffwalk and it seems much slower (~an order of magnitude) than the uniform replay buffer. The results below are for a one seed: Not sure why this is. The details for blind cliffwalk are a bit vague from the paper (no code available as well), but I've tried to implement it as close to the description as possible. Code for the test is in this gist: |
weights = (self.size() * probs) ** -self.beta | ||
weights = weights / weights.max() | ||
|
||
# TODO: add proper support for multi env |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How could we add proper support for multiple envs? Is there any idea? Does the random line below could work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure yet, the random line below might work but we need to check if it won't affect performance first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
araffin/sbx@b5ce091 should be better, see araffin/sbx#50
Some update from my part, I just added CNN support for SBX (SB3 + Jax) DQN, and it is 10x faster than the PyTorch equivalent: araffin/sbx#49 That should allow to test and debug things more quickly on Atari (~1h40 for 10M steps instead of 15h =D) Perf report: https://wandb.ai/openrlbenchmark/sbx?nw=nwuseraraffin (on-going) |
Some additional update: when trying to plug the PER implementation of this PR inside the Jax DQN implementation, the experience replay was the bottleneck (by a good margin, making things 40x slower...), so I investigated different ways to speed things up. After playing with many different implementation (pure python, numpy, jax, jax jitted, ...), I decided to re-use the SB2 "SegmentTree" vectorized implementation and also implement proper multi-env support. (still debugging, but at least I've got the first sign of life and this implementation is so much faster) |
Hey @araffin , it is great to hear that. Does SBX/Jax means this much speed improvement? If you think it is ready for testing I can give a try, just let me know when it is ready to be tested. :) |
With the right parameters (see the exact command line argument for the RL Zoo in the OpenRL benchmark organization run on W&B), yes, around 10x faster.
SBX version is ready to be tested but so far, I didn't manage to see any gain from the PER. I also experienced some explosion in the qf value when using multiple env (so there is probably a bug here). |
When I tested this PR I also noticed an explosion in loss, in that time I felt that it is because of the tweaking here and there. and I also noticed that it doesn't give me any advantage over a normal buffer(and I used Dobule DQN, even tried duelling), but I tried to tweak an N-step buffer which had a strong effect on the learning, AFAIK N-step(multi step) is also part of Rainbow and giving substantial part of the success. The key parts are the distributional, PER and N-step parts, as far as I understand the concept. The others are kinda tasks specific parts and can be detrimental to use them. |
Description
Implementation of prioritized replay buffer for DQN.
Closes #1242
Motivation and Context
In accordance with #1242
Types of changes
Checklist
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)make doc
(required)Note: You can run most of the checks using
make commit-checks
.Note: we are using a maximum length of 127 characters per line