forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
1548 lines (1388 loc) · 59.4 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import dataclasses
import warnings
from numbers import Number
from typing import Dict, List, Optional, Sequence, Tuple, Type, Union
import torch
from tensordict.nn import dispatch, TensorDictModuleBase
from torch import nn
from torch.nn import functional as F
from torchrl._utils import prod
from torchrl.data.utils import DEVICE_TYPING
from torchrl.modules.models.decision_transformer import DecisionTransformer
from torchrl.modules.models.utils import (
_find_depth,
create_on_device,
LazyMapping,
SquashDims,
Squeeze2dLayer,
SqueezeLayer,
)
class MLP(nn.Sequential):
"""A multi-layer perceptron.
If MLP receives more than one input, it concatenates them all along the last dimension before passing the
resulting tensor through the network. This is aimed at allowing for a seamless interface with calls of the type of
>>> model(state, action) # compute state-action value
In the future, this feature may be moved to the ProbabilisticTDModule, though it would require it to handle
different cases (vectors, images, ...)
Args:
in_features (int, optional): number of input features;
out_features (int, list of int): number of output features. If iterable of integers, the output is reshaped to
the desired shape;
depth (int, optional): depth of the network. A depth of 0 will produce a single linear layer network with the
desired input and output size. A length of 1 will create 2 linear layers etc. If no depth is indicated,
the depth information should be contained in the num_cells argument (see below). If num_cells is an
iterable and depth is indicated, both should match: len(num_cells) must be equal to depth.
num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If
an integer is provided, every layer will have the same number of cells. If an iterable is provided,
the linear layers out_features will match the content of num_cells.
default: 32;
activation_class (Type[nn.Module]): activation class to be used.
default: nn.Tanh
activation_kwargs (dict, optional): kwargs to be used with the activation class;
norm_class (Type, optional): normalization class, if any.
norm_kwargs (dict, optional): kwargs to be used with the normalization layers;
dropout (float, optional): dropout probability. Defaults to ``None`` (no
dropout);
bias_last_layer (bool): if ``True``, the last Linear layer will have a bias parameter.
default: True;
single_bias_last_layer (bool): if ``True``, the last dimension of the bias of the last layer will be a singleton
dimension.
default: True;
layer_class (Type[nn.Module]): class to be used for the linear layers;
layer_kwargs (dict, optional): kwargs for the linear layers;
activate_last_layer (bool): whether the MLP output should be activated. This is useful when the MLP output
is used as the input for another module.
default: False.
device (Optional[DEVICE_TYPING]): device to create the module on.
Examples:
>>> # All of the following examples provide valid, working MLPs
>>> mlp = MLP(in_features=3, out_features=6, depth=0) # MLP consisting of a single 3 x 6 linear layer
>>> print(mlp)
MLP(
(0): Linear(in_features=3, out_features=6, bias=True)
)
>>> mlp = MLP(in_features=3, out_features=6, depth=4, num_cells=32)
>>> print(mlp)
MLP(
(0): Linear(in_features=3, out_features=32, bias=True)
(1): Tanh()
(2): Linear(in_features=32, out_features=32, bias=True)
(3): Tanh()
(4): Linear(in_features=32, out_features=32, bias=True)
(5): Tanh()
(6): Linear(in_features=32, out_features=32, bias=True)
(7): Tanh()
(8): Linear(in_features=32, out_features=6, bias=True)
)
>>> mlp = MLP(out_features=6, depth=4, num_cells=32) # LazyLinear for the first layer
>>> print(mlp)
MLP(
(0): LazyLinear(in_features=0, out_features=32, bias=True)
(1): Tanh()
(2): Linear(in_features=32, out_features=32, bias=True)
(3): Tanh()
(4): Linear(in_features=32, out_features=32, bias=True)
(5): Tanh()
(6): Linear(in_features=32, out_features=32, bias=True)
(7): Tanh()
(8): Linear(in_features=32, out_features=6, bias=True)
)
>>> mlp = MLP(out_features=6, num_cells=[32, 33, 34, 35]) # defines the depth by the num_cells arg
>>> print(mlp)
MLP(
(0): LazyLinear(in_features=0, out_features=32, bias=True)
(1): Tanh()
(2): Linear(in_features=32, out_features=33, bias=True)
(3): Tanh()
(4): Linear(in_features=33, out_features=34, bias=True)
(5): Tanh()
(6): Linear(in_features=34, out_features=35, bias=True)
(7): Tanh()
(8): Linear(in_features=35, out_features=6, bias=True)
)
>>> mlp = MLP(out_features=(6, 7), num_cells=[32, 33, 34, 35]) # returns a view of the output tensor with shape [*, 6, 7]
>>> print(mlp)
MLP(
(0): LazyLinear(in_features=0, out_features=32, bias=True)
(1): Tanh()
(2): Linear(in_features=32, out_features=33, bias=True)
(3): Tanh()
(4): Linear(in_features=33, out_features=34, bias=True)
(5): Tanh()
(6): Linear(in_features=34, out_features=35, bias=True)
(7): Tanh()
(8): Linear(in_features=35, out_features=42, bias=True)
)
>>> from torchrl.modules import NoisyLinear
>>> mlp = MLP(out_features=(6, 7), num_cells=[32, 33, 34, 35], layer_class=NoisyLinear) # uses NoisyLinear layers
>>> print(mlp)
MLP(
(0): NoisyLazyLinear(in_features=0, out_features=32, bias=False)
(1): Tanh()
(2): NoisyLinear(in_features=32, out_features=33, bias=True)
(3): Tanh()
(4): NoisyLinear(in_features=33, out_features=34, bias=True)
(5): Tanh()
(6): NoisyLinear(in_features=34, out_features=35, bias=True)
(7): Tanh()
(8): NoisyLinear(in_features=35, out_features=42, bias=True)
)
"""
def __init__(
self,
in_features: Optional[int] = None,
out_features: Union[int, Sequence[int]] = None,
depth: Optional[int] = None,
num_cells: Optional[Union[Sequence, int]] = None,
activation_class: Type[nn.Module] = nn.Tanh,
activation_kwargs: Optional[dict] = None,
norm_class: Optional[Type[nn.Module]] = None,
norm_kwargs: Optional[dict] = None,
dropout: Optional[float] = None,
bias_last_layer: bool = True,
single_bias_last_layer: bool = False,
layer_class: Type[nn.Module] = nn.Linear,
layer_kwargs: Optional[dict] = None,
activate_last_layer: bool = False,
device: Optional[DEVICE_TYPING] = None,
):
if out_features is None:
raise ValueError("out_features must be specified for MLP.")
default_num_cells = 32
if num_cells is None:
if depth is None:
num_cells = [default_num_cells] * 3
depth = 3
else:
num_cells = [default_num_cells] * depth
self.in_features = in_features
_out_features_num = out_features
if not isinstance(out_features, Number):
_out_features_num = prod(out_features)
self.out_features = out_features
self._out_features_num = _out_features_num
self.activation_class = activation_class
self.activation_kwargs = (
activation_kwargs if activation_kwargs is not None else {}
)
self.norm_class = norm_class
self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
self.dropout = dropout
self.bias_last_layer = bias_last_layer
self.single_bias_last_layer = single_bias_last_layer
self.layer_class = layer_class
self.layer_kwargs = layer_kwargs if layer_kwargs is not None else {}
self.activate_last_layer = activate_last_layer
if single_bias_last_layer:
raise NotImplementedError
if not (isinstance(num_cells, Sequence) or depth is not None):
raise RuntimeError(
"If num_cells is provided as an integer, \
depth must be provided too."
)
self.num_cells = (
list(num_cells) if isinstance(num_cells, Sequence) else [num_cells] * depth
)
self.depth = depth if depth is not None else len(self.num_cells)
if not (len(self.num_cells) == depth or depth is None):
raise RuntimeError(
"depth and num_cells length conflict, \
consider matching or specifying a constant num_cells argument together with a a desired depth"
)
layers = self._make_net(device)
super().__init__(*layers)
def _make_net(self, device: Optional[DEVICE_TYPING]) -> List[nn.Module]:
layers = []
in_features = [self.in_features] + self.num_cells
out_features = self.num_cells + [self._out_features_num]
for i, (_in, _out) in enumerate(zip(in_features, out_features)):
_bias = self.bias_last_layer if i == self.depth else True
if _in is not None:
layers.append(
create_on_device(
self.layer_class,
device,
_in,
_out,
bias=_bias,
**self.layer_kwargs,
)
)
else:
try:
lazy_version = LazyMapping[self.layer_class]
except KeyError:
raise KeyError(
f"The lazy version of {self.layer_class.__name__} is not implemented yet. "
"Consider providing the input feature dimensions explicitely when creating an MLP module"
)
layers.append(
create_on_device(
lazy_version, device, _out, bias=_bias, **self.layer_kwargs
)
)
if i < self.depth or self.activate_last_layer:
if self.dropout is not None:
layers.append(create_on_device(nn.Dropout, device, p=self.dropout))
if self.norm_class is not None:
layers.append(
create_on_device(self.norm_class, device, **self.norm_kwargs)
)
layers.append(
create_on_device(
self.activation_class, device, **self.activation_kwargs
)
)
return layers
def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
if len(inputs) > 1:
inputs = (torch.cat([*inputs], -1),)
out = super().forward(*inputs)
if not isinstance(self.out_features, Number):
out = out.view(*out.shape[:-1], *self.out_features)
return out
class ConvNet(nn.Sequential):
"""A convolutional neural network.
Args:
in_features (int, optional): number of input features;
depth (int, optional): depth of the network. A depth of 1 will produce a single linear layer network with the
desired input size, and with an output size equal to the last element of the num_cells argument.
If no depth is indicated, the depth information should be contained in the num_cells argument (see below).
If num_cells is an iterable and depth is indicated, both should match: len(num_cells) must be equal to
the depth.
num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If
an integer is provided, every layer will have the same number of cells. If an iterable is provided,
the linear layers out_features will match the content of num_cells.
default: [32, 32, 32];
kernel_sizes (int, Sequence[Union[int, Sequence[int]]]): Kernel size(s) of the conv network. If iterable, the length must match the
depth, defined by the num_cells or depth arguments.
strides (int or Sequence[int]): Stride(s) of the conv network. If iterable, the length must match the
depth, defined by the num_cells or depth arguments.
activation_class (Type[nn.Module]): activation class to be used.
default: nn.Tanh
activation_kwargs (dict, optional): kwargs to be used with the activation class;
norm_class (Type, optional): normalization class, if any;
norm_kwargs (dict, optional): kwargs to be used with the normalization layers;
bias_last_layer (bool): if ``True``, the last Linear layer will have a bias parameter.
default: True;
aggregator_class (Type[nn.Module]): aggregator to use at the end of the chain.
default: SquashDims;
aggregator_kwargs (dict, optional): kwargs for the aggregator_class;
squeeze_output (bool): whether the output should be squeezed of its singleton dimensions.
default: False.
device (Optional[DEVICE_TYPING]): device to create the module on.
Examples:
>>> # All of the following examples provide valid, working MLPs
>>> cnet = ConvNet(in_features=3, depth=1, num_cells=[32,]) # MLP consisting of a single 3 x 6 linear layer
>>> print(cnet)
ConvNet(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ELU(alpha=1.0)
(2): SquashDims()
)
>>> cnet = ConvNet(in_features=3, depth=4, num_cells=32)
>>> print(cnet)
ConvNet(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ELU(alpha=1.0)
(2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(3): ELU(alpha=1.0)
(4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(5): ELU(alpha=1.0)
(6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(7): ELU(alpha=1.0)
(8): SquashDims()
)
>>> cnet = ConvNet(in_features=3, num_cells=[32, 33, 34, 35]) # defines the depth by the num_cells arg
>>> print(cnet)
ConvNet(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ELU(alpha=1.0)
(2): Conv2d(32, 33, kernel_size=(3, 3), stride=(1, 1))
(3): ELU(alpha=1.0)
(4): Conv2d(33, 34, kernel_size=(3, 3), stride=(1, 1))
(5): ELU(alpha=1.0)
(6): Conv2d(34, 35, kernel_size=(3, 3), stride=(1, 1))
(7): ELU(alpha=1.0)
(8): SquashDims()
)
>>> cnet = ConvNet(in_features=3, num_cells=[32, 33, 34, 35], kernel_sizes=[3, 4, 5, (2, 3)]) # defines kernels, possibly rectangular
>>> print(cnet)
ConvNet(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ELU(alpha=1.0)
(2): Conv2d(32, 33, kernel_size=(4, 4), stride=(1, 1))
(3): ELU(alpha=1.0)
(4): Conv2d(33, 34, kernel_size=(5, 5), stride=(1, 1))
(5): ELU(alpha=1.0)
(6): Conv2d(34, 35, kernel_size=(2, 3), stride=(1, 1))
(7): ELU(alpha=1.0)
(8): SquashDims()
)
"""
def __init__(
self,
in_features: Optional[int] = None,
depth: Optional[int] = None,
num_cells: Union[Sequence, int] = None,
kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 3,
strides: Union[Sequence, int] = 1,
paddings: Union[Sequence, int] = 0,
activation_class: Type[nn.Module] = nn.ELU,
activation_kwargs: Optional[dict] = None,
norm_class: Optional[Type[nn.Module]] = None,
norm_kwargs: Optional[dict] = None,
bias_last_layer: bool = True,
aggregator_class: Optional[Type[nn.Module]] = SquashDims,
aggregator_kwargs: Optional[dict] = None,
squeeze_output: bool = False,
device: Optional[DEVICE_TYPING] = None,
):
if num_cells is None:
num_cells = [32, 32, 32]
self.in_features = in_features
self.activation_class = activation_class
self.activation_kwargs = (
activation_kwargs if activation_kwargs is not None else {}
)
self.norm_class = norm_class
self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
self.bias_last_layer = bias_last_layer
self.aggregator_class = aggregator_class
self.aggregator_kwargs = (
aggregator_kwargs if aggregator_kwargs is not None else {"ndims_in": 3}
)
self.squeeze_output = squeeze_output
# self.single_bias_last_layer = single_bias_last_layer
depth = _find_depth(depth, num_cells, kernel_sizes, strides, paddings)
self.depth = depth
if depth == 0:
raise ValueError("Null depth is not permitted with ConvNet.")
for _field, _value in zip(
["num_cells", "kernel_sizes", "strides", "paddings"],
[num_cells, kernel_sizes, strides, paddings],
):
_depth = depth
setattr(
self,
_field,
(_value if isinstance(_value, Sequence) else [_value] * _depth),
)
if not (isinstance(_value, Sequence) or _depth is not None):
raise RuntimeError(
f"If {_field} is provided as an integer, "
"depth must be provided too."
)
if not (len(getattr(self, _field)) == _depth or _depth is None):
raise RuntimeError(
f"depth={depth} and {_field}={len(getattr(self, _field))} length conflict, "
+ f"consider matching or specifying a constant {_field} argument together with a a desired depth"
)
self.out_features = self.num_cells[-1]
self.depth = len(self.kernel_sizes)
layers = self._make_net(device)
super().__init__(*layers)
def _make_net(self, device: Optional[DEVICE_TYPING]) -> nn.Module:
layers = []
in_features = [self.in_features] + self.num_cells[: self.depth]
out_features = self.num_cells + [self.out_features]
kernel_sizes = self.kernel_sizes
strides = self.strides
paddings = self.paddings
for i, (_in, _out, _kernel, _stride, _padding) in enumerate(
zip(in_features, out_features, kernel_sizes, strides, paddings)
):
_bias = (i < len(in_features) - 1) or self.bias_last_layer
if _in is not None:
layers.append(
nn.Conv2d(
_in,
_out,
kernel_size=_kernel,
stride=_stride,
bias=_bias,
padding=_padding,
device=device,
)
)
else:
layers.append(
nn.LazyConv2d(
_out,
kernel_size=_kernel,
stride=_stride,
bias=_bias,
padding=_padding,
device=device,
)
)
layers.append(
create_on_device(
self.activation_class, device, **self.activation_kwargs
)
)
if self.norm_class is not None:
layers.append(
create_on_device(self.norm_class, device, **self.norm_kwargs)
)
if self.aggregator_class is not None:
layers.append(
create_on_device(
self.aggregator_class, device, **self.aggregator_kwargs
)
)
if self.squeeze_output:
layers.append(Squeeze2dLayer())
return layers
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
*batch, C, L, W = inputs.shape
if len(batch) > 1:
inputs = inputs.flatten(0, len(batch) - 1)
out = super(ConvNet, self).forward(inputs)
if len(batch) > 1:
out = out.unflatten(0, batch)
return out
Conv2dNet = ConvNet
class Conv3dNet(nn.Sequential):
"""A 3D-convolutional neural network.
Args:
in_features (int, optional): number of input features. A lazy implementation that automatically retrieves
the input size will be used if none is provided.
depth (int, optional): depth of the network. A depth of 1 will produce a single linear layer network with the
desired input size, and with an output size equal to the last element of the num_cells argument.
If no depth is indicated, the depth information should be contained in the num_cells argument (see below).
If num_cells is an iterable and depth is indicated, both should match: len(num_cells) must be equal to
the depth.
num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If
an integer is provided, every layer will have the same number of cells. If an iterable is provided,
the linear layers out_features will match the content of num_cells.
default: ``[32, 32, 32]`` or ``[32] * depth` is depth is not ``None``.
kernel_sizes (int, Sequence[Union[int, Sequence[int]]]): Kernel size(s) of the conv network. If iterable, the length must match the
depth, defined by the num_cells or depth arguments.
strides (int or Sequence[int]): Stride(s) of the conv network. If iterable, the length must match the
depth, defined by the num_cells or depth arguments.
activation_class (Type[nn.Module]): activation class to be used.
default: nn.Tanh
activation_kwargs (dict, optional): kwargs to be used with the activation class;
norm_class (Type, optional): normalization class, if any;
norm_kwargs (dict, optional): kwargs to be used with the normalization layers;
bias_last_layer (bool): if ``True``, the last Linear layer will have a bias parameter.
default: True;
aggregator_class (Type[nn.Module]): aggregator to use at the end of the chain.
default: SquashDims;
aggregator_kwargs (dict, optional): kwargs for the aggregator_class;
squeeze_output (bool): whether the output should be squeezed of its singleton dimensions.
default: False.
device (Optional[DEVICE_TYPING]): device to create the module on.
Examples:
>>> # All of the following examples provide valid, working MLPs
>>> cnet = Conv3dNet(in_features=3, depth=1, num_cells=[32,])
>>> print(cnet)
Conv3dNet(
(0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
(1): ELU(alpha=1.0)
(2): SquashDims()
)
>>> cnet = Conv3dNet(in_features=3, depth=4, num_cells=32)
>>> print(cnet)
Conv3dNet(
(0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
(1): ELU(alpha=1.0)
(2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
(3): ELU(alpha=1.0)
(4): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
(5): ELU(alpha=1.0)
(6): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
(7): ELU(alpha=1.0)
(8): SquashDims()
)
>>> cnet = Conv3dNet(in_features=3, num_cells=[32, 33, 34, 35]) # defines the depth by the num_cells arg
>>> print(cnet)
Conv3dNet(
(0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
(1): ELU(alpha=1.0)
(2): Conv3d(32, 33, kernel_size=(3, 3, 3), stride=(1, 1, 1))
(3): ELU(alpha=1.0)
(4): Conv3d(33, 34, kernel_size=(3, 3, 3), stride=(1, 1, 1))
(5): ELU(alpha=1.0)
(6): Conv3d(34, 35, kernel_size=(3, 3, 3), stride=(1, 1, 1))
(7): ELU(alpha=1.0)
(8): SquashDims()
)
>>> cnet = Conv3dNet(in_features=3, num_cells=[32, 33, 34, 35], kernel_sizes=[3, 4, 5, (2, 3, 4)]) # defines kernels, possibly rectangular
>>> print(cnet)
Conv3dNet(
(0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
(1): ELU(alpha=1.0)
(2): Conv3d(32, 33, kernel_size=(4, 4, 4), stride=(1, 1, 1))
(3): ELU(alpha=1.0)
(4): Conv3d(33, 34, kernel_size=(5, 5, 5), stride=(1, 1, 1))
(5): ELU(alpha=1.0)
(6): Conv3d(34, 35, kernel_size=(2, 3, 4), stride=(1, 1, 1))
(7): ELU(alpha=1.0)
(8): SquashDims()
)
"""
def __init__(
self,
in_features: Optional[int] = None,
depth: Optional[int] = None,
num_cells: Union[Sequence, int] = None,
kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 3,
strides: Union[Sequence, int] = 1,
paddings: Union[Sequence, int] = 0,
activation_class: Type[nn.Module] = nn.ELU,
activation_kwargs: Optional[dict] = None,
norm_class: Optional[Type[nn.Module]] = None,
norm_kwargs: Optional[dict] = None,
bias_last_layer: bool = True,
aggregator_class: Optional[Type[nn.Module]] = SquashDims,
aggregator_kwargs: Optional[dict] = None,
squeeze_output: bool = False,
device: Optional[DEVICE_TYPING] = None,
):
if num_cells is None:
if depth is None:
num_cells = [32, 32, 32]
else:
num_cells = [32] * depth
self.in_features = in_features
self.activation_class = activation_class
self.activation_kwargs = (
activation_kwargs if activation_kwargs is not None else {}
)
self.norm_class = norm_class
self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
self.bias_last_layer = bias_last_layer
self.aggregator_class = aggregator_class
self.aggregator_kwargs = (
aggregator_kwargs if aggregator_kwargs is not None else {"ndims_in": 4}
)
self.squeeze_output = squeeze_output
# self.single_bias_last_layer = single_bias_last_layer
depth = _find_depth(depth, num_cells, kernel_sizes, strides, paddings)
self.depth = depth
if depth == 0:
raise ValueError("Null depth is not permitted with Conv3dNet.")
for _field, _value in zip(
["num_cells", "kernel_sizes", "strides", "paddings"],
[num_cells, kernel_sizes, strides, paddings],
):
_depth = depth
setattr(
self,
_field,
(_value if isinstance(_value, Sequence) else [_value] * _depth),
)
if not (len(getattr(self, _field)) == _depth or _depth is None):
raise ValueError(
f"depth={depth} and {_field}={len(getattr(self, _field))} length conflict, "
+ f"consider matching or specifying a constant {_field} argument together with a a desired depth"
)
self.out_features = self.num_cells[-1]
self.depth = len(self.kernel_sizes)
layers = self._make_net(device)
super().__init__(*layers)
def _make_net(self, device: Optional[DEVICE_TYPING]) -> nn.Module:
layers = []
in_features = [self.in_features] + self.num_cells[: self.depth]
out_features = self.num_cells + [self.out_features]
kernel_sizes = self.kernel_sizes
strides = self.strides
paddings = self.paddings
for i, (_in, _out, _kernel, _stride, _padding) in enumerate(
zip(in_features, out_features, kernel_sizes, strides, paddings)
):
_bias = (i < len(in_features) - 1) or self.bias_last_layer
if _in is not None:
layers.append(
nn.Conv3d(
_in,
_out,
kernel_size=_kernel,
stride=_stride,
bias=_bias,
padding=_padding,
device=device,
)
)
else:
layers.append(
nn.LazyConv3d(
_out,
kernel_size=_kernel,
stride=_stride,
bias=_bias,
padding=_padding,
device=device,
)
)
layers.append(
create_on_device(
self.activation_class, device, **self.activation_kwargs
)
)
if self.norm_class is not None:
layers.append(
create_on_device(self.norm_class, device, **self.norm_kwargs)
)
if self.aggregator_class is not None:
layers.append(
create_on_device(
self.aggregator_class, device, **self.aggregator_kwargs
)
)
if self.squeeze_output:
layers.append(SqueezeLayer((-3, -2, -1)))
return layers
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
try:
*batch, C, D, L, W = inputs.shape
except ValueError as err:
raise ValueError(
f"The input value of {self.__class__.__name__} must have at least 4 dimensions, got {inputs.ndim} instead."
) from err
if len(batch) > 1:
inputs = inputs.flatten(0, len(batch) - 1)
out = super().forward(inputs)
if len(batch) > 1:
out = out.unflatten(0, batch)
return out
class DuelingMlpDQNet(nn.Module):
"""Creates a Dueling MLP Q-network.
Presented in https://arxiv.org/abs/1511.06581
Args:
out_features (int): number of features for the advantage network
out_features_value (int): number of features for the value network
mlp_kwargs_feature (dict, optional): kwargs for the feature network.
Default is
>>> mlp_kwargs_feature = {
... 'num_cells': [256, 256],
... 'activation_class': nn.ELU,
... 'out_features': 256,
... 'activate_last_layer': True,
... }
mlp_kwargs_output (dict, optional): kwargs for the advantage and
value networks.
Default is
>>> mlp_kwargs_output = {
... "depth": 1,
... "activation_class": nn.ELU,
... "num_cells": 512,
... "bias_last_layer": True,
... }
device (Optional[DEVICE_TYPING]): device to create the module on.
"""
def __init__(
self,
out_features: int,
out_features_value: int = 1,
mlp_kwargs_feature: Optional[dict] = None,
mlp_kwargs_output: Optional[dict] = None,
device: Optional[DEVICE_TYPING] = None,
):
super().__init__()
mlp_kwargs_feature = (
mlp_kwargs_feature if mlp_kwargs_feature is not None else {}
)
_mlp_kwargs_feature = {
"num_cells": [256, 256],
"out_features": 256,
"activation_class": nn.ELU,
"activate_last_layer": True,
}
_mlp_kwargs_feature.update(mlp_kwargs_feature)
self.features = MLP(device=device, **_mlp_kwargs_feature)
_mlp_kwargs_output = {
"depth": 1,
"activation_class": nn.ELU,
"num_cells": 512,
"bias_last_layer": True,
}
mlp_kwargs_output = mlp_kwargs_output if mlp_kwargs_output is not None else {}
_mlp_kwargs_output.update(mlp_kwargs_output)
self.out_features = out_features
self.out_features_value = out_features_value
self.advantage = MLP(
out_features=out_features, device=device, **_mlp_kwargs_output
)
self.value = MLP(
out_features=out_features_value, device=device, **_mlp_kwargs_output
)
for layer in self.modules():
if isinstance(layer, (nn.Conv2d, nn.Linear)) and isinstance(
layer.bias, torch.Tensor
):
layer.bias.data.zero_()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
advantage = self.advantage(x)
value = self.value(x)
return value + advantage - advantage.mean(dim=-1, keepdim=True)
class DuelingCnnDQNet(nn.Module):
"""Dueling CNN Q-network.
Presented in https://arxiv.org/abs/1511.06581
Args:
out_features (int): number of features for the advantage network
out_features_value (int): number of features for the value network
cnn_kwargs (dict, optional): kwargs for the feature network.
Default is
>>> cnn_kwargs = {
... 'num_cells': [32, 64, 64],
... 'strides': [4, 2, 1],
... 'kernels': [8, 4, 3],
... }
mlp_kwargs (dict, optional): kwargs for the advantage and value network.
Default is
>>> mlp_kwargs = {
... "depth": 1,
... "activation_class": nn.ELU,
... "num_cells": 512,
... "bias_last_layer": True,
... }
device (Optional[DEVICE_TYPING]): device to create the module on.
"""
def __init__(
self,
out_features: int,
out_features_value: int = 1,
cnn_kwargs: Optional[dict] = None,
mlp_kwargs: Optional[dict] = None,
device: Optional[DEVICE_TYPING] = None,
):
super().__init__()
cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
_cnn_kwargs = {
"num_cells": [32, 64, 64],
"strides": [4, 2, 1],
"kernel_sizes": [8, 4, 3],
}
_cnn_kwargs.update(cnn_kwargs)
self.features = ConvNet(device=device, **_cnn_kwargs)
_mlp_kwargs = {
"depth": 1,
"activation_class": nn.ELU,
"num_cells": 512,
"bias_last_layer": True,
}
mlp_kwargs = mlp_kwargs if mlp_kwargs is not None else {}
_mlp_kwargs.update(mlp_kwargs)
self.out_features = out_features
self.out_features_value = out_features_value
self.advantage = MLP(out_features=out_features, device=device, **_mlp_kwargs)
self.value = MLP(out_features=out_features_value, device=device, **_mlp_kwargs)
for layer in self.modules():
if isinstance(layer, (nn.Conv2d, nn.Linear)) and isinstance(
layer.bias, torch.Tensor
):
layer.bias.data.zero_()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
advantage = self.advantage(x)
value = self.value(x)
return value + advantage - advantage.mean(dim=-1, keepdim=True)
class DistributionalDQNnet(TensorDictModuleBase):
"""Distributional Deep Q-Network.
Args:
DQNet (nn.Module): (deprecated) Q-Network with output length equal
to the number of atoms:
output.shape = [*batch, atoms, actions].
in_keys (list of str or tuples of str): input keys to the log-softmax
operation. Defaults to ``["action_value"]``.
out_keys (list of str or tuples of str): output keys to the log-softmax
operation. Defaults to ``["action_value"]``.
"""
_wrong_out_feature_dims_error = (
"DistributionalDQNnet requires dqn output to be at least "
"2-dimensional, with dimensions *Batch x #Atoms x #Actions. Got {0} "
"instead."
)
def __init__(self, DQNet: nn.Module = None, in_keys=None, out_keys=None):
super().__init__()
if DQNet is not None:
warnings.warn(
f"Passing a network to {type(self)} is going to be deprecated.",
category=DeprecationWarning,
)
if not (
not isinstance(DQNet.out_features, Number)
and len(DQNet.out_features) > 1
):
raise RuntimeError(self._wrong_out_feature_dims_error)
self.dqn = DQNet
if in_keys is None:
in_keys = ["action_value"]
if out_keys is None:
out_keys = ["action_value"]
self.in_keys = in_keys
self.out_keys = out_keys
@dispatch(auto_batch_size=False)
def forward(self, tensordict):
for in_key, out_key in zip(self.in_keys, self.out_keys):
q_values = tensordict.get(in_key)
if self.dqn is not None:
q_values = self.dqn(q_values)
if q_values.ndimension() < 2:
raise RuntimeError(
self._wrong_out_feature_dims_error.format(q_values.shape)
)
tensordict.set(out_key, F.log_softmax(q_values, dim=-2))
return tensordict
def ddpg_init_last_layer(
module: nn.Sequential,
scale: float = 6e-4,
device: Optional[DEVICE_TYPING] = None,
) -> None:
"""Initializer for the last layer of DDPG.
Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING",
https://arxiv.org/pdf/1509.02971.pdf
"""
for last_layer in reversed(module):
if isinstance(last_layer, (nn.Linear, nn.Conv2d)):
break
else:
raise RuntimeError("Could not find a nn.Linear / nn.Conv2d to initialize.")
last_layer.weight.data.copy_(
torch.rand_like(last_layer.weight.data, device=device) * scale - scale / 2
)
if last_layer.bias is not None:
last_layer.bias.data.copy_(
torch.rand_like(last_layer.bias.data, device=device) * scale - scale / 2
)
class DdpgCnnActor(nn.Module):
"""DDPG Convolutional Actor class.
Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING",
https://arxiv.org/pdf/1509.02971.pdf
The DDPG Convolutional Actor takes as input an observation (some simple transformation of the observed pixels) and
returns an action vector from it.
It is trained to maximise the value returned by the DDPG Q Value network.
Args:
action_dim (int): length of the action vector.
conv_net_kwargs (dict, optional): kwargs for the ConvNet.
default: {
'in_features': None,
"num_cells": [32, 64, 64],
"kernel_sizes": [8, 4, 3],
"strides": [4, 2, 1],
"paddings": [0, 0, 1],
'activation_class': nn.ELU,
'norm_class': None,
'aggregator_class': SquashDims,
'aggregator_kwargs': {"ndims_in": 3},
'squeeze_output': True,
}
mlp_net_kwargs: kwargs for MLP.
Default: {
'in_features': None,
'out_features': action_dim,
'depth': 2,
'num_cells': 200,
'activation_class': nn.ELU,
'bias_last_layer': True,
}
use_avg_pooling (bool, optional): if ``True``, a nn.AvgPooling layer is
used to aggregate the output. Default is ``False``.
device (Optional[DEVICE_TYPING]): device to create the module on.
"""
def __init__(
self,
action_dim: int,
conv_net_kwargs: Optional[dict] = None,
mlp_net_kwargs: Optional[dict] = None,
use_avg_pooling: bool = False,
device: Optional[DEVICE_TYPING] = None,
):
super().__init__()
conv_net_default_kwargs = {
"in_features": None,