-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
torch_layers.py
309 lines (250 loc) · 13.4 KB
/
torch_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
from itertools import zip_longest
from typing import Dict, List, Tuple, Type, Union
import gym
import torch as th
from torch import nn
from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
from stable_baselines3.common.type_aliases import TensorDict
from stable_baselines3.common.utils import get_device
class BaseFeaturesExtractor(nn.Module):
"""
Base class that represents a features extractor.
:param observation_space:
:param features_dim: Number of features extracted.
"""
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
super(BaseFeaturesExtractor, self).__init__()
assert features_dim > 0
self._observation_space = observation_space
self._features_dim = features_dim
@property
def features_dim(self) -> int:
return self._features_dim
def forward(self, observations: th.Tensor) -> th.Tensor:
raise NotImplementedError()
class FlattenExtractor(BaseFeaturesExtractor):
"""
Feature extract that flatten the input.
Used as a placeholder when feature extraction is not needed.
:param observation_space:
"""
def __init__(self, observation_space: gym.Space):
super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
self.flatten = nn.Flatten()
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.flatten(observations)
class NatureCNN(BaseFeaturesExtractor):
"""
CNN from DQN nature paper:
Mnih, Volodymyr, et al.
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529-533.
:param observation_space:
:param features_dim: Number of features extracted.
This corresponds to the number of unit for the last layer.
"""
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
super(NatureCNN, self).__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
assert is_image_space(observation_space, check_channels=False), (
"You should use NatureCNN "
f"only with images not with {observation_space}\n"
"(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n"
"If you are using a custom environment,\n"
"please check it using our env checker:\n"
"https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html"
)
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
)
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))
def create_mlp(
input_dim: int,
output_dim: int,
net_arch: List[int],
activation_fn: Type[nn.Module] = nn.ReLU,
squash_output: bool = False,
) -> List[nn.Module]:
"""
Create a multi layer perceptron (MLP), which is
a collection of fully-connected layers each followed by an activation function.
:param input_dim: Dimension of the input vector
:param output_dim:
:param net_arch: Architecture of the neural net
It represents the number of units per layer.
The length of this list is the number of layers.
:param activation_fn: The activation function
to use after each layer.
:param squash_output: Whether to squash the output using a Tanh
activation function
:return:
"""
if len(net_arch) > 0:
modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()]
else:
modules = []
for idx in range(len(net_arch) - 1):
modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1]))
modules.append(activation_fn())
if output_dim > 0:
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
modules.append(nn.Linear(last_layer_dim, output_dim))
if squash_output:
modules.append(nn.Tanh())
return modules
class MlpExtractor(nn.Module):
"""
Constructs an MLP that receives observations as an input and outputs a latent representation for the policy and
a value network. The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many
of them are shared between the policy network and the value network. It is assumed to be a list with the following
structure:
1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer.
If the number of ints is zero, there will be no shared layers.
2. An optional dict, to specify the following non-shared layers for the value network and the policy network.
It is formatted like ``dict(vf=[<value layer sizes>], pi=[<policy layer sizes>])``.
If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed.
For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value
network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec
would be used: ``[55, dict(vf=[255, 255], pi=[128])]``. A simple shared network topology with two layers of size 128
would be specified as [128, 128].
Adapted from Stable Baselines.
:param feature_dim: Dimension of the feature vector (can be the output of a CNN)
:param net_arch: The specification of the policy and value networks.
See above for details on its formatting.
:param activation_fn: The activation function to use for the networks.
:param device:
"""
def __init__(
self,
feature_dim: int,
net_arch: List[Union[int, Dict[str, List[int]]]],
activation_fn: Type[nn.Module],
device: Union[th.device, str] = "auto",
):
super(MlpExtractor, self).__init__()
device = get_device(device)
shared_net, policy_net, value_net = [], [], []
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
last_layer_dim_shared = feature_dim
# Iterate through the shared layers and build the shared parts of the network
for layer in net_arch:
if isinstance(layer, int): # Check that this is a shared layer
# TODO: give layer a meaningful name
shared_net.append(nn.Linear(last_layer_dim_shared, layer)) # add linear of size layer
shared_net.append(activation_fn())
last_layer_dim_shared = layer
else:
assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts"
if "pi" in layer:
assert isinstance(layer["pi"], list), "Error: net_arch[-1]['pi'] must contain a list of integers."
policy_only_layers = layer["pi"]
if "vf" in layer:
assert isinstance(layer["vf"], list), "Error: net_arch[-1]['vf'] must contain a list of integers."
value_only_layers = layer["vf"]
break # From here on the network splits up in policy and value network
last_layer_dim_pi = last_layer_dim_shared
last_layer_dim_vf = last_layer_dim_shared
# Build the non-shared part of the network
for pi_layer_size, vf_layer_size in zip_longest(policy_only_layers, value_only_layers):
if pi_layer_size is not None:
assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size))
policy_net.append(activation_fn())
last_layer_dim_pi = pi_layer_size
if vf_layer_size is not None:
assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
value_net.append(nn.Linear(last_layer_dim_vf, vf_layer_size))
value_net.append(activation_fn())
last_layer_dim_vf = vf_layer_size
# Save dim, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
# Create networks
# If the list of layers is empty, the network will just act as an Identity module
self.shared_net = nn.Sequential(*shared_net).to(device)
self.policy_net = nn.Sequential(*policy_net).to(device)
self.value_net = nn.Sequential(*value_net).to(device)
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
:return: latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``
"""
shared_latent = self.shared_net(features)
return self.policy_net(shared_latent), self.value_net(shared_latent)
class CombinedExtractor(BaseFeaturesExtractor):
"""
Combined feature extractor for Dict observation spaces.
Builds a feature extractor for each key of the space. Input from each space
is fed through a separate submodule (CNN or MLP, depending on input shape),
the output features are concatenated and fed through additional MLP network ("combined").
:param observation_space:
:param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to
256 to avoid exploding network sizes.
"""
def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256):
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
super(CombinedExtractor, self).__init__(observation_space, features_dim=1)
extractors = {}
total_concat_size = 0
for key, subspace in observation_space.spaces.items():
if is_image_space(subspace):
extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim)
total_concat_size += cnn_output_dim
else:
# The observation key is a vector, flatten it if needed
extractors[key] = nn.Flatten()
total_concat_size += get_flattened_obs_dim(subspace)
self.extractors = nn.ModuleDict(extractors)
# Update the features dim manually
self._features_dim = total_concat_size
def forward(self, observations: TensorDict) -> th.Tensor:
encoded_tensor_list = []
for key, extractor in self.extractors.items():
encoded_tensor_list.append(extractor(observations[key]))
return th.cat(encoded_tensor_list, dim=1)
def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> Tuple[List[int], List[int]]:
"""
Get the actor and critic network architectures for off-policy actor-critic algorithms (SAC, TD3, DDPG).
The ``net_arch`` parameter allows to specify the amount and size of the hidden layers,
which can be different for the actor and the critic.
It is assumed to be a list of ints or a dict.
1. If it is a list, actor and critic networks will have the same architecture.
The architecture is represented by a list of integers (of arbitrary length (zero allowed))
each specifying the number of units per layer.
If the number of ints is zero, the network will be linear.
2. If it is a dict, it should have the following structure:
``dict(qf=[<critic network architecture>], pi=[<actor network architecture>])``.
where the network architecture is a list as described in 1.
For example, to have actor and critic that share the same network architecture,
you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256 units each).
If you want a different architecture for the actor and the critic,
then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``.
.. note::
Compared to their on-policy counterparts, no shared layers (other than the features extractor)
between the actor and the critic are allowed (to prevent issues with target networks).
:param net_arch: The specification of the actor and critic networks.
See above for details on its formatting.
:return: The network architectures for the actor and the critic
"""
if isinstance(net_arch, list):
actor_arch, critic_arch = net_arch, net_arch
else:
assert isinstance(net_arch, dict), "Error: the net_arch can only contain be a list of ints or a dict"
assert "pi" in net_arch, "Error: no key 'pi' was provided in net_arch for the actor network"
assert "qf" in net_arch, "Error: no key 'qf' was provided in net_arch for the critic network"
actor_arch, critic_arch = net_arch["pi"], net_arch["qf"]
return actor_arch, critic_arch