Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Making action masks compatible with q value modules and e-greedy #1499

Merged
merged 15 commits into from
Sep 7, 2023

Conversation

matteobettini
Copy link
Contributor

Action masks where introduced in #1421.

This PR has the job of making the components in the training pipeline use this mask.

The components that need updating are:

  • q modules and actors
  • e-greedy

This addresses some of the issues brought up in #1404

cc @Kang-SungKu @1030852813 @fedebotu

Signed-off-by: Matteo Bettini <[email protected]>
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 6, 2023
raise KeyError(
f"Action mask key {self.action_mask_key} not found in {tensordict}."
)
action_values[action_mask] = torch.finfo(action_values.dtype).min
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to discuss if this is the best choice for representing masked values

Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
@matteobettini matteobettini marked this pull request as ready for review September 7, 2023 08:59
@matteobettini
Copy link
Contributor Author

We have to decide what to do with the wrapper.
As of now i just added a new greedy module.
Do we want to deprecate the wrapper? what is the process for that?

Signed-off-by: Matteo Bettini <[email protected]>
torchrl/modules/tensordict_module/exploration.py Outdated Show resolved Hide resolved
test/test_exploration.py Show resolved Hide resolved
torchrl/modules/tensordict_module/exploration.py Outdated Show resolved Hide resolved
raise KeyError(
f"Action mask key {self.action_mask_key} not found in {tensordict}."
)
action_values[action_mask] = torch.finfo(action_values.dtype).min
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 things
(1) I think this is wrong, should be ~action_mask no?
(2) we should not modify the values in-place

rather torch.where(action_mask, action_values, torch.finfo(action_values.dtype).min) wdyt?

Copy link
Contributor Author

@matteobettini matteobettini Sep 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. oops this was left over
  2. those that pass gradients when the condition is applied? not sure

torchrl/modules/tensordict_module/actors.py Outdated Show resolved Hide resolved
@vmoens
Copy link
Contributor

vmoens commented Sep 7, 2023

We have to decide what to do with the wrapper. As of now i just added a new greedy module. Do we want to deprecate the wrapper? what is the process for that?

Add a deprecation warning in the constructor, say that it'll be deprecated in v0.3

Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
@matteobettini
Copy link
Contributor Author

should be ready

@vmoens vmoens added enhancement New feature or request Refactoring Refactoring of an existing feature labels Sep 7, 2023
torchrl/modules/tensordict_module/exploration.py Outdated Show resolved Hide resolved
torchrl/modules/tensordict_module/exploration.py Outdated Show resolved Hide resolved
torchrl/modules/tensordict_module/exploration.py Outdated Show resolved Hide resolved
torchrl/modules/tensordict_module/exploration.py Outdated Show resolved Hide resolved
torchrl/modules/tensordict_module/exploration.py Outdated Show resolved Hide resolved
torchrl/modules/tensordict_module/exploration.py Outdated Show resolved Hide resolved
torchrl/modules/tensordict_module/exploration.py Outdated Show resolved Hide resolved
torchrl/modules/tensordict_module/exploration.py Outdated Show resolved Hide resolved
matteobettini and others added 4 commits September 7, 2023 13:39
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for taking care of this
giphy (9)

@vmoens vmoens merged commit 786020d into pytorch:main Sep 7, 2023
48 of 56 checks passed
@matteobettini matteobettini deleted the mask_qvalue branch September 7, 2023 14:30
vmoens added a commit to hyerra/rl that referenced this pull request Oct 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request Refactoring Refactoring of an existing feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants