-
Notifications
You must be signed in to change notification settings - Fork 314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Making action masks compatible with q value modules and e-greedy #1499
Conversation
Signed-off-by: Matteo Bettini <[email protected]>
raise KeyError( | ||
f"Action mask key {self.action_mask_key} not found in {tensordict}." | ||
) | ||
action_values[action_mask] = torch.finfo(action_values.dtype).min |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to discuss if this is the best choice for representing masked values
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
We have to decide what to do with the wrapper. |
Signed-off-by: Matteo Bettini <[email protected]>
raise KeyError( | ||
f"Action mask key {self.action_mask_key} not found in {tensordict}." | ||
) | ||
action_values[action_mask] = torch.finfo(action_values.dtype).min |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 things
(1) I think this is wrong, should be ~action_mask
no?
(2) we should not modify the values in-place
rather torch.where(action_mask, action_values, torch.finfo(action_values.dtype).min)
wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- oops this was left over
- those that pass gradients when the condition is applied? not sure
Add a deprecation warning in the constructor, say that it'll be deprecated in v0.3 |
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
should be ready |
Co-authored-by: Vincent Moens <[email protected]>
Co-authored-by: Vincent Moens <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…reedy (pytorch#1499) Signed-off-by: Matteo Bettini <[email protected]> Co-authored-by: Vincent Moens <[email protected]>
Action masks where introduced in #1421.
This PR has the job of making the components in the training pipeline use this mask.
The components that need updating are:
This addresses some of the issues brought up in #1404
cc @Kang-SungKu @1030852813 @fedebotu