-
Notifications
You must be signed in to change notification settings - Fork 0
/
test-trained-network.lua
executable file
·92 lines (74 loc) · 2.46 KB
/
test-trained-network.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
-- Eugenio Culurciello
-- December 2016
-- test a trained deep Q learning neural network
local base_path = "../../" -- path to ViZDoom's root dir
require "vizdoom"
require 'nn'
require 'image'
torch.setnumthreads(8)
torch.setdefaulttensortype('torch.FloatTensor')
local opt = {}
opt.fpath = arg[1]
if not opt.fpath then print('missing arg #1: th test.lua results/model-20.net') return end
-- load trained network:
local model = torch.load(opt.fpath)
local config_file_path = base_path.."scenarios/simpler_basic.cfg"
-- local config_file_path = base_path.."scenarios/rocket_basic.cfg"
-- local config_file_path = base_path.."scenarios/basic.cfg"
-- Doom game actions:
local actions = {
[1] = torch.Tensor({1,0,0}),
[2] = torch.Tensor({0,1,0}),
[3] = torch.Tensor({0,0,1})
}
-- Other parameters
local resolution = {30, 45} -- Y, X sizes of rescaled state / game screen
-- Converts and down-samples the input image
local function preprocess(inImage)
return image.scale(inImage, unpack(resolution))
end
-- Creates and initializes ViZDoom environment:
function initializeViZdoom(config_file_path)
print("Initializing doom...")
game = vizdoom.DoomGame()
game:setViZDoomPath(base_path.."bin/vizdoom")
game:setDoomGamePath(base_path.."scenarios/freedoom2.wad")
game:loadConfig(config_file_path)
game:setMode(vizdoom.Mode.PLAYER)
game:setScreenFormat(vizdoom.ScreenFormat.GRAY8)
game:setScreenResolution(vizdoom.ScreenResolution.RES_640X480)
game:init()
print("Doom initialized.")
return game
end
function getQValues(state)
return model:forward(state)
end
function getBestAction(state)
local q = getQValues(state:float():reshape(1, 1, resolution[1], resolution[2]))
local max, index = torch.max(q, 1)
local action = index[1]
return action, q
end
-- Create Doom instance:
local game = initializeViZdoom(config_file_path)
-- Reinitialize the game with window visible:
game:setWindowVisible(true)
game:setMode(vizdoom.Mode.ASYNC_PLAYER)
game:init()
for i = 1, 20 do
game:newEpisode()
while not game:isEpisodeFinished() do
local state = preprocess(game:getState().screenBuffer:float():div(255))
local best_action_index = getBestAction(state)
game:makeAction(actions[best_action_index])
for j = 1, 12 do
game:advanceAction()
end
end
-- Sleep between episodes:
sys.sleep(1)
local score = game:getTotalReward()
print("Total score: ", score)
end
game:close()