diff --git a/tensor2tensor/rl/evaluator.py b/tensor2tensor/rl/evaluator.py index 0459da703..4b8728233 100644 --- a/tensor2tensor/rl/evaluator.py +++ b/tensor2tensor/rl/evaluator.py @@ -252,7 +252,7 @@ def make_env(env_type, real_env, sim_env_kwargs): def make_agent( agent_type, env, policy_hparams, policy_dir, sampling_temp, - sim_env_kwargs=None, frame_stack_size=None, rollout_agent_type=None, + sim_env_kwargs_fn=None, frame_stack_size=None, rollout_agent_type=None, batch_size=None, inner_batch_size=None, env_type=None, **planner_kwargs ): """Factory function for Agents.""" @@ -270,7 +270,7 @@ def make_agent( batch_size, make_agent( rollout_agent_type, env, policy_hparams, policy_dir, sampling_temp, batch_size=inner_batch_size - ), make_env(env_type, env.env, sim_env_kwargs), + ), make_env(env_type, env.env, sim_env_kwargs_fn()), lambda env: rl_utils.BatchStackWrapper(env, frame_stack_size), discount_factor=policy_hparams.gae_gamma, **planner_kwargs ), @@ -302,17 +302,18 @@ def make_agent_from_hparams( planner_hparams, model_dir, policy_dir, sampling_temp, video_writers=() ): """Creates an Agent from hparams.""" - sim_env_kwargs = rl.make_simulated_env_kwargs( - base_env, loop_hparams, batch_size=planner_hparams.batch_size, - model_dir=model_dir - ) + def sim_env_kwargs_fn(): + return rl.make_simulated_env_kwargs( + base_env, loop_hparams, batch_size=planner_hparams.batch_size, + model_dir=model_dir + ) planner_kwargs = planner_hparams.values() planner_kwargs.pop("batch_size") planner_kwargs.pop("rollout_agent_type") planner_kwargs.pop("env_type") return make_agent( agent_type, stacked_env, policy_hparams, policy_dir, sampling_temp, - sim_env_kwargs, loop_hparams.frame_stack_size, + sim_env_kwargs_fn, loop_hparams.frame_stack_size, planner_hparams.rollout_agent_type, inner_batch_size=planner_hparams.batch_size, env_type=planner_hparams.env_type, diff --git a/tensor2tensor/rl/rl_utils.py b/tensor2tensor/rl/rl_utils.py index 9eaac4f48..2e4c7660f 100644 --- a/tensor2tensor/rl/rl_utils.py +++ b/tensor2tensor/rl/rl_utils.py @@ -414,9 +414,8 @@ def augment_observation( (1, 15), "f:{:3}".format(int(frame_index)), fill=(255, 0, 0) ) - header = np.asarray(img) + header = np.copy(np.asarray(img)) del img - header.setflags(write=1) if bar_color is not None: header[0, :, :] = bar_color return np.concatenate([header, observation], axis=0)