Skip to content

Commit 786020d

Browse files
[Feature] Making action masks compatible with q value modules and e-greedy (#1499)
Signed-off-by: Matteo Bettini <[email protected]> Co-authored-by: Vincent Moens <[email protected]>
1 parent 153337e commit 786020d

File tree

9 files changed

+442
-19
lines changed

9 files changed

+442
-19
lines changed

docs/source/reference/modules.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ other cases, the action written in the tensordict is simply the network output.
6868
:template: rl_template_noinherit.rst
6969

7070
AdditiveGaussianWrapper
71-
EGreedyWrapper
71+
EGreedyModule
7272
OrnsteinUhlenbeckProcessWrapper
7373

7474
Probabilistic actors

examples/multiagent/iql.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def train(cfg: "DictConfig"): # noqa: F821
101101
eps_end=0,
102102
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
103103
action_key=env.action_key,
104-
spec=env.unbatched_action_spec[env.action_key],
104+
spec=env.unbatched_action_spec,
105105
)
106106

107107
collector = SyncDataCollector(

examples/multiagent/qmix_vdn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def train(cfg: "DictConfig"): # noqa: F821
102102
eps_end=0,
103103
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
104104
action_key=env.action_key,
105-
spec=env.unbatched_action_spec[env.action_key],
105+
spec=env.unbatched_action_spec,
106106
)
107107

108108
if cfg.loss.mixer_type == "qmix":

test/test_actors.py

+33
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,39 @@ def test_qvalue_hook_categorical_1_dim_batch(self, action_space, expected_action
613613
assert values.shape == in_values.shape
614614
assert (values == in_values).all()
615615

616+
@pytest.mark.parametrize("action_space", ["categorical", "one-hot"])
617+
@pytest.mark.parametrize("action_n", [2, 3, 4, 5])
618+
def test_qvalue_mask(self, action_space, action_n):
619+
torch.manual_seed(0)
620+
shape = (3, 4, 3, action_n)
621+
action_values = torch.randn(size=shape)
622+
td = TensorDict({"action_value": action_values}, [3])
623+
module = QValueModule(
624+
action_space=action_space,
625+
action_value_key="action_value",
626+
action_mask_key="action_mask",
627+
)
628+
with pytest.raises(KeyError, match="Action mask key "):
629+
module(td)
630+
631+
action_mask = torch.randint(high=2, size=shape).to(torch.bool)
632+
while not action_mask.any(dim=-1).all() or action_mask.all():
633+
action_mask = torch.randint(high=2, size=shape).to(torch.bool)
634+
635+
td.set("action_mask", action_mask)
636+
module(td)
637+
new_action_values = td.get("action_value")
638+
639+
assert (new_action_values[~action_mask] != action_values[~action_mask]).all()
640+
assert (new_action_values[action_mask] == action_values[action_mask]).all()
641+
assert (td.get("chosen_action_value") > torch.finfo(torch.float).min).all()
642+
643+
if action_space == "one-hot":
644+
assert (td.get("action")[action_mask]).any()
645+
assert not (td.get("action")[~action_mask]).any()
646+
else:
647+
assert action_mask.gather(-1, td.get("action").unsqueeze(-1)).all()
648+
616649

617650
@pytest.mark.parametrize("device", get_default_devices())
618651
def test_value_based_policy(device):

test/test_exploration.py

+155-6
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,18 @@
1414
NestedCountingEnv,
1515
)
1616
from scipy.stats import ttest_1samp
17-
from tensordict.nn import InteractionType, TensorDictModule
17+
18+
from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential
1819
from tensordict.tensordict import TensorDict
1920
from torch import nn
2021

2122
from torchrl.collectors import SyncDataCollector
22-
from torchrl.data import BoundedTensorSpec, CompositeSpec
23+
from torchrl.data import (
24+
BoundedTensorSpec,
25+
CompositeSpec,
26+
DiscreteTensorSpec,
27+
OneHotDiscreteTensorSpec,
28+
)
2329
from torchrl.envs import SerialEnv
2430
from torchrl.envs.transforms.transforms import gSDENoise, InitTracker, TransformedEnv
2531
from torchrl.envs.utils import set_exploration_type
@@ -30,23 +36,37 @@
3036
NormalParamWrapper,
3137
)
3238
from torchrl.modules.models.exploration import LazygSDEModule
33-
from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor
39+
from torchrl.modules.tensordict_module.actors import (
40+
Actor,
41+
ProbabilisticActor,
42+
QValueActor,
43+
)
3444
from torchrl.modules.tensordict_module.exploration import (
3545
_OrnsteinUhlenbeckProcess,
3646
AdditiveGaussianWrapper,
47+
EGreedyModule,
3748
EGreedyWrapper,
3849
OrnsteinUhlenbeckProcessWrapper,
3950
)
4051

4152

42-
@pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0])
4353
class TestEGreedy:
44-
def test_egreedy(self, eps_init):
54+
@pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0])
55+
@pytest.mark.parametrize("module", [True, False])
56+
def test_egreedy(self, eps_init, module):
4557
torch.manual_seed(0)
4658
spec = BoundedTensorSpec(1, 1, torch.Size([4]))
4759
module = torch.nn.Linear(4, 4, bias=False)
60+
4861
policy = Actor(spec=spec, module=module)
49-
explorative_policy = EGreedyWrapper(policy, eps_init=eps_init, eps_end=eps_init)
62+
if module:
63+
explorative_policy = TensorDictSequential(
64+
policy, EGreedyModule(eps_init=eps_init, eps_end=eps_init, spec=spec)
65+
)
66+
else:
67+
explorative_policy = EGreedyWrapper(
68+
policy, eps_init=eps_init, eps_end=eps_init
69+
)
5070
td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
5171
action = explorative_policy(td).get("action")
5272
if eps_init == 0:
@@ -58,6 +78,135 @@ def test_egreedy(self, eps_init):
5878
assert (action == 0).any()
5979
assert ((action == 1) | (action == 0)).all()
6080

81+
@pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0])
82+
@pytest.mark.parametrize("module", [True, False])
83+
@pytest.mark.parametrize("spec_class", ["discrete", "one_hot"])
84+
def test_egreedy_masked(self, module, eps_init, spec_class):
85+
torch.manual_seed(0)
86+
action_size = 4
87+
batch_size = (3, 4, 2)
88+
module = torch.nn.Linear(action_size, action_size, bias=False)
89+
if spec_class == "discrete":
90+
spec = DiscreteTensorSpec(action_size)
91+
else:
92+
spec = OneHotDiscreteTensorSpec(
93+
action_size,
94+
shape=(action_size,),
95+
)
96+
policy = QValueActor(spec=spec, module=module, action_mask_key="action_mask")
97+
if module:
98+
explorative_policy = TensorDictSequential(
99+
policy,
100+
EGreedyModule(
101+
eps_init=eps_init,
102+
eps_end=eps_init,
103+
spec=spec,
104+
action_mask_key="action_mask",
105+
),
106+
)
107+
else:
108+
explorative_policy = EGreedyWrapper(
109+
policy,
110+
eps_init=eps_init,
111+
eps_end=eps_init,
112+
action_mask_key="action_mask",
113+
)
114+
115+
td = TensorDict(
116+
{"observation": torch.zeros(*batch_size, action_size)},
117+
batch_size=batch_size,
118+
)
119+
with pytest.raises(KeyError, match="Action mask key action_mask not found in"):
120+
explorative_policy(td)
121+
122+
torch.manual_seed(0)
123+
action_mask = torch.ones(*batch_size, action_size).to(torch.bool)
124+
td = TensorDict(
125+
{
126+
"observation": torch.zeros(*batch_size, action_size),
127+
"action_mask": action_mask,
128+
},
129+
batch_size=batch_size,
130+
)
131+
action = explorative_policy(td).get("action")
132+
133+
torch.manual_seed(0)
134+
action_mask = torch.randint(high=2, size=(*batch_size, action_size)).to(
135+
torch.bool
136+
)
137+
while not action_mask.any(dim=-1).all() or action_mask.all():
138+
action_mask = torch.randint(high=2, size=(*batch_size, action_size)).to(
139+
torch.bool
140+
)
141+
142+
td = TensorDict(
143+
{
144+
"observation": torch.zeros(*batch_size, action_size),
145+
"action_mask": action_mask,
146+
},
147+
batch_size=batch_size,
148+
)
149+
masked_action = explorative_policy(td).get("action")
150+
151+
if spec_class == "discrete":
152+
action = spec.to_one_hot(action)
153+
masked_action = spec.to_one_hot(masked_action)
154+
155+
assert not (action[~action_mask] == 0).all()
156+
assert (masked_action[~action_mask] == 0).all()
157+
158+
def test_egreedy_wrapper_deprecation(self):
159+
torch.manual_seed(0)
160+
spec = BoundedTensorSpec(1, 1, torch.Size([4]))
161+
module = torch.nn.Linear(4, 4, bias=False)
162+
policy = Actor(spec=spec, module=module)
163+
with pytest.deprecated_call():
164+
EGreedyWrapper(policy)
165+
166+
def test_no_spec_error(
167+
self,
168+
):
169+
torch.manual_seed(0)
170+
action_size = 4
171+
batch_size = (3, 4, 2)
172+
module = torch.nn.Linear(action_size, action_size, bias=False)
173+
spec = OneHotDiscreteTensorSpec(action_size, shape=(action_size,))
174+
policy = QValueActor(spec=spec, module=module)
175+
explorative_policy = TensorDictSequential(
176+
policy,
177+
EGreedyModule(spec=None),
178+
)
179+
td = TensorDict(
180+
{
181+
"observation": torch.zeros(*batch_size, action_size),
182+
},
183+
batch_size=batch_size,
184+
)
185+
186+
with pytest.raises(
187+
RuntimeError, match="spec must be provided to the exploration wrapper."
188+
):
189+
explorative_policy(td)
190+
191+
@pytest.mark.parametrize("module", [True, False])
192+
def test_wrong_action_shape(self, module):
193+
torch.manual_seed(0)
194+
spec = BoundedTensorSpec(1, 1, torch.Size([4]))
195+
module = torch.nn.Linear(4, 5, bias=False)
196+
197+
policy = Actor(spec=spec, module=module)
198+
if module:
199+
explorative_policy = TensorDictSequential(policy, EGreedyModule(spec=spec))
200+
else:
201+
explorative_policy = EGreedyWrapper(
202+
policy,
203+
)
204+
td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
205+
with pytest.raises(
206+
ValueError, match="Action spec shape does not match the action shape"
207+
):
208+
explorative_policy(td)
209+
61210

62211
@pytest.mark.parametrize("device", get_default_devices())
63212
class TestOrnsteinUhlenbeckProcessWrapper:

torchrl/modules/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
DistributionalQValueActor,
5454
DistributionalQValueHook,
5555
DistributionalQValueModule,
56+
EGreedyModule,
5657
EGreedyWrapper,
5758
LMHeadActorValueOperator,
5859
LSTMModule,

torchrl/modules/tensordict_module/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .common import SafeModule, VmapModule
2424
from .exploration import (
2525
AdditiveGaussianWrapper,
26+
EGreedyModule,
2627
EGreedyWrapper,
2728
OrnsteinUhlenbeckProcessWrapper,
2829
)

0 commit comments

Comments
 (0)