Skip to content

Commit

Permalink
Used changes from @yukang2017 in coreylynch#24 to support new keras (…
Browse files Browse the repository at this point in the history
…et. al.) versions. Also updated to properly use wrapper in eval
  • Loading branch information
enragedginger committed Aug 2, 2017
1 parent 55e6955 commit 7f14a09
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions async_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def build_graph(num_actions):
# Define cost and gradient update op
a = tf.placeholder("float", [None, num_actions])
y = tf.placeholder("float", [None])
action_q_values = tf.reduce_sum(q_values * a, reduction_indices=1)
action_q_values = tf.reduce_sum(tf.multiply(q_values, a), reduction_indices=1)
cost = tf.reduce_mean(tf.square(y - action_q_values))
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
grad_update = optimizer.minimize(cost, var_list=network_params)
Expand All @@ -204,16 +204,16 @@ def build_graph(num_actions):
# Set up some episode summary ops to visualize on tensorboard.
def setup_summaries():
episode_reward = tf.Variable(0.)
tf.scalar_summary("Episode Reward", episode_reward)
tf.summary.scalar("Episode Reward", episode_reward)
episode_ave_max_q = tf.Variable(0.)
tf.scalar_summary("Max Q Value", episode_ave_max_q)
tf.summary.scalar("Max Q Value", episode_ave_max_q)
logged_epsilon = tf.Variable(0.)
tf.scalar_summary("Epsilon", logged_epsilon)
tf.summary.scalar("Epsilon", logged_epsilon)
logged_T = tf.Variable(0.)
summary_vars = [episode_reward, episode_ave_max_q, logged_epsilon]
summary_placeholders = [tf.placeholder("float") for i in range(len(summary_vars))]
update_ops = [summary_vars[i].assign(summary_placeholders[i]) for i in range(len(summary_vars))]
summary_op = tf.merge_all_summaries()
summary_op = tf.summary.merge_all()
return summary_placeholders, update_ops, summary_op

def get_num_actions():
Expand All @@ -231,6 +231,7 @@ def get_num_actions():
return num_actions

def train(session, graph_ops, num_actions, saver):
session.run(tf.initialize_all_variables())
# Initialize target network weights
session.run(graph_ops["reset_target_network_params"])

Expand All @@ -241,9 +242,8 @@ def train(session, graph_ops, num_actions, saver):
summary_op = summary_ops[-1]

# Initialize variables
session.run(tf.initialize_all_variables())
summary_save_path = FLAGS.summary_dir + "/" + FLAGS.experiment
writer = tf.train.SummaryWriter(summary_save_path, session.graph)
writer = tf.summary.FileWriter(summary_save_path, session.graph)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)

Expand All @@ -270,7 +270,7 @@ def evaluation(session, graph_ops, saver):
saver.restore(session, FLAGS.checkpoint_path)
print "Restored model weights from ", FLAGS.checkpoint_path
monitor_env = gym.make(FLAGS.game)
monitor_env.monitor.start(FLAGS.eval_dir+"/"+FLAGS.experiment+"/eval")
monitor_env = gym.wrappers.Monitor(monitor_env, FLAGS.eval_dir+"/"+FLAGS.experiment+"/eval")

# Unpack graph ops
s = graph_ops["s"]
Expand All @@ -295,7 +295,8 @@ def evaluation(session, graph_ops, saver):

def main(_):
g = tf.Graph()
with g.as_default(), tf.Session() as session:
session = tf.Session(graph=g)
with g.as_default(), session.as_default():
K.set_session(session)
num_actions = get_num_actions()
graph_ops = build_graph(num_actions)
Expand Down

0 comments on commit 7f14a09

Please sign in to comment.