-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathmodel.py
166 lines (135 loc) · 6.81 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import numpy as np
import torch
from torch.distributions import Categorical
from torch import nn
from torch.nn import functional as F
from transformer import Transformer
class ActorCriticModel(nn.Module):
def __init__(self, config, observation_space, action_space_shape, max_episode_length):
"""Model setup
Arguments:
config {dict} -- Configuration and hyperparameters of the environment, trainer and model.
observation_space {box} -- Properties of the agent's observation space
action_space_shape {tuple} -- Dimensions of the action space
max_episode_length {int} -- The maximum number of steps in an episode
"""
super().__init__()
self.hidden_size = config["hidden_layer_size"]
self.memory_layer_size = config["transformer"]["embed_dim"]
self.observation_space_shape = observation_space.shape
self.max_episode_length = max_episode_length
# Observation encoder
if len(self.observation_space_shape) > 1:
# Case: visual observation is available
# Visual encoder made of 3 convolutional layers
self.conv1 = nn.Conv2d(observation_space.shape[0], 32, 8, 4,)
self.conv2 = nn.Conv2d(32, 64, 4, 2, 0)
self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
nn.init.orthogonal_(self.conv1.weight, np.sqrt(2))
nn.init.orthogonal_(self.conv2.weight, np.sqrt(2))
nn.init.orthogonal_(self.conv3.weight, np.sqrt(2))
# Compute output size of convolutional layers
self.conv_out_size = self.get_conv_output(observation_space.shape)
in_features_next_layer = self.conv_out_size
else:
# Case: vector observation is available
in_features_next_layer = observation_space.shape[0]
# Hidden layer
self.lin_hidden = nn.Linear(in_features_next_layer, self.memory_layer_size)
nn.init.orthogonal_(self.lin_hidden.weight, np.sqrt(2))
# Transformer Blocks
self.transformer = Transformer(config["transformer"], self.memory_layer_size, self.max_episode_length)
# Decouple policy from value
# Hidden layer of the policy
self.lin_policy = nn.Linear(self.memory_layer_size, self.hidden_size)
nn.init.orthogonal_(self.lin_policy.weight, np.sqrt(2))
# Hidden layer of the value function
self.lin_value = nn.Linear(self.memory_layer_size, self.hidden_size)
nn.init.orthogonal_(self.lin_value.weight, np.sqrt(2))
# Outputs / Model heads
# Policy (Multi-discrete categorical distribution)
self.policy_branches = nn.ModuleList()
for num_actions in action_space_shape:
actor_branch = nn.Linear(in_features=self.hidden_size, out_features=num_actions)
nn.init.orthogonal_(actor_branch.weight, np.sqrt(0.01))
self.policy_branches.append(actor_branch)
# Value function
self.value = nn.Linear(self.hidden_size, 1)
nn.init.orthogonal_(self.value.weight, 1)
def forward(self, obs:torch.tensor, memory:torch.tensor, memory_mask:torch.tensor, memory_indices:torch.tensor):
"""Forward pass of the model
Arguments:
obs {torch.tensor} -- Batch of observations
memory {torch.tensor} -- Episodic memory window
memory_mask {torch.tensor} -- Mask to prevent the model from attending to the padding
memory_indices {torch.tensor} -- Indices to select the positional encoding that matches the memory window
Returns:
{Categorical} -- Policy: Categorical distribution
{torch.tensor} -- Value function: Value
"""
# Set observation as input to the model
h = obs
# Forward observation encoder
if len(self.observation_space_shape) > 1:
batch_size = h.size()[0]
# Propagate input through the visual encoder
h = F.relu(self.conv1(h))
h = F.relu(self.conv2(h))
h = F.relu(self.conv3(h))
# Flatten the output of the convolutional layers
h = h.reshape((batch_size, -1))
# Feed hidden layer
h = F.relu(self.lin_hidden(h))
# Forward transformer blocks
h, memory = self.transformer(h, memory, memory_mask, memory_indices)
# Decouple policy from value
# Feed hidden layer (policy)
h_policy = F.relu(self.lin_policy(h))
# Feed hidden layer (value function)
h_value = F.relu(self.lin_value(h))
# Head: Value function
value = self.value(h_value).reshape(-1)
# Head: Policy
pi = [Categorical(logits=branch(h_policy)) for branch in self.policy_branches]
return pi, value, memory
def get_conv_output(self, shape:tuple) -> int:
"""Computes the output size of the convolutional layers by feeding a dummy tensor.
Arguments:
shape {tuple} -- Input shape of the data feeding the first convolutional layer
Returns:
{int} -- Number of output features returned by the utilized convolutional layers
"""
o = self.conv1(torch.zeros(1, *shape))
o = self.conv2(o)
o = self.conv3(o)
return int(np.prod(o.size()))
def get_grad_norm(self):
"""Returns the norm of the gradients of the model.
Returns:
{dict} -- Dictionary of gradient norms grouped by layer name
"""
grads = {}
if len(self.observation_space_shape) > 1:
grads["encoder"] = self._calc_grad_norm(self.conv1, self.conv2, self.conv3)
grads["linear_layer"] = self._calc_grad_norm(self.lin_hidden)
transfomer_blocks = self.transformer.transformer_blocks
for i, block in enumerate(transfomer_blocks):
grads["transformer_block_" + str(i)] = self._calc_grad_norm(block)
for i, head in enumerate(self.policy_branches):
grads["policy_head_" + str(i)] = self._calc_grad_norm(head)
grads["lin_policy"] = self._calc_grad_norm(self.lin_policy)
grads["value"] = self._calc_grad_norm(self.lin_value, self.value)
grads["model"] = self._calc_grad_norm(self, self.value)
return grads
def _calc_grad_norm(self, *modules):
"""Computes the norm of the gradients of the given modules.
Arguments:
modules {list} -- List of modules to compute the norm of the gradients of.
Returns:
{float} -- Norm of the gradients of the given modules.
"""
grads = []
for module in modules:
for name, parameter in module.named_parameters():
grads.append(parameter.grad.view(-1))
return torch.linalg.norm(torch.cat(grads)).item() if len(grads) > 0 else None