-
Notifications
You must be signed in to change notification settings - Fork 12
/
rl_models.py
422 lines (362 loc) · 14 KB
/
rl_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
from __future__ import absolute_import
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_scatter
def layer_norm(layer, std=1.0, bias_const=1e-6):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
class IMPALACNN(nn.ModuleDict):
"""
CNN from paper:
"IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures"
https://arxiv.org/abs/1802.01561
:param observation_space:
:param features_dim: Number of features extracted.
This corresponds to the number of unit for the last layer.
"""
def __init__(self, input_channels, output_dim, depths=None, input_shape=None):
if depths is None:
depths = [16, 32, 32]
if input_shape is None:
input_shape = [100, 100]
super().__init__()
self.input_channels = input_channels
input_shape = torch.as_tensor(input_shape)
self.feat_convs = []
self.resnet1 = []
self.resnet2 = []
self.convs = []
for num_ch in depths:
feats_convs = []
feats_convs.append(
nn.Conv2d(
in_channels=input_channels,
out_channels=num_ch,
kernel_size=3,
stride=1,
padding=1,
)
)
feats_convs.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
self.feat_convs.append(nn.Sequential(*feats_convs))
input_channels = num_ch
for i in range(2):
resnet_block = []
resnet_block.append(nn.ReLU())
resnet_block.append(
nn.Conv2d(
in_channels=input_channels,
out_channels=num_ch,
kernel_size=3,
stride=1,
padding=1,
)
)
resnet_block.append(nn.ReLU())
resnet_block.append(
nn.Conv2d(
in_channels=input_channels,
out_channels=num_ch,
kernel_size=3,
stride=1,
padding=1,
)
)
if i == 0:
self.resnet1.append(nn.Sequential(*resnet_block))
else:
self.resnet2.append(nn.Sequential(*resnet_block))
self.feat_convs = nn.ModuleList(self.feat_convs)
self.resnet1 = nn.ModuleList(self.resnet1)
self.resnet2 = nn.ModuleList(self.resnet2)
self.flatten = nn.Flatten()
flatten_dim = self._get_flatten_dim(input_shape)
self.fc = nn.Linear(flatten_dim, output_dim)
def _get_flatten_dim(self, input_shape):
input_sample = torch.empty(
[
1,
self.input_channels,
input_shape[0].int().item(),
input_shape[1].int().item(),
]
)
x = input_sample
res_input = None
for i, fconv in enumerate(self.feat_convs):
x = fconv(x)
res_input = x
x = self.resnet1[i](x)
x += res_input
res_input = x
x = self.resnet2[i](x)
x += res_input
x = self.flatten(x)
return x.shape[1]
def forward(self, batch: torch.Tensor) -> torch.Tensor:
x = batch
res_input = None
for i, fconv in enumerate(self.feat_convs):
x = fconv(x)
res_input = x
x = self.resnet1[i](x)
x += res_input
res_input = x
x = self.resnet2[i](x)
x += res_input
x = F.relu(self.flatten(x))
x = F.relu(self.fc(x))
return x
class IMPALAEncoder(nn.ModuleDict):
"""
CNN from paper:
"IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures"
https://arxiv.org/abs/1802.01561
:param observation_space:
:param features_dim: Number of features extracted.
This corresponds to the number of unit for the last layer.
"""
def __init__(
self,
input_dim,
embedding_dim,
bev_range=None,
bev_pixel_size=None,
depths=None,
):
if bev_range is None:
bev_range = [0.0, 0.0, 1.0, 1.0]
if bev_pixel_size is None:
bev_pixel_size = [0.01, 0.01]
if depths is None:
depths = [16, 32, 32]
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
super().__init__()
self.bev_range = torch.tensor(bev_range)
self.bev_pixel_size = torch.tensor(bev_pixel_size)
self.input_dim = input_dim
self.embedding_dim = embedding_dim
self.input_channels = embedding_dim // 2
bev_grid_shape = (self.bev_range[2:] - self.bev_range[:2]) / self.bev_pixel_size
self.input_transform = nn.Sequential(
nn.Conv1d(
self.input_dim + 4,
self.input_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.InstanceNorm1d(
self.input_channels, eps=1e-3, momentum=0.01, affine=True
),
nn.ReLU(inplace=True),
)
self.cnn = IMPALACNN(
self.input_channels, self.embedding_dim, depths, bev_grid_shape
)
def forward(self, batch: torch.Tensor) -> torch.Tensor:
# get input feature, whose shape is [B*N, input_dim]
# size of 2nd dim changed from 2 to 8, add three features
B, N, _ = batch.shape
H = self.input_channels
bev_range = self.bev_range.to(batch.device)
bev_pixel_size = self.bev_pixel_size.to(batch.device)
bev_grid_shape = (bev_range[2:] - bev_range[:2]) / bev_pixel_size
bev_idxs = (batch[..., 0:2].view(-1, 2) - bev_range[0:2]) / bev_pixel_size
bev_coords = torch.floor(bev_idxs).int()
batch_idxs = (
torch.arange(0, B).view(1, B).repeat(N, 1).t().flatten().to(batch.device)
)
points = torch.cat([batch_idxs[:, None], batch.flatten(0, 1)], dim=1)
points_xy = points[:, [1, 2]]
extra_feat = points[:, 3:]
assert points.shape[1] == self.input_dim + 1
bev_scale_xy = bev_grid_shape[0] * bev_grid_shape[1]
bev_scale_y = bev_grid_shape[1]
bev_merge_coords = (
points[:, 0].int() * bev_scale_xy
+ bev_coords[:, 0] * bev_scale_y
+ bev_coords[:, 1]
)
bev_unq_coords, bev_unq_inv, bev_unq_cnt = torch.unique(
bev_merge_coords, return_inverse=True, return_counts=True, dim=0
)
bev_f_center = points_xy - (
(bev_coords.to(points_xy.dtype) + 0.5) * bev_pixel_size + bev_range[[0, 1]]
)
bev_f_mean = torch_scatter.scatter_mean(points_xy, bev_unq_inv, dim=0)
bev_f_cluster = points_xy - bev_f_mean[bev_unq_inv, :]
bev_f_cluster = bev_f_cluster[:, [0, 1]]
mvf_input = torch.cat(
[points_xy, bev_f_center, bev_f_cluster, extra_feat], dim=1
).contiguous() # [B * N, input_dim + 4]
# get pseudo image and inital transformation,
# whose shape is [B, H, h, w]
mvf_input = mvf_input.view(B, N, -1) # [B, N, C]
mvf_input = mvf_input.transpose(1, 2) # [B, C, N]
pt_fea_in = self.input_transform(mvf_input)
pt_fea_bev = pt_fea_in.transpose(1, 2).flatten(0, 1) # [B*N, input_channels]
bev_fea_in = torch_scatter.scatter_max(pt_fea_bev, bev_unq_inv, dim=0)[0]
pixel_coords = torch.stack(
(
torch.div(bev_unq_coords, bev_scale_xy, rounding_mode="floor"),
torch.div(
bev_unq_coords % bev_scale_xy, bev_scale_y, rounding_mode="floor"
),
torch.div(bev_unq_coords % bev_scale_y, 1, rounding_mode="floor"),
bev_unq_coords % 1,
),
dim=1,
)
pixel_coords = pixel_coords[:, [0, 3, 2, 1]]
# forward image
batch_bev_features = []
for batch_idx in range(B):
feature = torch.zeros(
H,
bev_scale_xy.int().item(),
dtype=bev_fea_in.dtype,
device=bev_fea_in.device,
)
batch_mask = pixel_coords[:, 0] == batch_idx
this_coords = pixel_coords[batch_mask, :]
indices = (
this_coords[:, 1]
+ this_coords[:, 2] * bev_grid_shape[0]
+ this_coords[:, 3]
)
indices = indices.type(torch.long)
feature[:, indices] = bev_fea_in[batch_mask, :].t()
batch_bev_features.append(feature)
batch_bev_features = torch.stack(batch_bev_features, 0)
batch_bev_features = batch_bev_features.view(
B, H, bev_grid_shape[1].int().item(), bev_grid_shape[0].int().item()
)
batch_bev_features = batch_bev_features.permute(0, 1, 3, 2)
return self.cnn(batch_bev_features)
class ActorPPO(nn.Module):
"""
Actor class for **PPO** with stochastic, learnable, **state-independent** log standard deviation.
:param mid_dim[int]: the middle dimension of networks
:param state_dim[int]: the dimension of state (the number of state vector)
:param action_dim[int]: the dimension of action (the number of discrete action)
"""
def __init__(self, state_dim, mid_dim, action_dim, init_a_std_log=-0.5):
super().__init__()
nn_middle = nn.Sequential(
nn.Linear(state_dim, mid_dim),
nn.ReLU(),
nn.Linear(mid_dim, mid_dim),
nn.ReLU(),
)
self.net = nn.Sequential(
nn_middle,
nn.Linear(mid_dim, mid_dim),
nn.Hardswish(),
nn.Linear(mid_dim, action_dim),
)
# the logarithm (log) of standard deviation (std) of action, it is a trainable parameter
self.a_std_log = nn.Parameter(
torch.ones((1, action_dim)).mul_(init_a_std_log), requires_grad=True
) # calculated from action space
self.register_parameter("a_std_log", self.a_std_log)
self.sqrt_2pi_log = np.log(np.sqrt(2 * np.pi))
self.reset_parameter()
def reset_parameter(self):
for name, module in self.net.named_modules():
if isinstance(module, torch.nn.Linear):
layer_norm(module)
# rescale last layer
last_layer = self.net[-1]
assert isinstance(last_layer, torch.nn.Linear)
layer_norm(last_layer, 0.01)
def forward(self, state):
"""
The forward function.
:param state[np.array]: the input state.
:return: the output tensor.
"""
return self.net(state).tanh() # action.tanh()
def get_action(self, state):
"""
The forward function with Gaussian noise.
:param state[np.array]: the input state.
:return: the action and added noise.
"""
a_avg = self.net(state)
a_std = self.a_std_log.exp()
noise = torch.randn_like(a_avg)
action = a_avg + noise * a_std
return action, noise
def get_logprob_entropy(self, state, action):
"""
Compute the log of probability with current network.
:param state[np.array]: the input state.
:param action[float]: the action.
:return: the log of probability and entropy.
"""
a_avg = self.net(state)
a_std = self.a_std_log.exp()
dist = torch.distributions.Normal(a_avg, a_std)
logprob = dist.log_prob(action).sum(1)
dist_entropy = -dist.entropy().mean()
del dist
return logprob, dist_entropy
def get_old_logprob(self, _action, noise): # noise = action - a_noise
"""
Compute the log of probability with old network.
:param _action[float]: the action.
:param noise[float]: the added noise when exploring.
:return: the log of probability with old network.
"""
delta = noise.pow(2) * 0.5
return -(self.a_std_log + self.sqrt_2pi_log + delta).sum(1) # old_logprob
def get_old_logprob_act(self, old_action, old_noise, action):
"""
Compute the log of probability with out new noise.
:param _action[float]: the action.
:param noise[float]: the added noise when exploring.
:return: the log of probability with old network.
"""
a_std = self.a_std_log.exp()
noise = (old_action - action) / a_std - old_noise
delta = noise.pow(2) * 0.5
return -(self.a_std_log + self.sqrt_2pi_log + delta).sum(1) # old_logprob
class CriticPPO(nn.Module):
"""
The Critic class for **PPO**.
:param mid_dim[int]: the middle dimension of networks
:param state_dim[int]: the dimension of state (the number of state vector)
:param action_dim[int]: the dimension of action (the number of discrete action)
"""
def __init__(self, state_dim, mid_dim, _action_dim):
super().__init__()
nn_middle = nn.Sequential(
nn.Linear(state_dim, mid_dim),
nn.ReLU(),
)
self.net = nn.Sequential(
nn_middle,
nn.Linear(mid_dim, mid_dim),
nn.ReLU(),
nn.Linear(mid_dim, mid_dim),
nn.Hardswish(),
nn.Linear(mid_dim, 1),
)
self.reset_parameter()
def reset_parameter(self):
for name, module in self.net.named_modules():
if isinstance(module, torch.nn.Linear):
layer_norm(module)
def forward(self, state):
"""
The forward function to ouput the value of the state.
:param state[np.array]: the input state.
:return: the output tensor.
"""
return self.net(state) # advantage value