-
Notifications
You must be signed in to change notification settings - Fork 3
/
dqn.py
181 lines (148 loc) · 7.39 KB
/
dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from typing import NamedTuple, Tuple
from collections import deque
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from config import Config
class Experience(NamedTuple):
state: np.ndarray
action: np.ndarray
reward: float
next_state: np.ndarray
done: bool
class ExperienceReplayBuffer:
def __init__(self, buffer_size: int, batch_size: int, config: Config):
self.memory = deque(maxlen=buffer_size)
self.batch_size = batch_size
self.device = config.DQN.device
def add_experience(self, experience: Experience) -> None:
self.memory.append(experience)
def sample(self):
samples = random.sample(self.memory, k=self.batch_size)
# Grab S, A, R, S', Done
# Each row is a sample
states = np.vstack([sample.state for sample in samples])
actions = np.vstack([sample.action for sample in samples])
rewards = np.vstack([sample.reward for sample in samples])
next_states = np.vstack([sample.next_state for sample in samples])
dones = np.vstack([sample.done for sample in samples])
# Convert the above to tensors
states = torch.from_numpy(states).float().to(self.device)
actions = torch.from_numpy(actions).long().to(self.device)
rewards = torch.from_numpy(rewards).float().to(self.device)
next_states = torch.from_numpy(next_states).float().to(self.device)
dones = torch.from_numpy(dones).byte().to(self.device)
return (states, actions, rewards, next_states, dones)
def __len__(self) -> int:
return len(self.memory)
class QNetwork(nn.Module):
"""
https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
"""
def __init__(self, input_shape: Tuple[int, int, int], num_actions: int):
super().__init__()
# Get the dimensions (height, width, channels)
h, w, c = input_shape
# Create the network
# 16 8x8 filters with stride 4
self.conv1 = nn.Conv2d(in_channels=c, out_channels=16, kernel_size=8, stride=4)
conv1_h_out, conv1_w_out = self._get_output_shape(h, w, 8, 4)
# 32 4x4 filters with stride 2
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2)
conv2_h_out, conv2_w_out = self._get_output_shape(conv1_h_out, conv1_w_out, 4, 2)
# Calculate the total number of neurons in the first flattened layer
self.num_flattened = conv2_h_out * conv2_w_out * 32
# Create flattened layer
self.dense1 = nn.Linear(self.num_flattened, 256)
# Output for actions
self.dense2 = nn.Linear(256, num_actions)
def forward(self, state: torch.Tensor):
# Convert from (batch size) x (height) x (width) x (channels) (NHWC)
# to (batch size) x (channels) x (height) x (width) (NCHW)
state = state.permute(0, 3, 1, 2).contiguous()
x = torch.relu(self.conv1(state))
x = torch.relu(self.conv2(x))
x = torch.relu(self.dense1(x.reshape([-1, self.num_flattened])))
x = self.dense2(x)
return x
def _get_output_shape(self, height: int, width: int, kernel_size: int, stride: int):
"""
Gets the output height and width based on kernel size, dilation, etc.
The only things that I consider here are height, width, kernel_size and stride.
I keep the default values of padding and stride.
If you need more information, look here:
https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
"""
h_out = math.floor(((height - (kernel_size - 1) - 1) / stride) + 1)
w_out = math.floor(((width - (kernel_size - 1) - 1) / stride) + 1)
return (h_out, w_out)
class DQNAgent:
def __init__(self, input_shape: Tuple[int, int, int], num_actions: int, config: Config):
self._step = 0
self.num_actions = num_actions
self.device = config.DQN.device
buffer_size = config.ExperienceReplay.memory_size
batch_size = config.ExperienceReplay.batch_size
self.local_net = QNetwork(input_shape, num_actions).to(self.device)
self.target_net = QNetwork(input_shape, num_actions).to(self.device)
self.target_net.load_state_dict(self.local_net.state_dict())
self.tau = config.DQN.tau
self.gamma = config.DQN.gamma
self.soft_update_every_n = config.DQN.soft_update_every_n_episodes
self.loss_func = getattr(F, config.DQN.loss)
self.optimizer = getattr(torch.optim, config.DQN.optimizer)(self.local_net.parameters())
self.experience_replay = ExperienceReplayBuffer(buffer_size, batch_size, config)
def learn(self, experiences: Tuple[Experience]) -> None:
states, actions, rewards, next_states, dones = experiences
state_action_vals = self.local_net(states).gather(1, actions)
# Get state-action values for next_states assuming greedy-policy
# unsqueeze to go from shape [batch] to [batch, 1]
state_action_vals_next_states = self.target_net(next_states).detach().max(1)[0].unsqueeze(1)
# Compute expected
expected_state_action_values = rewards + (self.gamma * state_action_vals_next_states * (1 - dones))
# Clear gradient and minimize
self.local_net.train()
self.optimizer.zero_grad()
loss = self.loss_func(state_action_vals, expected_state_action_values)
loss.backward()
self.optimizer.step()
self.soft_update()
def soft_update(self) -> None:
for target_param, policy_param in zip(self.target_net.parameters(), self.local_net.parameters()):
target_param.data.copy_(self.tau*policy_param.data + (1.0-self.tau)*target_param.data)
def step(self, experience: Experience) -> None:
self.experience_replay.add_experience(experience)
self._step = (self._step + 1) % self.soft_update_every_n
if len(self.experience_replay) > 64 and self._step == 0:
experiences = self.experience_replay.sample()
self.learn(experiences)
def act(self, state: np.ndarray, eps) -> int:
# Convert state to [1, N] where N is the number of state dimensions
state = torch.from_numpy(state).float().to(self.device)
self.local_net.eval()
with torch.no_grad():
action_vals = self.local_net(state)
self.local_net.train()
if random.random() < eps:
return random.choice(np.arange(self.num_actions))
else:
return np.argmax(action_vals.cpu().data.numpy())
def save_checkpoint(agent: DQNAgent, episode: int, eps: float, eps_end: float, eps_decay: float, path: str) -> None:
torch.save({
'episode': episode,
'eps': eps,
'eps_end': eps_end,
'eps_decay': eps_decay,
'local_model_state_dict': agent.local_net.state_dict(),
'target_model_state_dict': agent.target_net.state_dict(),
'optimizer_state_dict': agent.optimizer.state_dict()
}, path)
def load_checkpoint(agent: DQNAgent, path: str) -> Tuple[int, float, float, float]:
checkpoint = torch.load(path)
agent.local_net.load_state_dict(checkpoint['local_model_state_dict'])
agent.target_net.load_state_dict(checkpoint['target_model_state_dict'])
agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return (checkpoint['episode'], checkpoint['eps'], checkpoint['eps_end'], checkpoint['eps_decay'])