-
Notifications
You must be signed in to change notification settings - Fork 290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Categorical encoding for action space #593
Changes from 11 commits
7303b6f
8b0d505
1fc515d
282af05
b0e5991
1eb5bde
6ea6512
a9dce75
4425417
54e256d
bb902f3
37071fe
40809eb
c9cc56f
813381a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import pytest | ||
import torch | ||
from torchrl.modules.tensordict_module.actors import ( | ||
QValueHook, | ||
DistributionalQValueHook, | ||
) | ||
|
||
|
||
def test_qvalue_hook_wrong_action_space(): | ||
with pytest.raises(ValueError): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's check that the message match, to make sure we're not capturing the wrong error There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a message check (had to make it short since it relies on order of the items in dict) |
||
QValueHook(action_space="wrong_value") | ||
|
||
|
||
def test_distributional_qvalue_hook_wrong_action_space(): | ||
with pytest.raises(ValueError): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed as in |
||
DistributionalQValueHook(action_space="wrong_value", support=None) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"action_space, expected_action", | ||
( | ||
("one_hot", [0, 0, 1, 0, 0]), | ||
("categorical", 2), | ||
), | ||
) | ||
def test_qvalue_hook_0_dim_batch(action_space, expected_action): | ||
hook = QValueHook(action_space=action_space) | ||
|
||
in_values = torch.tensor([1.0, -1.0, 100.0, -2.0, -3.0]) | ||
action, values, chosen_action_value = hook( | ||
net=None, observation=None, values=in_values | ||
) | ||
|
||
assert (torch.tensor(expected_action, dtype=torch.long) == action).all() | ||
assert (values == in_values).all() | ||
assert (torch.tensor([100.0]) == chosen_action_value).all() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"action_space, expected_action", | ||
( | ||
("one_hot", [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0]]), | ||
("categorical", [2, 0]), | ||
), | ||
) | ||
def test_qvalue_hook_1_dim_batch(action_space, expected_action): | ||
hook = QValueHook(action_space=action_space) | ||
|
||
in_values = torch.tensor( | ||
[ | ||
[1.0, -1.0, 100.0, -2.0, -3.0], | ||
[5.0, 4.0, 3.0, 2.0, -5.0], | ||
] | ||
) | ||
action, values, chosen_action_value = hook( | ||
net=None, observation=None, values=in_values | ||
) | ||
|
||
assert (torch.tensor(expected_action, dtype=torch.long) == action).all() | ||
assert (values == in_values).all() | ||
assert (torch.tensor([[100.0], [5.0]]) == chosen_action_value).all() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"action_space, expected_action", | ||
( | ||
("one_hot", [0, 0, 1, 0, 0]), | ||
("categorical", 2), | ||
), | ||
) | ||
def test_distributional_qvalue_hook_0_dim_batch(action_space, expected_action): | ||
support = torch.tensor([-2.0, 0.0, 2.0]) | ||
hook = DistributionalQValueHook(action_space=action_space, support=support) | ||
|
||
in_values = torch.nn.LogSoftmax(dim=-1)( | ||
torch.tensor( | ||
[ | ||
[1.0, -1.0, 11.0, -2.0, 30.0], | ||
[1.0, -1.0, 1.0, -2.0, -3.0], | ||
[1.0, -1.0, 10.0, -2.0, -3.0], | ||
] | ||
) | ||
) | ||
action, values = hook(net=None, observation=None, values=in_values) | ||
expected_action = torch.tensor(expected_action, dtype=torch.long) | ||
|
||
assert action.shape == expected_action.shape | ||
assert (action == expected_action).all() | ||
assert values.shape == in_values.shape | ||
assert (values == in_values).all() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"action_space, expected_action", | ||
( | ||
("one_hot", [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0]]), | ||
("categorical", [2, 0]), | ||
), | ||
) | ||
def test_qvalue_hook_categorical_1_dim_batch(action_space, expected_action): | ||
support = torch.tensor([-2.0, 0.0, 2.0]) | ||
hook = DistributionalQValueHook(action_space=action_space, support=support) | ||
|
||
in_values = torch.nn.LogSoftmax(dim=-1)( | ||
torch.tensor( | ||
[ | ||
[ | ||
[1.0, -1.0, 11.0, -2.0, 30.0], | ||
[1.0, -1.0, 1.0, -2.0, -3.0], | ||
[1.0, -1.0, 10.0, -2.0, -3.0], | ||
], | ||
[ | ||
[11.0, -1.0, 7.0, -1.0, 20.0], | ||
[10.0, 19.0, 1.0, -2.0, -3.0], | ||
[1.0, -1.0, 0.0, -2.0, -3.0], | ||
], | ||
] | ||
) | ||
) | ||
action, values = hook(net=None, observation=None, values=in_values) | ||
expected_action = torch.tensor(expected_action, dtype=torch.long) | ||
|
||
assert action.shape == expected_action.shape | ||
assert (action == expected_action).all() | ||
assert values.shape == in_values.shape | ||
assert (values == in_values).all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we put all those
test_qvalue
under aTestQValue
class?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that if we add a
test_actors.py
we should also move some tests there in a future PRThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.