Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tensor2tensor/models/research/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,19 @@ def mfrl_base():
hparams = mfrl_original()
hparams.add_hparam("ppo_epochs_num", 3000)
hparams.add_hparam("ppo_eval_every_epochs", 100)
hparams.add_hparam("eval_max_num_noops", 8)
hparams.add_hparam("resize_height_factor", 2)
hparams.add_hparam("resize_width_factor", 2)
hparams.add_hparam("grayscale", 1)
hparams.add_hparam("env_timesteps_limit", -1)
return hparams


@registry.register_hparams
def mfrl_tiny():
hparams = mfrl_base()
hparams.ppo_epochs_num = 100
hparams.ppo_eval_every_epochs = 10
return hparams


Expand Down
125 changes: 125 additions & 0 deletions tensor2tensor/rl/rl_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Utilities for RL training
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import six

from tensor2tensor.data_generators.gym_env import T2TGymEnv
from tensor2tensor.models.research import rl
from tensor2tensor.rl.dopamine_connector import DQNLearner
from tensor2tensor.rl.ppo_learner import PPOLearner
from tensor2tensor.utils import trainer_lib

import tensorflow as tf


flags = tf.flags
FLAGS = flags.FLAGS


def compute_mean_reward(rollouts, clipped):
"""Calculate mean rewards from given epoch."""
reward_name = "reward" if clipped else "unclipped_reward"
rewards = []
for rollout in rollouts:
if rollout[-1].done:
rollout_reward = sum(getattr(frame, reward_name) for frame in rollout)
rewards.append(rollout_reward)
if rewards:
mean_rewards = np.mean(rewards)
else:
mean_rewards = 0
return mean_rewards


def get_metric_name(stochastic, max_num_noops, clipped):
return "mean_reward/eval/stochastic_{}_max_noops_{}_{}".format(
stochastic, max_num_noops, "clipped" if clipped else "unclipped")


def evaluate_single_config(hparams, stochastic, max_num_noops,
agent_model_dir):
"""Evaluate the PPO agent in the real environment."""
eval_hparams = trainer_lib.create_hparams(hparams.base_algo_params)
env = setup_env(
hparams, batch_size=hparams.eval_batch_size, max_num_noops=max_num_noops
)
env.start_new_epoch(0)
env_fn = rl.make_real_env_fn(env)
learner = LEARNERS[hparams.base_algo](
hparams.frame_stack_size, base_event_dir=None,
agent_model_dir=agent_model_dir
)
learner.evaluate(env_fn, eval_hparams, stochastic)
rollouts = env.current_epoch_rollouts()
env.close()

return tuple(
compute_mean_reward(rollouts, clipped) for clipped in (True, False)
)


def evaluate_all_configs(hparams, agent_model_dir):
"""Evaluate the agent with multiple eval configurations."""
metrics = {}
# Iterate over all combinations of picking actions by sampling/mode and
# whether to do initial no-ops.
for stochastic in (True, False):
for max_num_noops in (hparams.eval_max_num_noops, 0):
scores = evaluate_single_config(
hparams, stochastic, max_num_noops, agent_model_dir
)
for (score, clipped) in zip(scores, (True, False)):
metric_name = get_metric_name(stochastic, max_num_noops, clipped)
metrics[metric_name] = score

return metrics


LEARNERS = {
"ppo": PPOLearner,
"dqn": DQNLearner,
}


def setup_env(hparams, batch_size, max_num_noops):
"""Setup."""
game_mode = "Deterministic-v4"
camel_game_name = "".join(
[w[0].upper() + w[1:] for w in hparams.game.split("_")])
camel_game_name += game_mode
env_name = camel_game_name

env = T2TGymEnv(base_env_name=env_name,
batch_size=batch_size,
grayscale=hparams.grayscale,
resize_width_factor=hparams.resize_width_factor,
resize_height_factor=hparams.resize_height_factor,
base_env_timesteps_limit=hparams.env_timesteps_limit,
max_num_noops=max_num_noops)
return env

def update_hparams_from_hparams(target_hparams, source_hparams, prefix):
"""Copy a subset of hparams to target_hparams."""
for (param_name, param_value) in six.iteritems(source_hparams.values()):
if param_name.startswith(prefix):
target_hparams.set_hparam(param_name[len(prefix):], param_value)
109 changes: 9 additions & 100 deletions tensor2tensor/rl/trainer_model_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,10 @@
import six

from tensor2tensor.bin import t2t_trainer # pylint: disable=unused-import
from tensor2tensor.data_generators.gym_env import T2TGymEnv
from tensor2tensor.layers import common_video
from tensor2tensor.models.research import rl
from tensor2tensor.rl import rl_utils
from tensor2tensor.rl import trainer_model_based_params
from tensor2tensor.rl.dopamine_connector import DQNLearner
from tensor2tensor.rl.ppo_learner import PPOLearner
from tensor2tensor.utils import trainer_lib

import tensorflow as tf
Expand All @@ -52,19 +50,6 @@
FLAGS = flags.FLAGS


LEARNERS = {
"ppo": PPOLearner,
"dqn": DQNLearner,
}


def update_hparams_from_hparams(target_hparams, source_hparams, prefix):
"""Copy a subset of hparams to target_hparams."""
for (param_name, param_value) in six.iteritems(source_hparams.values()):
if param_name.startswith(prefix):
target_hparams.set_hparam(param_name[len(prefix):], param_value)


def real_env_step_increment(hparams):
"""Real env step increment."""
return int(math.ceil(
Expand Down Expand Up @@ -207,7 +192,7 @@ def initial_frame_chooser(batch_size):
base_algo_str = hparams.base_algo
train_hparams = trainer_lib.create_hparams(hparams.base_algo_params)

update_hparams_from_hparams(
rl_utils.update_hparams_from_hparams(
train_hparams, hparams, base_algo_str + "_"
)

Expand All @@ -223,7 +208,7 @@ def train_agent_real_env(env, learner, hparams, epoch):
base_algo_str = hparams.base_algo

train_hparams = trainer_lib.create_hparams(hparams.base_algo_params)
update_hparams_from_hparams(
rl_utils.update_hparams_from_hparams(
train_hparams, hparams, "real_" + base_algo_str + "_"
)

Expand Down Expand Up @@ -263,82 +248,6 @@ def train_world_model(
return world_model_steps_num


def setup_env(hparams, batch_size, max_num_noops):
"""Setup."""
game_mode = "Deterministic-v4"
camel_game_name = "".join(
[w[0].upper() + w[1:] for w in hparams.game.split("_")])
camel_game_name += game_mode
env_name = camel_game_name

env = T2TGymEnv(base_env_name=env_name,
batch_size=batch_size,
grayscale=hparams.grayscale,
resize_width_factor=hparams.resize_width_factor,
resize_height_factor=hparams.resize_height_factor,
base_env_timesteps_limit=hparams.env_timesteps_limit,
max_num_noops=max_num_noops)
return env


def evaluate_single_config(hparams, stochastic, max_num_noops, agent_model_dir):
"""Evaluate the PPO agent in the real environment."""
eval_hparams = trainer_lib.create_hparams(hparams.base_algo_params)
env = setup_env(
hparams, batch_size=hparams.eval_batch_size, max_num_noops=max_num_noops
)
env.start_new_epoch(0)
env_fn = rl.make_real_env_fn(env)
learner = LEARNERS[hparams.base_algo](
hparams.frame_stack_size, base_event_dir=None,
agent_model_dir=agent_model_dir
)
learner.evaluate(env_fn, eval_hparams, stochastic)
rollouts = env.current_epoch_rollouts()
env.close()

return tuple(
compute_mean_reward(rollouts, clipped) for clipped in (True, False)
)


def get_metric_name(stochastic, max_num_noops, clipped):
return "mean_reward/eval/stochastic_{}_max_noops_{}_{}".format(
stochastic, max_num_noops, "clipped" if clipped else "unclipped")


def evaluate_all_configs(hparams, agent_model_dir):
"""Evaluate the agent with multiple eval configurations."""
metrics = {}
# Iterate over all combinations of picking actions by sampling/mode and
# whether to do initial no-ops.
for stochastic in (True, False):
for max_num_noops in (hparams.eval_max_num_noops, 0):
scores = evaluate_single_config(
hparams, stochastic, max_num_noops, agent_model_dir
)
for (score, clipped) in zip(scores, (True, False)):
metric_name = get_metric_name(stochastic, max_num_noops, clipped)
metrics[metric_name] = score

return metrics


def compute_mean_reward(rollouts, clipped):
"""Calculate mean rewards from given epoch."""
reward_name = "reward" if clipped else "unclipped_reward"
rewards = []
for rollout in rollouts:
if rollout[-1].done:
rollout_reward = sum(getattr(frame, reward_name) for frame in rollout)
rewards.append(rollout_reward)
if rewards:
mean_rewards = np.mean(rewards)
else:
mean_rewards = 0
return mean_rewards


def evaluate_world_model(real_env, hparams, world_model_dir, debug_video_path):
"""Evaluate the world model (reward accuracy)."""
frame_stack_size = hparams.frame_stack_size
Expand Down Expand Up @@ -485,13 +394,13 @@ def training_loop(hparams, output_dir, report_fn=None, report_metric=None):

epoch = -1
data_dir = directories["data"]
env = setup_env(
env = rl_utils.setup_env(
hparams, batch_size=hparams.real_batch_size,
max_num_noops=hparams.max_num_noops
)
env.start_new_epoch(epoch, data_dir)

learner = LEARNERS[hparams.base_algo](
learner = rl_utils.LEARNERS[hparams.base_algo](
hparams.frame_stack_size, directories["policy"],
directories["policy"]
)
Expand All @@ -507,7 +416,7 @@ def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
policy_model_dir = directories["policy"]
tf.logging.info("Initial training of the policy in real environment.")
train_agent_real_env(env, learner, hparams, epoch)
metrics["mean_reward/train/clipped"] = compute_mean_reward(
metrics["mean_reward/train/clipped"] = rl_utils.compute_mean_reward(
env.current_epoch_rollouts(), clipped=True
)
tf.logging.info("Mean training reward (initial): {}".format(
Expand Down Expand Up @@ -555,14 +464,14 @@ def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
# we'd overwrite them with wrong data.
log("Metrics found for this epoch, skipping evaluation.")
else:
metrics["mean_reward/train/clipped"] = compute_mean_reward(
metrics["mean_reward/train/clipped"] = rl_utils.compute_mean_reward(
env.current_epoch_rollouts(), clipped=True
)
log("Mean training reward: {}".format(
metrics["mean_reward/train/clipped"]
))

eval_metrics = evaluate_all_configs(hparams, policy_model_dir)
eval_metrics = rl_utils.evaluate_all_configs(hparams, policy_model_dir)
log("Agent eval metrics:\n{}".format(pprint.pformat(eval_metrics)))
metrics.update(eval_metrics)

Expand All @@ -582,7 +491,7 @@ def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
# Report metrics
if report_fn:
if report_metric == "mean_reward":
metric_name = get_metric_name(
metric_name = rl_utils.get_metric_name(
stochastic=True, max_num_noops=hparams.eval_max_num_noops,
clipped=False
)
Expand Down
Loading