Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Categorical encoding for action space #593

Merged
merged 15 commits into from
Oct 25, 2022
3 changes: 3 additions & 0 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
BinaryDiscreteTensorSpec,
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
MultOneHotDiscreteTensorSpec,
NdBoundedTensorSpec,
NdUnboundedContinuousTensorSpec,
Expand All @@ -24,6 +25,7 @@
spec_dict = {
"bounded": BoundedTensorSpec,
"one_hot": OneHotDiscreteTensorSpec,
"categorical": DiscreteTensorSpec,
"unbounded": UnboundedContinuousTensorSpec,
"ndbounded": NdBoundedTensorSpec,
"ndunbounded": NdUnboundedContinuousTensorSpec,
Expand All @@ -35,6 +37,7 @@
default_spec_kwargs = {
BoundedTensorSpec: {"minimum": -1.0, "maximum": 1.0},
OneHotDiscreteTensorSpec: {"n": 7},
DiscreteTensorSpec: {"n": 7},
UnboundedContinuousTensorSpec: {},
NdBoundedTensorSpec: {"minimum": -torch.ones(4), "maxmimum": torch.ones(4)},
NdUnboundedContinuousTensorSpec: {
Expand Down
126 changes: 126 additions & 0 deletions test/test_actors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import pytest
import torch
from torchrl.modules.tensordict_module.actors import (
QValueHook,
DistributionalQValueHook,
)


def test_qvalue_hook_wrong_action_space():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we put all those test_qvalue under a TestQValue class?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that if we add a test_actors.py we should also move some tests there in a future PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we put all those test_qvalue under a TestQValue class?
Fixed, thanks

with pytest.raises(ValueError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's check that the message match, to make sure we're not capturing the wrong error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a message check (had to make it short since it relies on order of the items in dict)

QValueHook(action_space="wrong_value")


def test_distributional_qvalue_hook_wrong_action_space():
with pytest.raises(ValueError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed as in test_qvalue_hook_wrong_action_space

DistributionalQValueHook(action_space="wrong_value", support=None)


@pytest.mark.parametrize(
"action_space, expected_action",
(
("one_hot", [0, 0, 1, 0, 0]),
("categorical", 2),
),
)
def test_qvalue_hook_0_dim_batch(action_space, expected_action):
hook = QValueHook(action_space=action_space)

in_values = torch.tensor([1.0, -1.0, 100.0, -2.0, -3.0])
action, values, chosen_action_value = hook(
net=None, observation=None, values=in_values
)

assert (torch.tensor(expected_action, dtype=torch.long) == action).all()
assert (values == in_values).all()
assert (torch.tensor([100.0]) == chosen_action_value).all()


@pytest.mark.parametrize(
"action_space, expected_action",
(
("one_hot", [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0]]),
("categorical", [2, 0]),
),
)
def test_qvalue_hook_1_dim_batch(action_space, expected_action):
hook = QValueHook(action_space=action_space)

in_values = torch.tensor(
[
[1.0, -1.0, 100.0, -2.0, -3.0],
[5.0, 4.0, 3.0, 2.0, -5.0],
]
)
action, values, chosen_action_value = hook(
net=None, observation=None, values=in_values
)

assert (torch.tensor(expected_action, dtype=torch.long) == action).all()
assert (values == in_values).all()
assert (torch.tensor([[100.0], [5.0]]) == chosen_action_value).all()


@pytest.mark.parametrize(
"action_space, expected_action",
(
("one_hot", [0, 0, 1, 0, 0]),
("categorical", 2),
),
)
def test_distributional_qvalue_hook_0_dim_batch(action_space, expected_action):
support = torch.tensor([-2.0, 0.0, 2.0])
hook = DistributionalQValueHook(action_space=action_space, support=support)

in_values = torch.nn.LogSoftmax(dim=-1)(
torch.tensor(
[
[1.0, -1.0, 11.0, -2.0, 30.0],
[1.0, -1.0, 1.0, -2.0, -3.0],
[1.0, -1.0, 10.0, -2.0, -3.0],
]
)
)
action, values = hook(net=None, observation=None, values=in_values)
expected_action = torch.tensor(expected_action, dtype=torch.long)

assert action.shape == expected_action.shape
assert (action == expected_action).all()
assert values.shape == in_values.shape
assert (values == in_values).all()


@pytest.mark.parametrize(
"action_space, expected_action",
(
("one_hot", [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0]]),
("categorical", [2, 0]),
),
)
def test_qvalue_hook_categorical_1_dim_batch(action_space, expected_action):
support = torch.tensor([-2.0, 0.0, 2.0])
hook = DistributionalQValueHook(action_space=action_space, support=support)

in_values = torch.nn.LogSoftmax(dim=-1)(
torch.tensor(
[
[
[1.0, -1.0, 11.0, -2.0, 30.0],
[1.0, -1.0, 1.0, -2.0, -3.0],
[1.0, -1.0, 10.0, -2.0, -3.0],
],
[
[11.0, -1.0, 7.0, -1.0, 20.0],
[10.0, 19.0, 1.0, -2.0, -3.0],
[1.0, -1.0, 0.0, -2.0, -3.0],
],
]
)
)
action, values = hook(net=None, observation=None, values=in_values)
expected_action = torch.tensor(expected_action, dtype=torch.long)

assert action.shape == expected_action.shape
assert (action == expected_action).all()
assert values.shape == in_values.shape
assert (values == in_values).all()
107 changes: 90 additions & 17 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
NdBoundedTensorSpec,
NdUnboundedContinuousTensorSpec,
TensorDict,
OneHotDiscreteTensorSpec,
DiscreteTensorSpec,
)
from torchrl.data.postprocs.postprocs import MultiStep

Expand Down Expand Up @@ -112,11 +114,21 @@ def get_devices():
class TestDQN:
seed = 0

def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
def _create_mock_actor(
self, action_spec_type, batch=2, obs_dim=3, action_dim=4, device="cpu"
):
# Actor
action_spec = NdBoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
if action_spec_type == "one_hot":
action_spec = OneHotDiscreteTensorSpec(action_dim)
elif action_spec_type == "categorical":
action_spec = DiscreteTensorSpec(action_dim)
elif action_spec_type == "nd_bounded":
action_spec = NdBoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
else:
raise ValueError(f"Wrong {action_spec_type}")

module = nn.Linear(obs_dim, action_dim)
actor = QValueActor(
spec=CompositeSpec(
Expand All @@ -127,21 +139,44 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
return actor

def _create_mock_distributional_actor(
self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5
self,
action_spec_type,
batch=2,
obs_dim=3,
action_dim=4,
atoms=5,
vmin=1,
vmax=5,
):
# Actor
action_spec = MultOneHotDiscreteTensorSpec([atoms] * action_dim)
if action_spec_type == "mult_one_hot":
action_spec = MultOneHotDiscreteTensorSpec([atoms] * action_dim)
elif action_spec_type == "one_hot":
action_spec = OneHotDiscreteTensorSpec(action_dim)
elif action_spec_type == "categorical":
action_spec = DiscreteTensorSpec(action_dim)
else:
raise ValueError(f"Wrong {action_spec_type}")
support = torch.linspace(vmin, vmax, atoms, dtype=torch.float)
module = MLP(obs_dim, (atoms, action_dim))
actor = DistributionalQValueActor(
spec=CompositeSpec(action=action_spec, action_value=None),
module=module,
support=support,
action_space="categorical"
if isinstance(action_spec, DiscreteTensorSpec)
else "one_hot",
)
return actor

def _create_mock_data_dqn(
self, batch=2, obs_dim=3, action_dim=4, atoms=None, device="cpu"
self,
action_spec_type,
batch=2,
obs_dim=3,
action_dim=4,
atoms=None,
device="cpu",
):
# create a tensordict
obs = torch.randn(batch, obs_dim)
Expand All @@ -154,6 +189,10 @@ def _create_mock_data_dqn(
else:
action_value = torch.randn(batch, action_dim)
action = (action_value == action_value.max(-1, True)[0]).to(torch.long)

if action_spec_type == "categorical":
action_value = torch.max(action_value, -1, keepdim=True)[0]
action = torch.argmax(action, -1, keepdim=True)
reward = torch.randn(batch, 1)
done = torch.zeros(batch, 1, dtype=torch.bool)
td = TensorDict(
Expand All @@ -171,7 +210,14 @@ def _create_mock_data_dqn(
return td

def _create_seq_mock_data_dqn(
self, batch=2, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu"
self,
action_spec_type,
batch=2,
T=4,
obs_dim=3,
action_dim=4,
atoms=None,
device="cpu",
):
# create a tensordict
total_obs = torch.randn(batch, T + 1, obs_dim, device=device)
Expand All @@ -187,6 +233,10 @@ def _create_seq_mock_data_dqn(
else:
action_value = torch.randn(batch, T, action_dim, device=device)
action = (action_value == action_value.max(-1, True)[0]).to(torch.long)

if action_spec_type == "categorical":
action_value = torch.max(action_value, -1, keepdim=True)[0]
action = torch.argmax(action, -1, keepdim=True)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
Expand All @@ -207,10 +257,17 @@ def _create_seq_mock_data_dqn(

@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_available_devices())
def test_dqn(self, delay_value, device):
@pytest.mark.parametrize(
"action_spec_type", ("nd_bounded", "one_hot", "categorical")
)
def test_dqn(self, delay_value, device, action_spec_type):
torch.manual_seed(self.seed)
actor = self._create_mock_actor(device=device)
td = self._create_mock_data_dqn(device=device)
actor = self._create_mock_actor(
action_spec_type=action_spec_type, device=device
)
td = self._create_mock_data_dqn(
action_spec_type=action_spec_type, device=device
)
loss_fn = DQNLoss(actor, gamma=0.9, loss_function="l2", delay_value=delay_value)
with _check_td_steady(td):
loss = loss_fn(td)
Expand Down Expand Up @@ -240,11 +297,18 @@ def test_dqn(self, delay_value, device):
@pytest.mark.parametrize("n", range(4))
@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_available_devices())
def test_dqn_batcher(self, n, delay_value, device, gamma=0.9):
@pytest.mark.parametrize(
"action_spec_type", ("nd_bounded", "one_hot", "categorical")
)
def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9):
torch.manual_seed(self.seed)
actor = self._create_mock_actor(device=device)
actor = self._create_mock_actor(
action_spec_type=action_spec_type, device=device
)

td = self._create_seq_mock_data_dqn(device=device)
td = self._create_seq_mock_data_dqn(
action_spec_type=action_spec_type, device=device
)
loss_fn = DQNLoss(
actor, gamma=gamma, loss_function="l2", delay_value=delay_value
)
Expand Down Expand Up @@ -292,11 +356,20 @@ def test_dqn_batcher(self, n, delay_value, device, gamma=0.9):
@pytest.mark.parametrize("atoms", range(4, 10))
@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_devices())
def test_distributional_dqn(self, atoms, delay_value, device, gamma=0.9):
@pytest.mark.parametrize(
"action_spec_type", ("mult_one_hot", "one_hot", "categorical")
)
def test_distributional_dqn(
self, atoms, delay_value, device, action_spec_type, gamma=0.9
):
torch.manual_seed(self.seed)
actor = self._create_mock_distributional_actor(atoms=atoms).to(device)
actor = self._create_mock_distributional_actor(
action_spec_type=action_spec_type, atoms=atoms
).to(device)

td = self._create_mock_data_dqn(atoms=atoms).to(device)
td = self._create_mock_data_dqn(
action_spec_type=action_spec_type, atoms=atoms
).to(device)
loss_fn = DistributionalDQNLoss(actor, gamma=gamma, delay_value=delay_value)

with _check_td_steady(td):
Expand Down
Loading