-
-
Notifications
You must be signed in to change notification settings - Fork 412
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updated python examples, added script to test building with all gcc a…
…nd python combinations
- Loading branch information
Showing
4 changed files
with
431 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
|
||
# E. Culurciello | ||
# August 2017 | ||
|
||
from __future__ import division | ||
from __future__ import print_function | ||
from vizdoom import * | ||
import itertools as it | ||
from random import sample, randint, random | ||
from time import time, sleep | ||
import numpy as np | ||
import skimage.color, skimage.transform | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torchvision import datasets, transforms | ||
from torch.autograd import Variable | ||
|
||
|
||
# NN learning settings | ||
batch_size = 64 | ||
|
||
# Training regime | ||
test_episodes_per_epoch = 100 | ||
|
||
# Other parameters | ||
frame_repeat = 12 | ||
resolution = (30, 45) | ||
episodes_to_watch = 10 | ||
|
||
model_savefile = "./model-doom.pth" | ||
save_model = True | ||
load_model = False | ||
skip_learning = False | ||
# Configuration file path | ||
config_file_path = "../../scenarios/simpler_basic.cfg" | ||
|
||
|
||
# config_file_path = "../../scenarios/rocket_basic.cfg" | ||
# config_file_path = "../../scenarios/basic.cfg" | ||
|
||
# Converts and down-samples the input image | ||
def preprocess(img): | ||
img = skimage.transform.resize(img, resolution) | ||
img = img.astype(np.float32) | ||
return img | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self, available_actions_count): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 8, kernel_size=6, stride=3) | ||
self.conv2 = nn.Conv2d(8, 8, kernel_size=3, stride=2) | ||
self.fc1 = nn.Linear(192, 128) | ||
self.fc2 = nn.Linear(128, available_actions_count) | ||
|
||
def forward(self, x): | ||
x = F.relu(self.conv1(x)) | ||
x = F.relu(self.conv2(x)) | ||
x = x.view(-1, 192) | ||
x = F.relu(self.fc1(x)) | ||
return self.fc2(x) | ||
|
||
|
||
def get_q_values(state): | ||
state = torch.from_numpy(state) | ||
state = Variable(state) | ||
return model(state) | ||
|
||
def get_best_action(state): | ||
q = get_q_values(state) | ||
m, index = torch.max(q, 1) | ||
action = index.data.numpy()[0] | ||
return action | ||
|
||
|
||
# Creates and initializes ViZDoom environment. | ||
def initialize_vizdoom(config_file_path): | ||
print("Initializing doom...") | ||
game = DoomGame() | ||
game.load_config(config_file_path) | ||
game.set_window_visible(True) | ||
game.set_mode(Mode.PLAYER) | ||
game.set_screen_format(ScreenFormat.GRAY8) | ||
game.set_screen_resolution(ScreenResolution.RES_640X480) | ||
game.init() | ||
print("Doom initialized.") | ||
return game | ||
|
||
|
||
if __name__ == '__main__': | ||
# Create Doom instance | ||
game = initialize_vizdoom(config_file_path) | ||
|
||
# Action = which buttons are pressed | ||
n = game.get_available_buttons_size() | ||
actions = [list(a) for a in it.product([0, 1], repeat=n)] | ||
|
||
print("Loading model from: ", model_savefile) | ||
model = torch.load(model_savefile) | ||
|
||
print("======================================") | ||
print("Testing trained neural network!") | ||
|
||
# Reinitialize the game with window visible | ||
game.set_window_visible(True) | ||
game.set_mode(Mode.ASYNC_PLAYER) | ||
game.init() | ||
|
||
for _ in range(episodes_to_watch): | ||
game.new_episode() | ||
while not game.is_episode_finished(): | ||
state = preprocess(game.get_state().screen_buffer) | ||
state = state.reshape([1, 1, resolution[0], resolution[1]]) | ||
best_action_index = get_best_action(state) | ||
|
||
# Instead of make_action(a, frame_repeat) in order to make the animation smooth | ||
game.set_action(actions[best_action_index]) | ||
for _ in range(frame_repeat): | ||
game.advance_action() | ||
sleep(0.03) | ||
|
||
# Sleep between episodes | ||
sleep(1.0) | ||
score = game.get_total_reward() | ||
print("Total score: ", score) |
Oops, something went wrong.