Skip to content

Commit

Permalink
Corrected tf example to comply with tf 1.0.
Browse files Browse the repository at this point in the history
Former-commit-id: 0a43941
  • Loading branch information
mihahauke committed Feb 26, 2017
1 parent f473833 commit 2257561
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions examples/python/learning_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from __future__ import print_function
from vizdoom import *
import itertools as it
import pickle
from random import sample, randint, random
from time import time, sleep
import numpy as np
Expand Down Expand Up @@ -215,7 +214,7 @@ def initialize_vizdoom(config_file_path):
print("Loading model from: ", model_savefile)
saver.restore(session, model_savefile)
else:
init = tf.initialize_all_variables()
init = tf.global_variables_initializer()
session.run(init)
print("Starting the training!")

Expand All @@ -228,7 +227,7 @@ def initialize_vizdoom(config_file_path):

print("Training...")
game.new_episode()
for learning_step in trange(learning_steps_per_epoch):
for learning_step in trange(learning_steps_per_epoch, leave=False):
perform_learning_step(epoch)
if game.is_episode_finished():
score = game.get_total_reward()
Expand All @@ -246,7 +245,7 @@ def initialize_vizdoom(config_file_path):
print("\nTesting...")
test_episode = []
test_scores = []
for test_episode in trange(test_episodes_per_epoch):
for test_episode in trange(test_episodes_per_epoch, leave=False):
game.new_episode()
while not game.is_episode_finished():
state = preprocess(game.get_state().screen_buffer)
Expand All @@ -263,7 +262,6 @@ def initialize_vizdoom(config_file_path):

print("Saving the network weigths to:", model_savefile)
saver.save(session, model_savefile)
# pickle.dump(get_all_param_values(net), open('weights.dump', "wb"))

print("Total elapsed time: %.2f minutes" % ((time() - time_start) / 60.0))

Expand Down

0 comments on commit 2257561

Please sign in to comment.