14
14
NestedCountingEnv ,
15
15
)
16
16
from scipy .stats import ttest_1samp
17
- from tensordict .nn import InteractionType , TensorDictModule
17
+
18
+ from tensordict .nn import InteractionType , TensorDictModule , TensorDictSequential
18
19
from tensordict .tensordict import TensorDict
19
20
from torch import nn
20
21
21
22
from torchrl .collectors import SyncDataCollector
22
- from torchrl .data import BoundedTensorSpec , CompositeSpec
23
+ from torchrl .data import (
24
+ BoundedTensorSpec ,
25
+ CompositeSpec ,
26
+ DiscreteTensorSpec ,
27
+ OneHotDiscreteTensorSpec ,
28
+ )
23
29
from torchrl .envs import SerialEnv
24
30
from torchrl .envs .transforms .transforms import gSDENoise , InitTracker , TransformedEnv
25
31
from torchrl .envs .utils import set_exploration_type
30
36
NormalParamWrapper ,
31
37
)
32
38
from torchrl .modules .models .exploration import LazygSDEModule
33
- from torchrl .modules .tensordict_module .actors import Actor , ProbabilisticActor
39
+ from torchrl .modules .tensordict_module .actors import (
40
+ Actor ,
41
+ ProbabilisticActor ,
42
+ QValueActor ,
43
+ )
34
44
from torchrl .modules .tensordict_module .exploration import (
35
45
_OrnsteinUhlenbeckProcess ,
36
46
AdditiveGaussianWrapper ,
47
+ EGreedyModule ,
37
48
EGreedyWrapper ,
38
49
OrnsteinUhlenbeckProcessWrapper ,
39
50
)
40
51
41
52
42
- @pytest .mark .parametrize ("eps_init" , [0.0 , 0.5 , 1.0 ])
43
53
class TestEGreedy :
44
- def test_egreedy (self , eps_init ):
54
+ @pytest .mark .parametrize ("eps_init" , [0.0 , 0.5 , 1.0 ])
55
+ @pytest .mark .parametrize ("module" , [True , False ])
56
+ def test_egreedy (self , eps_init , module ):
45
57
torch .manual_seed (0 )
46
58
spec = BoundedTensorSpec (1 , 1 , torch .Size ([4 ]))
47
59
module = torch .nn .Linear (4 , 4 , bias = False )
60
+
48
61
policy = Actor (spec = spec , module = module )
49
- explorative_policy = EGreedyWrapper (policy , eps_init = eps_init , eps_end = eps_init )
62
+ if module :
63
+ explorative_policy = TensorDictSequential (
64
+ policy , EGreedyModule (eps_init = eps_init , eps_end = eps_init , spec = spec )
65
+ )
66
+ else :
67
+ explorative_policy = EGreedyWrapper (
68
+ policy , eps_init = eps_init , eps_end = eps_init
69
+ )
50
70
td = TensorDict ({"observation" : torch .zeros (10 , 4 )}, batch_size = [10 ])
51
71
action = explorative_policy (td ).get ("action" )
52
72
if eps_init == 0 :
@@ -58,6 +78,135 @@ def test_egreedy(self, eps_init):
58
78
assert (action == 0 ).any ()
59
79
assert ((action == 1 ) | (action == 0 )).all ()
60
80
81
+ @pytest .mark .parametrize ("eps_init" , [0.0 , 0.5 , 1.0 ])
82
+ @pytest .mark .parametrize ("module" , [True , False ])
83
+ @pytest .mark .parametrize ("spec_class" , ["discrete" , "one_hot" ])
84
+ def test_egreedy_masked (self , module , eps_init , spec_class ):
85
+ torch .manual_seed (0 )
86
+ action_size = 4
87
+ batch_size = (3 , 4 , 2 )
88
+ module = torch .nn .Linear (action_size , action_size , bias = False )
89
+ if spec_class == "discrete" :
90
+ spec = DiscreteTensorSpec (action_size )
91
+ else :
92
+ spec = OneHotDiscreteTensorSpec (
93
+ action_size ,
94
+ shape = (action_size ,),
95
+ )
96
+ policy = QValueActor (spec = spec , module = module , action_mask_key = "action_mask" )
97
+ if module :
98
+ explorative_policy = TensorDictSequential (
99
+ policy ,
100
+ EGreedyModule (
101
+ eps_init = eps_init ,
102
+ eps_end = eps_init ,
103
+ spec = spec ,
104
+ action_mask_key = "action_mask" ,
105
+ ),
106
+ )
107
+ else :
108
+ explorative_policy = EGreedyWrapper (
109
+ policy ,
110
+ eps_init = eps_init ,
111
+ eps_end = eps_init ,
112
+ action_mask_key = "action_mask" ,
113
+ )
114
+
115
+ td = TensorDict (
116
+ {"observation" : torch .zeros (* batch_size , action_size )},
117
+ batch_size = batch_size ,
118
+ )
119
+ with pytest .raises (KeyError , match = "Action mask key action_mask not found in" ):
120
+ explorative_policy (td )
121
+
122
+ torch .manual_seed (0 )
123
+ action_mask = torch .ones (* batch_size , action_size ).to (torch .bool )
124
+ td = TensorDict (
125
+ {
126
+ "observation" : torch .zeros (* batch_size , action_size ),
127
+ "action_mask" : action_mask ,
128
+ },
129
+ batch_size = batch_size ,
130
+ )
131
+ action = explorative_policy (td ).get ("action" )
132
+
133
+ torch .manual_seed (0 )
134
+ action_mask = torch .randint (high = 2 , size = (* batch_size , action_size )).to (
135
+ torch .bool
136
+ )
137
+ while not action_mask .any (dim = - 1 ).all () or action_mask .all ():
138
+ action_mask = torch .randint (high = 2 , size = (* batch_size , action_size )).to (
139
+ torch .bool
140
+ )
141
+
142
+ td = TensorDict (
143
+ {
144
+ "observation" : torch .zeros (* batch_size , action_size ),
145
+ "action_mask" : action_mask ,
146
+ },
147
+ batch_size = batch_size ,
148
+ )
149
+ masked_action = explorative_policy (td ).get ("action" )
150
+
151
+ if spec_class == "discrete" :
152
+ action = spec .to_one_hot (action )
153
+ masked_action = spec .to_one_hot (masked_action )
154
+
155
+ assert not (action [~ action_mask ] == 0 ).all ()
156
+ assert (masked_action [~ action_mask ] == 0 ).all ()
157
+
158
+ def test_egreedy_wrapper_deprecation (self ):
159
+ torch .manual_seed (0 )
160
+ spec = BoundedTensorSpec (1 , 1 , torch .Size ([4 ]))
161
+ module = torch .nn .Linear (4 , 4 , bias = False )
162
+ policy = Actor (spec = spec , module = module )
163
+ with pytest .deprecated_call ():
164
+ EGreedyWrapper (policy )
165
+
166
+ def test_no_spec_error (
167
+ self ,
168
+ ):
169
+ torch .manual_seed (0 )
170
+ action_size = 4
171
+ batch_size = (3 , 4 , 2 )
172
+ module = torch .nn .Linear (action_size , action_size , bias = False )
173
+ spec = OneHotDiscreteTensorSpec (action_size , shape = (action_size ,))
174
+ policy = QValueActor (spec = spec , module = module )
175
+ explorative_policy = TensorDictSequential (
176
+ policy ,
177
+ EGreedyModule (spec = None ),
178
+ )
179
+ td = TensorDict (
180
+ {
181
+ "observation" : torch .zeros (* batch_size , action_size ),
182
+ },
183
+ batch_size = batch_size ,
184
+ )
185
+
186
+ with pytest .raises (
187
+ RuntimeError , match = "spec must be provided to the exploration wrapper."
188
+ ):
189
+ explorative_policy (td )
190
+
191
+ @pytest .mark .parametrize ("module" , [True , False ])
192
+ def test_wrong_action_shape (self , module ):
193
+ torch .manual_seed (0 )
194
+ spec = BoundedTensorSpec (1 , 1 , torch .Size ([4 ]))
195
+ module = torch .nn .Linear (4 , 5 , bias = False )
196
+
197
+ policy = Actor (spec = spec , module = module )
198
+ if module :
199
+ explorative_policy = TensorDictSequential (policy , EGreedyModule (spec = spec ))
200
+ else :
201
+ explorative_policy = EGreedyWrapper (
202
+ policy ,
203
+ )
204
+ td = TensorDict ({"observation" : torch .zeros (10 , 4 )}, batch_size = [10 ])
205
+ with pytest .raises (
206
+ ValueError , match = "Action spec shape does not match the action shape"
207
+ ):
208
+ explorative_policy (td )
209
+
61
210
62
211
@pytest .mark .parametrize ("device" , get_default_devices ())
63
212
class TestOrnsteinUhlenbeckProcessWrapper :
0 commit comments