-
Notifications
You must be signed in to change notification settings - Fork 2
/
model.py
78 lines (65 loc) · 2.43 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import Config
class BaseNN(nn.Module):
"""Superclass for the Actor and Critic classes"""
def __init__(self):
super(BaseNN, self).__init__()
self.config = Config()
self.to(self.config.device)
torch.manual_seed(self.config.seed)
self.module_list = nn.ModuleList()
def create_fc_layer(self, nodes_in, nodes_out):
layer = nn.Linear(nodes_in, nodes_out)
self.reset_parameters(layer)
self.module_list.append(layer)
def reset_parameters(self, layer):
layer.weight.data.uniform_(-3e-3, 3e-3)
class Actor(BaseNN):
"""Build an actor (policy) network that maps states -> actions."""
def __init__(self):
super(Actor, self).__init__()
for nodes_in, nodes_out in self.layers_nodes():
self.create_fc_layer(nodes_in, nodes_out)
def layers_nodes(self):
nodes = []
nodes.append(self.config.state_size)
nodes.extend(self.config.actor_layers)
nodes.append(self.config.action_size)
nodes_in = nodes[:-1]
nodes_out = nodes[1:]
return zip(nodes_in, nodes_out)
def forward(self, x):
for layer in self.module_list[:-1]:
x = F.relu(layer(x))
x = self.module_list[-1](x)
return torch.tanh(x)
class Critic(BaseNN):
"""Build a critic (value) network that maps
(state, action) pair -> Q-values.
"""
def __init__(self):
super(Critic, self).__init__()
for nodes_in, nodes_out in self.layers_nodes():
self.create_fc_layer(nodes_in, nodes_out)
if self.config.batch_normalization:
self.bn = nn.BatchNorm1d(self.module_list[1].in_features)
def layers_nodes(self):
nodes = []
nodes.append(self.config.state_size * self.config.num_agents)
nodes.extend(self.config.critic_layers)
nodes.append(1)
nodes_in = nodes[:-1]
nodes_in[1] += self.config.num_agents * self.config.action_size
nodes_out = nodes[1:]
return zip(nodes_in, nodes_out)
def forward(self, state, action):
x = F.relu(self.module_list[0](state))
x = torch.cat((x, action), dim=1)
if self.config.batch_normalization:
x = self.bn(x)
for layer in self.module_list[1:-1]:
x = F.relu(layer(x))
x = self.module_list[-1](x)
return torch.sigmoid(x)