-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdqn_comparison.py
105 lines (93 loc) · 4.79 KB
/
dqn_comparison.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from rainbow import *
import gym
import torch
import matplotlib.pyplot as plt
import argparse
def set_seed(seed, env):
torch.manual_seed(seed)
if torch.backends.cudnn.enabled:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
np.random.seed(seed)
random.seed(seed)
env.seed(seed)
if __name__ == '__main__':
# python dqn_comparison.py --num_frames 20000 --plotting_interval 1000
ap = argparse.ArgumentParser()
ap.add_argument("-nf", "--num_frames", type=int, default=2000,
help="number of training frames")
ap.add_argument("-plt", "--plot", default=False, action='store_true',
help="Plot training stats during training for each network")
ap.add_argument("-pi", "--plotting_interval", type=int, default=100,
help="Number of steps per plots update")
args = ap.parse_args()
# hyper parameters
num_frames = args.num_frames
memory_size = args.num_frames / 10
batch_size = 32
target_update = args.num_frames / 10
plotting_interval = args.plotting_interval
plot = args.plot
# seed
seed = 777
# make environment
env_id = "CartPole-v0"
env = gym.make(env_id)
set_seed(seed, env)
# train
agent_dqn = DQNAgent(env, memory_size, batch_size, target_update,
no_dueling=True, no_categorical=True, no_double=True,
no_n_step=True, no_noise=True, no_priority=True,
plot=plot, frame_interval=plotting_interval)
agent_double_dqn = DQNAgent(env, memory_size, batch_size, target_update,
no_dueling=True, no_categorical=True, no_double=False,
no_n_step=True, no_noise=True, no_priority=True, plot=plot,
frame_interval=plotting_interval)
agent_prioritized_dqn = DQNAgent(env, memory_size, batch_size, target_update,
no_dueling=True, no_categorical=True, no_double=False,
no_n_step=True, no_noise=True, no_priority=False,
plot=plot, frame_interval=plotting_interval)
agent_dueling_dqn = DQNAgent(env, memory_size, batch_size, target_update,
no_dueling=False, no_categorical=True, no_double=False,
no_n_step=True, no_noise=True, no_priority=True,
plot=plot, frame_interval=plotting_interval)
agent_noisy_dqn = DQNAgent(env, memory_size, batch_size, target_update,
no_dueling=True, no_categorical=True, no_double=False,
no_n_step=True, no_noise=False, no_priority=True,
plot=plot, frame_interval=plotting_interval)
agent_categorical_dqn = DQNAgent(env, memory_size, batch_size, target_update,
no_dueling=True, no_categorical=False, no_double=False,
no_n_step=True, no_noise=True, no_priority=True,
plot=plot, frame_interval=plotting_interval)
agent_n_step_dqn = DQNAgent(env, memory_size, batch_size, target_update,
no_dueling=True, no_categorical=True, no_double=False,
no_n_step=False, no_noise=True, no_priority=True,
plot=plot, frame_interval=plotting_interval)
agent_rainbow = DQNAgent(env, memory_size, batch_size, target_update,
no_dueling=False, no_categorical=False, no_double=False,
no_n_step=False, no_noise=False, no_priority=False,
plot=plot, frame_interval=plotting_interval)
agents = [agent_dqn, agent_double_dqn, agent_prioritized_dqn, agent_dueling_dqn,
agent_noisy_dqn, agent_categorical_dqn, agent_n_step_dqn, agent_rainbow]
labels = ["DQN", "DDQN", "Prioritized DDQN", "Dueling DDQN",
"Noisy DDQN", "Categorical DDQN", "N-step DDQN", "Rainbow"]
scores = []
losses = []
for i, agent in enumerate(agents):
print("Training agent", labels[i])
score, loss = agent.train(num_frames)
scores.append(score)
losses.append(loss)
# create a color palette
palette = plt.get_cmap('Set1')
plt.figure(figsize=(20, 5))
plt.subplot(131)
plt.title('Training frames: %s' % num_frames)
for i in range(len(scores)):
linewidth = 1.
if i == len(scores) - 1:
linewidth = 3.
plt.plot(scores[i], marker='', color=palette(i), linewidth=linewidth, alpha=1., label=labels[i])
plt.legend(loc=2, ncol=1)
plt.xlabel("Frames x " + str(plotting_interval))
plt.ylabel("Score")