-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Importance Sampling #3
Comments
Nope, you're totally right! It's a bug. According to the original paper, you should multiply the TD error by the IS weights. You can do this in the repo code by changing the
I'm not sure if the weighting should go inside the MSE calculation or can be applied outside, like I've done. By that, I mean I don't know if the weights should inside the square operation or not. I think outside the MSE is the correct place. |
I think each sample in a minibatch of losses should be multiplied by the corresponding IS weights. However, torch by default returns an averaged MSE loss over the minibatch. Hence, to multiply each loss value in a batch by its corresponding weight, the MSE loss should be calculated by setting reduce flag in the torch.nn.functional.mse_loss() to False Also, I like to calculate the loss values in a minibatch by passing the targets and the predictions in the format [ [loss for sample1], [loss for sample 2], .... [loss for sample n] ] instead of [loss for sample 1, loss for sample 2,........., loss for sample n]. Not sure if the second format is a correct method or not. After multiplication, the batch of this weighted loss can then be averaged using tensor.mean() and then you can call backward() on this. |
In the original paper, the loss is also multiplied by the TD-error. Seems if we calculate the gradient of your loss, there is no TD-error as a factor in the coefficient? @stormont |
First of all, thanks for this implementation.
One question however: I see that the weights for the importance sampling are calculated and returned when a batch is sampled.
However, the weights aren't used further in the code.
Is this some legacy code from a feature that didn't work? Or is the code not finished yet?
Andreas
The text was updated successfully, but these errors were encountered: