-
-
Notifications
You must be signed in to change notification settings - Fork 404
/
test_pytorch.py
executable file
·127 lines (97 loc) · 3.39 KB
/
test_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
# E. Culurciello
# August 2017
import itertools as it
from time import sleep
import numpy as np
import skimage.color
import skimage.transform
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from vizdoom import DoomGame, Mode, ScreenFormat, ScreenResolution, os, vzd
# NN learning settings
batch_size = 64
# Training regime
test_episodes_per_epoch = 100
# Other parameters
frame_repeat = 12
resolution = (30, 45)
episodes_to_watch = 10
model_savefile = "./model-doom.pth"
save_model = True
load_model = False
skip_learning = False
# Configuration file path
config_file_path = os.path.join(vzd.scenarios_path, "simpler_basic.cfg")
# config_file_path = os.path.join(vzd.scenarios_path, "rocket_basic.cfg")
# config_file_path = os.path.join(vzd.scenarios_path, "basic.cfg")
# Converts and down-samples the input image
def preprocess(img):
img = skimage.transform.resize(img, resolution)
img = img.astype(np.float32)
return img
class Net(nn.Module):
def __init__(self, available_actions_count):
super().__init__()
self.conv1 = nn.Conv2d(1, 8, kernel_size=6, stride=3)
self.conv2 = nn.Conv2d(8, 8, kernel_size=3, stride=2)
self.fc1 = nn.Linear(192, 128)
self.fc2 = nn.Linear(128, available_actions_count)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(-1, 192)
x = F.relu(self.fc1(x))
return self.fc2(x)
def get_q_values(state):
state = torch.from_numpy(state)
state = Variable(state)
return model(state)
def get_best_action(state):
q = get_q_values(state)
m, index = torch.max(q, 1)
action = index.data.numpy()[0]
return action
# Creates and initializes ViZDoom environment.
def initialize_vizdoom(config_file_path):
print("Initializing doom...")
game = DoomGame()
game.load_config(config_file_path)
game.set_window_visible(True)
game.set_mode(Mode.PLAYER)
game.set_screen_format(ScreenFormat.GRAY8)
game.set_screen_resolution(ScreenResolution.RES_640X480)
game.init()
print("Doom initialized.")
return game
if __name__ == "__main__":
# Create Doom instance
game = initialize_vizdoom(config_file_path)
# Action = which buttons are pressed
n = game.get_available_buttons_size()
actions = [list(a) for a in it.product([0, 1], repeat=n)]
print("Loading model from: ", model_savefile)
model = torch.load(model_savefile)
print("======================================")
print("Testing trained neural network!")
# Reinitialize the game with window visible
game.set_window_visible(True)
game.set_mode(Mode.ASYNC_PLAYER)
game.init()
for _ in range(episodes_to_watch):
game.new_episode()
while not game.is_episode_finished():
state = preprocess(game.get_state().screen_buffer)
state = state.reshape([1, 1, resolution[0], resolution[1]])
best_action_index = get_best_action(state)
# Instead of make_action(a, frame_repeat) in order to make the animation smooth
game.set_action(actions[best_action_index])
for _ in range(frame_repeat):
game.advance_action()
sleep(0.03)
# Sleep between episodes
sleep(1.0)
score = game.get_total_reward()
print("Total score: ", score)