forked from Kaixhin/Rainbow
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
128 lines (104 loc) · 4.89 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import math
import torch
from torch import nn
from torch.nn import functional as F
# Factorised NoisyLinear layer with bias
class NoisyLinear(nn.Module):
"""NosiyLinear called Nosiy Networks has been studied in DeepMind.
A deep reinforcement learning agent with parametric noise added to its weights,
and show that th induced stochasticity of the agent's policy can be used to aid efficient exploration.
"""
def __init__(self, in_features, out_features, std_init=0.5):
"""This module extends torch.nn.Linear
Args:
in_features: the number of input feature
out_features: the number of output feature
std_init: parameter for NoisyLinear
"""
super(NoisyLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.std_init = std_init
self.weight_mu = nn.Parameter(torch.empty(out_features, in_features))
self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features))
self.register_buffer('weight_epsilon', torch.empty(out_features, in_features))
self.bias_mu = nn.Parameter(torch.empty(out_features))
self.bias_sigma = nn.Parameter(torch.empty(out_features))
self.register_buffer('bias_epsilon', torch.empty(out_features))
self.reset_parameters()
self.reset_noise()
def reset_parameters(self):
"""This method for reset layer parameter.
Notes:
For factorised noisy networks, each element mu_i,j was initialised by a sample
from an independent uniform distribuntions MU[-1/root(p),+1/root(p)] and
each element sigma_i,j was initialised to a contant sigma_0/root(p).
in paper, hyperparameter sigma_0 is set to 0.5
std_init=0.5
"""
mu_range = 1 / math.sqrt(self.in_features)
self.weight_mu.data.uniform_(-mu_range, mu_range)
self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.in_features))
self.bias_mu.data.uniform_(-mu_range, mu_range)
self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.out_features))
def _scale_noise(self, size):
"""This method for scale noise by the number of input/output features.
Args:
size: size is int for setting scale
Returns:
scaled noise
"""
x = torch.randn(size)
return x.sign().mul_(x.abs().sqrt_())
def reset_noise(self):
"""This method make initialized noise.
The Noise depends on the number of input/output featuers.
"""
epsilon_in = self._scale_noise(self.in_features)
epsilon_out = self._scale_noise(self.out_features)
self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
self.bias_epsilon.copy_(epsilon_out)
def forward(self, input):
"""This method is override nn.Linear's forward
Args:
input: Input data
Returns:
Return is nn.Linear's output. but use noisy parameter.
"""
if self.training:
return F.linear(input, self.weight_mu + self.weight_sigma * self.weight_epsilon,
self.bias_mu + self.bias_sigma * self.bias_epsilon)
else:
return F.linear(input, self.weight_mu, self.bias_mu)
class DQN(nn.Module):
"""This is DQN where 'C51', 'Duelling', 'NoisyNetwork'
"""
def __init__(self, args, action_space):
super().__init__()
self.atoms = args.atoms
self.action_space = action_space
self.conv1 = nn.Conv2d(args.history_length, 32, 8, stride=4, padding=1)
self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
self.conv3 = nn.Conv2d(64, 64, 3)
self.fc_h_v = NoisyLinear(3136, args.hidden_size, std_init=args.noisy_std)
self.fc_h_a = NoisyLinear(3136, args.hidden_size, std_init=args.noisy_std)
self.fc_z_v = NoisyLinear(args.hidden_size, self.atoms, std_init=args.noisy_std)
self.fc_z_a = NoisyLinear(args.hidden_size, action_space * self.atoms, std_init=args.noisy_std)
def forward(self, x, log=False):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(-1, 3136)
v = self.fc_z_v(F.relu(self.fc_h_v(x))) # Value stream
a = self.fc_z_a(F.relu(self.fc_h_a(x))) # Advantage stream
v, a = v.view(-1, 1, self.atoms), a.view(-1, self.action_space, self.atoms)
q = v + a - a.mean(1, keepdim=True) # Combine streams
if log: # Use log softmax for numerical stability
q = F.log_softmax(q, dim=2) # Log probabilities with action over second dimension
else:
q = F.softmax(q, dim=2) # Probabilities with action over second dimension
return q
def reset_noise(self):
for name, module in self.named_children():
if 'fc' in name:
module.reset_noise()