Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Categorical encoding for action space #593

Merged
merged 15 commits into from
Oct 25, 2022

Conversation

artkorenev
Copy link
Contributor

@artkorenev artkorenev commented Oct 20, 2022

Description

Added alternative to one-hot encoding for action spaces with categorical features.

Motivation and Context

Implementing feature #538.

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • New feature (non-breaking change which adds core functionality)

Checklist

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Checklist from #538:

Current implementation seems to couple different parts of TorchRL quite tightly as there is quite a few places where it is assumed that the action space is one-hot. Perhaps, a more generic action-space object that encapsulates manipulation with values and that is shared across losses, actors, hooks, etc. (as a nn.Module maybe) might be a worth investment in future.

The mode is turned on by export CATEGORICAL_ACTION_ENCODING=True command, and it is turned off by default. Another alternative could be to provide it through the config-file, however this, again, would require a bit more generic solution to configure module that will be passed to all the necessary modules.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 20, 2022
@codecov
Copy link

codecov bot commented Oct 21, 2022

Codecov Report

Merging #593 (813381a) into main (e1fbf86) will increase coverage by 0.14%.
The diff coverage is 98.68%.

@@            Coverage Diff             @@
##             main     #593      +/-   ##
==========================================
+ Coverage   87.24%   87.39%   +0.14%     
==========================================
  Files         122      124       +2     
  Lines       22532    22784     +252     
==========================================
+ Hits        19658    19911     +253     
+ Misses       2874     2873       -1     
Flag Coverage Δ
linux-cpu 85.79% <98.68%> (+0.16%) ⬆️
linux-gpu 87.17% <98.68%> (+0.14%) ⬆️
linux-outdeps-gpu 76.19% <94.07%> (+0.23%) ⬆️
linux-stable-cpu 85.77% <98.68%> (+0.16%) ⬆️
linux-stable-gpu 87.17% <98.68%> (+0.14%) ⬆️
macos-cpu 85.56% <98.35%> (+0.16%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
test/test_cost.py 96.29% <93.93%> (-0.11%) ⬇️
torchrl/objectives/dqn.py 93.07% <96.29%> (+0.34%) ⬆️
test/test_utils.py 97.43% <97.43%> (ø)
test/mocking_classes.py 97.88% <100.00%> (+0.02%) ⬆️
test/test_actors.py 100.00% <100.00%> (ø)
test/test_env.py 98.87% <100.00%> (+0.03%) ⬆️
test/test_helpers.py 90.25% <100.00%> (+0.02%) ⬆️
test/test_modules.py 99.37% <100.00%> (+0.04%) ⬆️
test/test_tensor_spec.py 99.53% <100.00%> (+0.02%) ⬆️
test/test_trainer.py 97.88% <100.00%> (+0.01%) ⬆️
... and 8 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd like to add the new classes to the docs (this is not done automatically atm).
Just add them in docs/source/data.

key (str): name of the environment variable.
"""
val = os.environ.get(key, False)
if val in ("0", "False", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we consider the option False why not plain 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to "False" by default. Environment variables cannot be set to something other than str.

torchrl/_utils.py Show resolved Hide resolved
torchrl/data/tensor_specs.py Show resolved Hide resolved
torchrl/data/tensor_specs.py Show resolved Hide resolved
def __init__(
self,
n: int,
shape: Optional[torch.Size] = torch.Size((1,)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC we once had a linting issue with this (that seems to have disappeared now).
This is why we usually have a None and if None then shape is replaced by Size([1]).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted, changed to that

torchrl/modules/tensordict_module/actors.py Show resolved Hide resolved

if _CATEGORICAL_ACTION_ENCODING:
batch_size = action.size(0)
pred_val_index = pred_val[range(batch_size), action.squeeze(-1)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we'd like to be able to work with batch sizes that are more than unidimensional.

We can use gather (I think)

x = torch.randn(3, 4, 5)
idx = torch.randint(5, (3, 4))
x.gather(-1, idx)

This will index x along the last dimension.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not covered by the tests (it should be done in test_costs.py I believe).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reworked indexing here, thx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also added tests so should be covered by coverage by now

pred_val = td_copy.get("action_value")
pred_val_index = (pred_val * action).sum(-1)

if _CATEGORICAL_ACTION_ENCODING:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit tricky

What if someone explicitly builds a policy that uses categorical actions, but does not set the env variable?

My view of the env variable is more along the line of "a way of easily change a whole training script from categorical to one hot and vice-versa".

But the reversed dependency should not hold, and the environment variable should only be checked when building the network.

One option could be to infer the type of action by looking at the value_network: if it has registered the action spec, then we can use it to infer the type of action we will see. Otherwise, we must ask to the user to tell us (via an arg in the constructor) what kind of action we should be expecting.

Open to other suggestions obviously :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if someone explicitly builds a policy that uses categorical actions, but does not set the env variable?

This is exactly what I meant when mentioned that the current solution couples things a bit too much.

Since actions are passed as pure tensors (even though, wrapped in TensorDict) it is either assumed or configured separately for each class what to do with action tensors.

Now I can see two options what we can do here:

  1. Whenever we deal with actions, we either pass action_space in the constructor or derive it from some other entity (e.g. value_network as you proposed). However here we rely on careful configuration of all modules together (in theory, somebody could explicitly configure value_network incompatible with an environment). Another drawback is that we "spreading out" the logic for each case across the code base which can make code a harder to understand (I think in case of one-hot/categorical it is fine, but for future approaches it can blow up).
  2. We can introduce something like ActionSpec which would be an extended version of existing specs, though purely for action processing (deriving actions from value_network, etc.). We instantiate the ActionSpec once and pass its instance everywhere we need encapsulating work with actions. Although, it is hard for me to estimate what interface ActionSpec needs in order to cover all possible cases (I mean globally for all RL cases). And this also would require larger scale refactoring.

tldr; I think 1st option seems better choice for this situation and somewhat follows current code base practices (e.g. QValueHook)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I reworked it so it can be specified through config and this eliminates the whole problem with working with environment variables. Now we just specify binary flag that will set the mode of handling discrete gym environments and which also be passed as action space to value network and al the hooks.

log_ps_a = log_ps_a.view(batch_size, atoms) # log p(s_t, a_t; θonline)

if _CATEGORICAL_ACTION_ENCODING:
log_ps_a = action_log_softmax[range(batch_size), :, action.squeeze(-1)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments as above regarding indexing, usage of env variable and test coverage

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reworked this as well with a few tweaks since we work with atoms here.


if _CATEGORICAL_ACTION_ENCODING:
actor_kwargs.update({"action_space": "categorical"})
out_features = env_specs["action_spec"].space.n
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we could have an arg in the function that indicates if categorical has to be used.
If default (e.g. None) then it falls back on _CATEGORICAL_ACTION_ENCODING.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it here to be inferred from action spec that is set during the env setup.

@vmoens vmoens added the enhancement New feature or request label Oct 21, 2022
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Amazing work, if fits in the library very well now. I think we're there.

A couple of things before we merge this

  • Can you have a look at the few comments I left?
  • Can you have a look at the coverage report (some lines are not covered -- don't worry about things that are not covered but are in the test directory).

Oh and one more thing: Can you add the new classes to the doc? Have a look in docs/source/reference



def test_qvalue_hook_wrong_action_space():
with pytest.raises(ValueError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's check that the message match, to make sure we're not capturing the wrong error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a message check (had to make it short since it relies on order of the items in dict)



def test_distributional_qvalue_hook_wrong_action_space():
with pytest.raises(ValueError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed as in test_qvalue_hook_wrong_action_space

)


def test_qvalue_hook_wrong_action_space():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we put all those test_qvalue under a TestQValue class?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that if we add a test_actors.py we should also move some tests there in a future PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we put all those test_qvalue under a TestQValue class?
Fixed, thanks

@@ -0,0 +1,62 @@
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note to myself: we should move some tests here (e.g. timeit etc)

@artkorenev
Copy link
Contributor Author

Awesome! Thank you for the review!
It seems like the coverage is now fixed completely.

@vmoens vmoens merged commit 61b80f8 into pytorch:main Oct 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants