Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Categorical encoding for action space #593

Merged
merged 15 commits into from
Oct 25, 2022
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ as shape, device, dtype and domain.
NdUnboundedContinuousTensorSpec
BinaryDiscreteTensorSpec
MultOneHotDiscreteTensorSpec
DiscreteTensorSpec
CompositeSpec


Expand Down
25 changes: 22 additions & 3 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
BinaryDiscreteTensorSpec,
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
MultOneHotDiscreteTensorSpec,
NdBoundedTensorSpec,
NdUnboundedContinuousTensorSpec,
Expand All @@ -24,6 +25,7 @@
spec_dict = {
"bounded": BoundedTensorSpec,
"one_hot": OneHotDiscreteTensorSpec,
"categorical": DiscreteTensorSpec,
"unbounded": UnboundedContinuousTensorSpec,
"ndbounded": NdBoundedTensorSpec,
"ndunbounded": NdUnboundedContinuousTensorSpec,
Expand All @@ -35,6 +37,7 @@
default_spec_kwargs = {
BoundedTensorSpec: {"minimum": -1.0, "maximum": 1.0},
OneHotDiscreteTensorSpec: {"n": 7},
DiscreteTensorSpec: {"n": 7},
UnboundedContinuousTensorSpec: {},
NdBoundedTensorSpec: {"minimum": -torch.ones(4), "maxmimum": torch.ones(4)},
NdUnboundedContinuousTensorSpec: {
Expand Down Expand Up @@ -277,6 +280,7 @@ def __new__(
input_spec=None,
reward_spec=None,
from_pixels=False,
categorical_action_encoding=False,
**kwargs,
):
size = cls.size = 7
Expand All @@ -291,7 +295,12 @@ def __new__(
),
)
if action_spec is None:
action_spec = OneHotDiscreteTensorSpec(7)
action_spec_cls = (
DiscreteTensorSpec
if categorical_action_encoding
else OneHotDiscreteTensorSpec
)
action_spec = action_spec_cls(7)
if reward_spec is None:
reward_spec = UnboundedContinuousTensorSpec()

Expand All @@ -307,6 +316,7 @@ def __new__(
cls._observation_spec = observation_spec
cls._input_spec = input_spec
cls.from_pixels = from_pixels
cls.categorical_action_encoding = categorical_action_encoding
return super().__new__(*args, **kwargs)

def _get_in_obs(self, obs):
Expand All @@ -333,7 +343,9 @@ def _step(
) -> TensorDictBase:
tensordict = tensordict.to(self.device)
a = tensordict.get("action")
assert (a.sum(-1) == 1).all()

if not self.categorical_action_encoding:
assert (a.sum(-1) == 1).all()
assert not self.is_done, "trying to execute step in done env"

obs = self._get_in_obs(tensordict.get(self._out_key)) + a / self.maxstep
Expand Down Expand Up @@ -519,6 +531,7 @@ def __new__(
input_spec=None,
reward_spec=None,
from_pixels=True,
categorical_action_encoding=False,
**kwargs,
):
if observation_spec is None:
Expand All @@ -532,7 +545,12 @@ def __new__(
),
)
if action_spec is None:
action_spec = OneHotDiscreteTensorSpec(7)
action_spec_cls = (
DiscreteTensorSpec
if categorical_action_encoding
else OneHotDiscreteTensorSpec
)
action_spec = action_spec_cls(7)
if input_spec is None:
cls._out_key = "pixels_orig"
input_spec = CompositeSpec(
Expand All @@ -549,6 +567,7 @@ def __new__(
reward_spec=reward_spec,
input_spec=input_spec,
from_pixels=from_pixels,
categorical_action_encoding=categorical_action_encoding,
**kwargs,
)

Expand Down
126 changes: 126 additions & 0 deletions test/test_actors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import pytest
import torch
from torchrl.modules.tensordict_module.actors import (
QValueHook,
DistributionalQValueHook,
)


class TestQValue:
def test_qvalue_hook_wrong_action_space(self):
with pytest.raises(ValueError) as exc:
QValueHook(action_space="wrong_value")
assert "action_space must be one of" in str(exc.value)

def test_distributional_qvalue_hook_wrong_action_space(self):
with pytest.raises(ValueError) as exc:
DistributionalQValueHook(action_space="wrong_value", support=None)
assert "action_space must be one of" in str(exc.value)

@pytest.mark.parametrize(
"action_space, expected_action",
(
("one_hot", [0, 0, 1, 0, 0]),
("categorical", 2),
),
)
def test_qvalue_hook_0_dim_batch(self, action_space, expected_action):
hook = QValueHook(action_space=action_space)

in_values = torch.tensor([1.0, -1.0, 100.0, -2.0, -3.0])
action, values, chosen_action_value = hook(
net=None, observation=None, values=in_values
)

assert (torch.tensor(expected_action, dtype=torch.long) == action).all()
assert (values == in_values).all()
assert (torch.tensor([100.0]) == chosen_action_value).all()

@pytest.mark.parametrize(
"action_space, expected_action",
(
("one_hot", [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0]]),
("categorical", [2, 0]),
),
)
def test_qvalue_hook_1_dim_batch(self, action_space, expected_action):
hook = QValueHook(action_space=action_space)

in_values = torch.tensor(
[
[1.0, -1.0, 100.0, -2.0, -3.0],
[5.0, 4.0, 3.0, 2.0, -5.0],
]
)
action, values, chosen_action_value = hook(
net=None, observation=None, values=in_values
)

assert (torch.tensor(expected_action, dtype=torch.long) == action).all()
assert (values == in_values).all()
assert (torch.tensor([[100.0], [5.0]]) == chosen_action_value).all()

@pytest.mark.parametrize(
"action_space, expected_action",
(
("one_hot", [0, 0, 1, 0, 0]),
("categorical", 2),
),
)
def test_distributional_qvalue_hook_0_dim_batch(
self, action_space, expected_action
):
support = torch.tensor([-2.0, 0.0, 2.0])
hook = DistributionalQValueHook(action_space=action_space, support=support)

in_values = torch.nn.LogSoftmax(dim=-1)(
torch.tensor(
[
[1.0, -1.0, 11.0, -2.0, 30.0],
[1.0, -1.0, 1.0, -2.0, -3.0],
[1.0, -1.0, 10.0, -2.0, -3.0],
]
)
)
action, values = hook(net=None, observation=None, values=in_values)
expected_action = torch.tensor(expected_action, dtype=torch.long)

assert action.shape == expected_action.shape
assert (action == expected_action).all()
assert values.shape == in_values.shape
assert (values == in_values).all()

@pytest.mark.parametrize(
"action_space, expected_action",
(
("one_hot", [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0]]),
("categorical", [2, 0]),
),
)
def test_qvalue_hook_categorical_1_dim_batch(self, action_space, expected_action):
support = torch.tensor([-2.0, 0.0, 2.0])
hook = DistributionalQValueHook(action_space=action_space, support=support)

in_values = torch.nn.LogSoftmax(dim=-1)(
torch.tensor(
[
[
[1.0, -1.0, 11.0, -2.0, 30.0],
[1.0, -1.0, 1.0, -2.0, -3.0],
[1.0, -1.0, 10.0, -2.0, -3.0],
],
[
[11.0, -1.0, 7.0, -1.0, 20.0],
[10.0, 19.0, 1.0, -2.0, -3.0],
[1.0, -1.0, 0.0, -2.0, -3.0],
],
]
)
)
action, values = hook(net=None, observation=None, values=in_values)
expected_action = torch.tensor(expected_action, dtype=torch.long)

assert action.shape == expected_action.shape
assert (action == expected_action).all()
assert values.shape == in_values.shape
assert (values == in_values).all()
Loading