Skip to content

Commit 286eae9

Browse files
blazejosinskikpe
authored andcommitted
Differentiate summaries for train and eval. (tensorflow#1256)
1 parent 560dbfe commit 286eae9

File tree

2 files changed

+10
-31
lines changed

2 files changed

+10
-31
lines changed

tensor2tensor/models/research/rl.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -178,28 +178,6 @@ def ppo_pong_ae_base():
178178
return hparams
179179

180180

181-
@registry.register_hparams
182-
def pong_model_free():
183-
"""TODO(piotrmilos): Document this."""
184-
hparams = mfrl_base()
185-
hparams.batch_size = 2
186-
hparams.ppo_eval_every_epochs = 2
187-
hparams.ppo_epochs_num = 4
188-
hparams.add_hparam("ppo_optimization_epochs", 3)
189-
hparams.add_hparam("ppo_epoch_length", 30)
190-
hparams.add_hparam("ppo_learning_rate", 8e-05)
191-
hparams.add_hparam("ppo_optimizer", "Adam")
192-
hparams.add_hparam("ppo_optimization_batch_size", 4)
193-
hparams.add_hparam("ppo_save_models_every_epochs", 1000000)
194-
env = gym_env.T2TGymEnv("PongNoFrameskip-v4", batch_size=2)
195-
env.start_new_epoch(0)
196-
hparams.add_hparam("env_fn", make_real_env_fn(env))
197-
eval_env = gym_env.T2TGymEnv("PongNoFrameskip-v4", batch_size=2)
198-
eval_env.start_new_epoch(0)
199-
hparams.add_hparam("eval_env_fn", make_real_env_fn(eval_env))
200-
return hparams
201-
202-
203181
@registry.register_hparams
204182
def dqn_atari_base():
205183
# These params are based on agents/dqn/configs/dqn.gin
@@ -242,7 +220,7 @@ def dqn_original_params():
242220
@registry.register_hparams
243221
def mfrl_original():
244222
return tf.contrib.training.HParams(
245-
game="",
223+
game="pong",
246224
base_algo="ppo",
247225
base_algo_params="ppo_original_params",
248226
batch_size=16,

tensor2tensor/rl/ppo_learner.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -460,11 +460,12 @@ def stop_condition(i, _, resets):
460460
new_memory.append(mem)
461461
memory = new_memory
462462

463-
mean_score_summary = tf.cond(
464-
tf.greater(scores_num, 0),
465-
lambda: tf.summary.scalar("mean_score_this_iter", mean_score), str)
466-
summaries = tf.summary.merge([
467-
mean_score_summary,
468-
tf.summary.scalar("episodes_finished_this_iter", scores_num)
469-
])
470-
return memory, summaries, initialization_lambda
463+
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
464+
mean_score_summary = tf.cond(
465+
tf.greater(scores_num, 0),
466+
lambda: tf.summary.scalar("mean_score_this_iter", mean_score), str)
467+
summaries = tf.summary.merge([
468+
mean_score_summary,
469+
tf.summary.scalar("episodes_finished_this_iter", scores_num)
470+
])
471+
return memory, summaries, initialization_lambda

0 commit comments

Comments
 (0)