diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index 836d9f074999..eb384058de80 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -49,8 +49,8 @@ class A3CTrainer(Trainer): def _init(self, config, env_creator): if config["use_pytorch"]: from ray.rllib.agents.a3c.a3c_torch_policy_graph import \ - A3CTorchPolicyGraph - policy_cls = A3CTorchPolicyGraph + A3CTorchPolicy + policy_cls = A3CTorchPolicy else: policy_cls = self._policy_graph diff --git a/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py index d35aabe0d667..fa6f857f9eca 100644 --- a/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py @@ -7,109 +7,84 @@ from torch import nn import ray -from ray.rllib.models.catalog import ModelCatalog from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph -from ray.rllib.utils.annotations import override - - -class A3CLoss(nn.Module): - def __init__(self, dist_class, vf_loss_coeff=0.5, entropy_coeff=0.01): - nn.Module.__init__(self) - self.dist_class = dist_class - self.vf_loss_coeff = vf_loss_coeff - self.entropy_coeff = entropy_coeff - - def forward(self, policy_model, observations, actions, advantages, - value_targets): - logits, _, values, _ = policy_model({ - SampleBatch.CUR_OBS: observations - }, []) - dist = self.dist_class(logits) - log_probs = dist.logp(actions) - self.entropy = dist.entropy().mean() - self.pi_err = -advantages.dot(log_probs.reshape(-1)) - self.value_err = F.mse_loss(values.reshape(-1), value_targets) - overall_err = sum([ - self.pi_err, - self.vf_loss_coeff * self.value_err, - -self.entropy_coeff * self.entropy, - ]) - - return overall_err - - -class A3CPostprocessing(object): - """Adds the VF preds and advantages fields to the trajectory.""" - - @override(TorchPolicyGraph) - def extra_action_out(self, model_out): - return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} - - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - completed = sample_batch[SampleBatch.DONES][-1] - if completed: - last_r = 0.0 - else: - last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1]) - return compute_advantages(sample_batch, last_r, self.config["gamma"], - self.config["lambda"]) - - -class A3CTorchPolicyGraph(A3CPostprocessing, TorchPolicyGraph): - """A simple, non-recurrent PyTorch policy example.""" - - def __init__(self, obs_space, action_space, config): - config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config) - self.config = config - dist_class, self.logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"], torch=True) - model = ModelCatalog.get_torch_model(obs_space, self.logit_dim, - self.config["model"]) - loss = A3CLoss(dist_class, self.config["vf_loss_coeff"], - self.config["entropy_coeff"]) - TorchPolicyGraph.__init__( - self, - obs_space, - action_space, - model, - loss, - loss_inputs=[ - SampleBatch.CUR_OBS, SampleBatch.ACTIONS, - Postprocessing.ADVANTAGES, Postprocessing.VALUE_TARGETS - ], - action_distribution_cls=dist_class) - - @override(TorchPolicyGraph) - def optimizer(self): - return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"]) - - @override(TorchPolicyGraph) - def extra_grad_process(self): - info = {} - if self.config["grad_clip"]: - total_norm = nn.utils.clip_grad_norm_(self._model.parameters(), - self.config["grad_clip"]) - info["grad_gnorm"] = total_norm - return info - - @override(TorchPolicyGraph) - def extra_grad_info(self): - return { - "policy_entropy": self._loss.entropy.item(), - "policy_loss": self._loss.pi_err.item(), - "vf_loss": self._loss.value_err.item() - } - +from ray.rllib.evaluation.torch_policy_template import build_torch_policy + + +def actor_critic_loss(policy, batch_tensors): + logits, _, values, _ = policy.model({ + SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] + }, []) + dist = policy.dist_class(logits) + log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS]) + policy.entropy = dist.entropy().mean() + policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot( + log_probs.reshape(-1)) + policy.value_err = F.mse_loss( + values.reshape(-1), batch_tensors[Postprocessing.VALUE_TARGETS]) + overall_err = sum([ + policy.pi_err, + policy.config["vf_loss_coeff"] * policy.value_err, + -policy.config["entropy_coeff"] * policy.entropy, + ]) + return overall_err + + +def loss_and_entropy_stats(policy, batch_tensors): + return { + "policy_entropy": policy.entropy.item(), + "policy_loss": policy.pi_err.item(), + "vf_loss": policy.value_err.item(), + } + + +def add_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + completed = sample_batch[SampleBatch.DONES][-1] + if completed: + last_r = 0.0 + else: + last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1]) + return compute_advantages(sample_batch, last_r, policy.config["gamma"], + policy.config["lambda"]) + + +def model_value_predictions(policy, model_out): + return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} + + +def apply_grad_clipping(policy): + info = {} + if policy.config["grad_clip"]: + total_norm = nn.utils.clip_grad_norm_(policy.model.parameters(), + policy.config["grad_clip"]) + info["grad_gnorm"] = total_norm + return info + + +def torch_optimizer(policy, config): + return torch.optim.Adam(policy.model.parameters(), lr=config["lr"]) + + +class ValueNetworkMixin(object): def _value(self, obs): with self.lock: obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) - _, _, vf, _ = self._model({"obs": obs}, []) + _, _, vf, _ = self.model({"obs": obs}, []) return vf.detach().cpu().numpy().squeeze() + + +A3CTorchPolicy = build_torch_policy( + name="A3CTorchPolicy", + get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, + loss_fn=actor_critic_loss, + stats_fn=loss_and_entropy_stats, + postprocess_fn=add_advantages, + extra_action_out_fn=model_value_predictions, + extra_grad_process_fn=apply_grad_clipping, + optimizer_fn=torch_optimizer, + mixins=[ValueNetworkMixin]) diff --git a/python/ray/rllib/agents/pg/pg.py b/python/ray/rllib/agents/pg/pg.py index e70fdcc8b2c6..ffbb899d1b9e 100644 --- a/python/ray/rllib/agents/pg/pg.py +++ b/python/ray/rllib/agents/pg/pg.py @@ -2,11 +2,9 @@ from __future__ import division from __future__ import print_function -from ray.rllib.agents.trainer import Trainer, with_common_config -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph - -from ray.rllib.optimizers import SyncSamplesOptimizer -from ray.rllib.utils.annotations import override +from ray.rllib.agents.trainer import with_common_config +from ray.rllib.agents.trainer_template import build_trainer +from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy # yapf: disable # __sphinx_doc_begin__ @@ -22,40 +20,16 @@ # yapf: enable -class PGTrainer(Trainer): - """Simple policy gradient agent. - - This is an example agent to show how to implement algorithms in RLlib. - In most cases, you will probably want to use the PPO agent instead. - """ - - _name = "PG" - _default_config = DEFAULT_CONFIG - _policy_graph = PGPolicyGraph +def get_policy_class(config): + if config["use_pytorch"]: + from ray.rllib.agents.pg.torch_pg_policy_graph import PGTorchPolicy + return PGTorchPolicy + else: + return PGTFPolicy - @override(Trainer) - def _init(self, config, env_creator): - if config["use_pytorch"]: - from ray.rllib.agents.pg.torch_pg_policy_graph import \ - PGTorchPolicyGraph - policy_cls = PGTorchPolicyGraph - else: - policy_cls = self._policy_graph - self.local_evaluator = self.make_local_evaluator( - env_creator, policy_cls) - self.remote_evaluators = self.make_remote_evaluators( - env_creator, policy_cls, config["num_workers"]) - optimizer_config = dict( - config["optimizer"], - **{"train_batch_size": config["train_batch_size"]}) - self.optimizer = SyncSamplesOptimizer( - self.local_evaluator, self.remote_evaluators, **optimizer_config) - @override(Trainer) - def _train(self): - prev_steps = self.optimizer.num_steps_sampled - self.optimizer.step() - result = self.collect_metrics() - result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - - prev_steps) - return result +PGTrainer = build_trainer( + name="PG", + default_config=DEFAULT_CONFIG, + default_policy=PGTFPolicy, + get_policy_class=get_policy_class) diff --git a/python/ray/rllib/agents/pg/pg_policy_graph.py b/python/ray/rllib/agents/pg/pg_policy_graph.py index a55af79b1e61..54fcd041cc72 100644 --- a/python/ray/rllib/agents/pg/pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/pg_policy_graph.py @@ -3,102 +3,33 @@ from __future__ import print_function import ray -from ray.rllib.models.catalog import ModelCatalog from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.evaluation.tf_policy_template import build_tf_policy from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph -from ray.rllib.utils.annotations import override from ray.rllib.utils import try_import_tf tf = try_import_tf() -class PGLoss(object): - """The basic policy gradient loss.""" +# The basic policy gradients loss +def policy_gradient_loss(policy, batch_tensors): + actions = batch_tensors[SampleBatch.ACTIONS] + advantages = batch_tensors[Postprocessing.ADVANTAGES] + return -tf.reduce_mean(policy.action_dist.logp(actions) * advantages) - def __init__(self, action_dist, actions, advantages): - self.loss = -tf.reduce_mean(action_dist.logp(actions) * advantages) +# This adds the "advantages" column to the sample batch. +def postprocess_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + return compute_advantages( + sample_batch, 0.0, policy.config["gamma"], use_gae=False) -class PGPostprocessing(object): - """Adds the advantages field to the trajectory.""" - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - # This adds the "advantages" column to the sample batch - return compute_advantages( - sample_batch, 0.0, self.config["gamma"], use_gae=False) - - -class PGPolicyGraph(PGPostprocessing, TFPolicyGraph): - """Simple policy gradient example of defining a policy graph.""" - - def __init__(self, obs_space, action_space, config): - config = dict(ray.rllib.agents.pg.pg.DEFAULT_CONFIG, **config) - self.config = config - - # Setup placeholders - obs = tf.placeholder(tf.float32, shape=[None] + list(obs_space.shape)) - dist_class, self.logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - prev_actions = ModelCatalog.get_action_placeholder(action_space) - prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward") - - # Create the model network and action outputs - self.model = ModelCatalog.get_model({ - "obs": obs, - "prev_actions": prev_actions, - "prev_rewards": prev_rewards, - "is_training": self._get_is_training_placeholder(), - }, obs_space, action_space, self.logit_dim, self.config["model"]) - action_dist = dist_class(self.model.outputs) # logit for each action - - # Setup policy loss - actions = ModelCatalog.get_action_placeholder(action_space) - advantages = tf.placeholder(tf.float32, [None], name="adv") - loss = PGLoss(action_dist, actions, advantages).loss - - # Mapping from sample batch keys to placeholders. These keys will be - # read from postprocessed sample batches and fed into the specified - # placeholders during loss computation. - loss_in = [ - (SampleBatch.CUR_OBS, obs), - (SampleBatch.ACTIONS, actions), - (SampleBatch.PREV_ACTIONS, prev_actions), - (SampleBatch.PREV_REWARDS, prev_rewards), - (Postprocessing.ADVANTAGES, advantages), - ] - - # Initialize TFPolicyGraph - sess = tf.get_default_session() - TFPolicyGraph.__init__( - self, - obs_space, - action_space, - sess, - obs_input=obs, - action_sampler=action_dist.sample(), - action_prob=action_dist.sampled_action_prob(), - loss=loss, - loss_inputs=loss_in, - model=self.model, - state_inputs=self.model.state_in, - state_outputs=self.model.state_out, - prev_action_input=prev_actions, - prev_reward_input=prev_rewards, - seq_lens=self.model.seq_lens, - max_seq_len=config["model"]["max_seq_len"]) - sess.run(tf.global_variables_initializer()) - - @override(PolicyGraph) - def get_initial_state(self): - return self.model.state_init - - @override(TFPolicyGraph) - def optimizer(self): - return tf.train.AdamOptimizer(learning_rate=self.config["lr"]) +PGTFPolicy = build_tf_policy( + name="PGTFPolicy", + get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG, + postprocess_fn=postprocess_advantages, + loss_fn=policy_gradient_loss) diff --git a/python/ray/rllib/agents/pg/torch_pg_policy_graph.py b/python/ray/rllib/agents/pg/torch_pg_policy_graph.py index 746ef1bca42f..cda1b6eb5057 100644 --- a/python/ray/rllib/agents/pg/torch_pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/torch_pg_policy_graph.py @@ -2,82 +2,41 @@ from __future__ import division from __future__ import print_function -import torch -from torch import nn - import ray -from ray.rllib.models.catalog import ModelCatalog from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph -from ray.rllib.utils.annotations import override - - -class PGLoss(nn.Module): - def __init__(self, dist_class): - nn.Module.__init__(self) - self.dist_class = dist_class - - def forward(self, policy_model, observations, actions, advantages): - logits, _, values, _ = policy_model({ - SampleBatch.CUR_OBS: observations - }, []) - dist = self.dist_class(logits) - log_probs = dist.logp(actions) - self.pi_err = -advantages.dot(log_probs.reshape(-1)) - return self.pi_err - - -class PGPostprocessing(object): - """Adds the value func output and advantages field to the trajectory.""" +from ray.rllib.evaluation.torch_policy_template import build_torch_policy - @override(TorchPolicyGraph) - def extra_action_out(self, model_out): - return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - return compute_advantages( - sample_batch, 0.0, self.config["gamma"], use_gae=False) +def pg_torch_loss(policy, batch_tensors): + logits, _, values, _ = policy.model({ + SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] + }, []) + action_dist = policy.dist_class(logits) + log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) + # save the error in the policy object + policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot( + log_probs.reshape(-1)) + return policy.pi_err -class PGTorchPolicyGraph(PGPostprocessing, TorchPolicyGraph): - def __init__(self, obs_space, action_space, config): - config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config) - self.config = config - dist_class, self.logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"], torch=True) - model = ModelCatalog.get_torch_model(obs_space, self.logit_dim, - self.config["model"]) - loss = PGLoss(dist_class) +def postprocess_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + return compute_advantages( + sample_batch, 0.0, policy.config["gamma"], use_gae=False) - TorchPolicyGraph.__init__( - self, - obs_space, - action_space, - model, - loss, - loss_inputs=[ - SampleBatch.CUR_OBS, SampleBatch.ACTIONS, - Postprocessing.ADVANTAGES - ], - action_distribution_cls=dist_class) - @override(TorchPolicyGraph) - def optimizer(self): - return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"]) +def pg_loss_stats(policy, batch_tensors): + # the error is recorded when computing the loss + return {"policy_loss": policy.pi_err.item()} - @override(TorchPolicyGraph) - def extra_grad_info(self): - return {"policy_loss": self._loss.pi_err.item()} - def _value(self, obs): - with self.lock: - obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) - _, _, vf, _ = self.model({"obs": obs}, []) - return vf.detach().cpu().numpy().squeeze() +PGTorchPolicy = build_torch_policy( + name="PGTorchPolicy", + get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, + loss_fn=pg_torch_loss, + stats_fn=pg_loss_stats, + postprocess_fn=postprocess_advantages) diff --git a/python/ray/rllib/agents/ppo/appo.py b/python/ray/rllib/agents/ppo/appo.py index ac3251775d52..b32531dd7d5c 100644 --- a/python/ray/rllib/agents/ppo/appo.py +++ b/python/ray/rllib/agents/ppo/appo.py @@ -2,7 +2,7 @@ from __future__ import division from __future__ import print_function -from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOPolicyGraph +from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOTFPolicy from ray.rllib.agents.trainer import with_base_config from ray.rllib.agents import impala from ray.rllib.utils.annotations import override @@ -57,8 +57,8 @@ class APPOTrainer(impala.ImpalaTrainer): _name = "APPO" _default_config = DEFAULT_CONFIG - _policy_graph = AsyncPPOPolicyGraph + _policy_graph = AsyncPPOTFPolicy @override(impala.ImpalaTrainer) def _get_policy_graph(self): - return AsyncPPOPolicyGraph + return AsyncPPOTFPolicy diff --git a/python/ray/rllib/agents/ppo/appo_policy_graph.py b/python/ray/rllib/agents/ppo/appo_policy_graph.py index caaaf512bcb1..5aa76913194f 100644 --- a/python/ray/rllib/agents/ppo/appo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/appo_policy_graph.py @@ -12,14 +12,11 @@ import ray from ray.rllib.agents.impala import vtrace -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ - LearningRateSchedule -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.utils.annotations import override +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.evaluation.tf_policy_template import build_tf_policy +from ray.rllib.evaluation.tf_policy_graph import LearningRateSchedule from ray.rllib.utils.explained_variance import explained_variance -from ray.rllib.models.action_dist import MultiCategorical from ray.rllib.evaluation.postprocessing import compute_advantages from ray.rllib.utils import try_import_tf @@ -27,6 +24,8 @@ logger = logging.getLogger(__name__) +BEHAVIOUR_LOGITS = "behaviour_logits" + class PPOSurrogateLoss(object): """Loss used when V-trace is disabled. @@ -163,333 +162,235 @@ def __init__(self, self.entropy * entropy_coeff) -class APPOPostprocessing(object): - """Adds the policy logits, VF preds, and advantages to the trajectory.""" - - @override(TFPolicyGraph) - def extra_compute_action_fetches(self): - out = {"behaviour_logits": self.model.outputs} - if not self.config["vtrace"]: - out["vf_preds"] = self.value_function - return dict(TFPolicyGraph.extra_compute_action_fetches(self), **out) - - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - if not self.config["vtrace"]: - completed = sample_batch["dones"][-1] - if completed: - last_r = 0.0 - else: - next_state = [] - for i in range(len(self.model.state_in)): - next_state.append( - [sample_batch["state_out_{}".format(i)][-1]]) - last_r = self.value(sample_batch["new_obs"][-1], *next_state) - batch = compute_advantages( - sample_batch, - last_r, - self.config["gamma"], - self.config["lambda"], - use_gae=self.config["use_gae"]) - else: - batch = sample_batch - del batch.data["new_obs"] # not used, so save some bandwidth - return batch - +def _make_time_major(policy, tensor, drop_last=False): + """Swaps batch and trajectory axis. -class AsyncPPOPolicyGraph(LearningRateSchedule, APPOPostprocessing, - TFPolicyGraph): - def __init__(self, - observation_space, - action_space, - config, - existing_inputs=None): - config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config) - assert config["batch_mode"] == "truncate_episodes", \ - "Must use `truncate_episodes` batch mode with V-trace." - self.config = config - self.sess = tf.get_default_session() - self.grads = None - - if isinstance(action_space, gym.spaces.Discrete): - is_multidiscrete = False - output_hidden_shape = [action_space.n] - elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): - is_multidiscrete = True - output_hidden_shape = action_space.nvec.astype(np.int32) - else: - is_multidiscrete = False - output_hidden_shape = 1 - - # Policy network model - dist_class, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - - # Create input placeholders - if existing_inputs: - if self.config["vtrace"]: - actions, dones, behaviour_logits, rewards, observations, \ - prev_actions, prev_rewards = existing_inputs[:7] - existing_state_in = existing_inputs[7:-1] - existing_seq_lens = existing_inputs[-1] - else: - actions, dones, behaviour_logits, rewards, observations, \ - prev_actions, prev_rewards, adv_ph, value_targets = \ - existing_inputs[:9] - existing_state_in = existing_inputs[9:-1] - existing_seq_lens = existing_inputs[-1] + Arguments: + policy: Policy reference + tensor: A tensor or list of tensors to reshape. + drop_last: A bool indicating whether to drop the last + trajectory item. + + Returns: + res: A tensor with swapped axes or a list of tensors with + swapped axes. + """ + if isinstance(tensor, list): + return [_make_time_major(policy, t, drop_last) for t in tensor] + + if policy.model.state_init: + B = tf.shape(policy.model.seq_lens)[0] + T = tf.shape(tensor)[0] // B + else: + # Important: chop the tensor into batches at known episode cut + # boundaries. TODO(ekl) this is kind of a hack + T = policy.config["sample_batch_size"] + B = tf.shape(tensor)[0] // T + rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0)) + + # swap B and T axes + res = tf.transpose( + rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0])))) + + if drop_last: + return res[:-1] + return res + + +def build_appo_surrogate_loss(policy, batch_tensors): + if isinstance(policy.action_space, gym.spaces.Discrete): + is_multidiscrete = False + output_hidden_shape = [policy.action_space.n] + elif isinstance(policy.action_space, + gym.spaces.multi_discrete.MultiDiscrete): + is_multidiscrete = True + output_hidden_shape = policy.action_space.nvec.astype(np.int32) + else: + is_multidiscrete = False + output_hidden_shape = 1 + + def make_time_major(*args, **kw): + return _make_time_major(policy, *args, **kw) + + actions = batch_tensors[SampleBatch.ACTIONS] + dones = batch_tensors[SampleBatch.DONES] + rewards = batch_tensors[SampleBatch.REWARDS] + behaviour_logits = batch_tensors[BEHAVIOUR_LOGITS] + unpacked_behaviour_logits = tf.split( + behaviour_logits, output_hidden_shape, axis=1) + unpacked_outputs = tf.split( + policy.model.outputs, output_hidden_shape, axis=1) + prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \ + behaviour_logits + action_dist = policy.action_dist + prev_action_dist = policy.dist_class(prev_dist_inputs) + values = policy.value_function + + if policy.model.state_in: + max_seq_len = tf.reduce_max(policy.model.seq_lens) - 1 + mask = tf.sequence_mask(policy.model.seq_lens, max_seq_len) + mask = tf.reshape(mask, [-1]) + else: + mask = tf.ones_like(rewards) + + if policy.config["vtrace"]: + logger.info("Using V-Trace surrogate loss (vtrace=True)") + + # Prepare actions for loss + loss_actions = actions if is_multidiscrete else tf.expand_dims( + actions, axis=1) + + policy.loss = VTraceSurrogateLoss( + actions=make_time_major(loss_actions, drop_last=True), + prev_actions_logp=make_time_major( + prev_action_dist.logp(actions), drop_last=True), + actions_logp=make_time_major( + action_dist.logp(actions), drop_last=True), + action_kl=prev_action_dist.kl(action_dist), + actions_entropy=make_time_major( + action_dist.entropy(), drop_last=True), + dones=make_time_major(dones, drop_last=True), + behaviour_logits=make_time_major( + unpacked_behaviour_logits, drop_last=True), + target_logits=make_time_major(unpacked_outputs, drop_last=True), + discount=policy.config["gamma"], + rewards=make_time_major(rewards, drop_last=True), + values=make_time_major(values, drop_last=True), + bootstrap_value=make_time_major(values)[-1], + dist_class=policy.dist_class, + valid_mask=make_time_major(mask, drop_last=True), + vf_loss_coeff=policy.config["vf_loss_coeff"], + entropy_coeff=policy.config["entropy_coeff"], + clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], + clip_pg_rho_threshold=policy.config[ + "vtrace_clip_pg_rho_threshold"], + clip_param=policy.config["clip_param"]) + else: + logger.info("Using PPO surrogate loss (vtrace=False)") + policy.loss = PPOSurrogateLoss( + prev_actions_logp=make_time_major(prev_action_dist.logp(actions)), + actions_logp=make_time_major(action_dist.logp(actions)), + action_kl=prev_action_dist.kl(action_dist), + actions_entropy=make_time_major(action_dist.entropy()), + values=make_time_major(values), + valid_mask=make_time_major(mask), + advantages=make_time_major( + batch_tensors[Postprocessing.ADVANTAGES]), + value_targets=make_time_major( + batch_tensors[Postprocessing.VALUE_TARGETS]), + vf_loss_coeff=policy.config["vf_loss_coeff"], + entropy_coeff=policy.config["entropy_coeff"], + clip_param=policy.config["clip_param"]) + + return policy.loss.total_loss + + +def stats(policy, batch_tensors): + values_batched = _make_time_major( + policy, policy.value_function, drop_last=policy.config["vtrace"]) + + return { + "cur_lr": tf.cast(policy.cur_lr, tf.float64), + "policy_loss": policy.loss.pi_loss, + "entropy": policy.loss.entropy, + "var_gnorm": tf.global_norm(policy.var_list), + "vf_loss": policy.loss.vf_loss, + "vf_explained_var": explained_variance( + tf.reshape(policy.loss.value_targets, [-1]), + tf.reshape(values_batched, [-1])), + } + + +def grad_stats(policy, grads): + return { + "grad_gnorm": tf.global_norm(grads), + } + + +def postprocess_trajectory(policy, + sample_batch, + other_agent_batches=None, + episode=None): + if not policy.config["vtrace"]: + completed = sample_batch["dones"][-1] + if completed: + last_r = 0.0 else: - actions = ModelCatalog.get_action_placeholder(action_space) - dones = tf.placeholder(tf.bool, [None], name="dones") - rewards = tf.placeholder(tf.float32, [None], name="rewards") - behaviour_logits = tf.placeholder( - tf.float32, [None, logit_dim], name="behaviour_logits") - observations = tf.placeholder( - tf.float32, [None] + list(observation_space.shape)) - existing_state_in = None - existing_seq_lens = None - - if not self.config["vtrace"]: - adv_ph = tf.placeholder( - tf.float32, name="advantages", shape=(None, )) - value_targets = tf.placeholder( - tf.float32, name="value_targets", shape=(None, )) - self.observations = observations - - # Unpack behaviour logits - unpacked_behaviour_logits = tf.split( - behaviour_logits, output_hidden_shape, axis=1) - - # Setup the policy - dist_class, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - prev_actions = ModelCatalog.get_action_placeholder(action_space) - prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward") - self.model = ModelCatalog.get_model( - { - "obs": observations, - "prev_actions": prev_actions, - "prev_rewards": prev_rewards, - "is_training": self._get_is_training_placeholder(), - }, - observation_space, - action_space, - logit_dim, - self.config["model"], - state_in=existing_state_in, - seq_lens=existing_seq_lens) - unpacked_outputs = tf.split( - self.model.outputs, output_hidden_shape, axis=1) - - dist_inputs = unpacked_outputs if is_multidiscrete else \ - self.model.outputs - prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \ - behaviour_logits - - action_dist = dist_class(dist_inputs) - prev_action_dist = dist_class(prev_dist_inputs) - - values = self.model.value_function() - self.value_function = values + next_state = [] + for i in range(len(policy.model.state_in)): + next_state.append([sample_batch["state_out_{}".format(i)][-1]]) + last_r = policy.value(sample_batch["new_obs"][-1], *next_state) + batch = compute_advantages( + sample_batch, + last_r, + policy.config["gamma"], + policy.config["lambda"], + use_gae=policy.config["use_gae"]) + else: + batch = sample_batch + del batch.data["new_obs"] # not used, so save some bandwidth + return batch + + +def add_values_and_logits(policy): + out = {BEHAVIOUR_LOGITS: policy.model.outputs} + if not policy.config["vtrace"]: + out[SampleBatch.VF_PREDS] = policy.value_function + return out + + +def validate_config(policy, obs_space, action_space, config): + assert config["batch_mode"] == "truncate_episodes", \ + "Must use `truncate_episodes` batch mode with V-trace." + + +def choose_optimizer(policy, config): + if policy.config["opt_type"] == "adam": + return tf.train.AdamOptimizer(policy.cur_lr) + else: + return tf.train.RMSPropOptimizer(policy.cur_lr, config["decay"], + config["momentum"], config["epsilon"]) + + +def clip_gradients(policy, optimizer, loss): + grads = tf.gradients(loss, policy.var_list) + policy.grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"]) + clipped_grads = list(zip(policy.grads, policy.var_list)) + return clipped_grads + + +class ValueNetworkMixin(object): + def __init__(self): + self.value_function = self.model.value_function() self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) - def make_time_major(tensor, drop_last=False): - """Swaps batch and trajectory axis. - Args: - tensor: A tensor or list of tensors to reshape. - drop_last: A bool indicating whether to drop the last - trajectory item. - Returns: - res: A tensor with swapped axes or a list of tensors with - swapped axes. - """ - if isinstance(tensor, list): - return [make_time_major(t, drop_last) for t in tensor] - - if self.model.state_init: - B = tf.shape(self.model.seq_lens)[0] - T = tf.shape(tensor)[0] // B - else: - # Important: chop the tensor into batches at known episode cut - # boundaries. TODO(ekl) this is kind of a hack - T = self.config["sample_batch_size"] - B = tf.shape(tensor)[0] // T - rs = tf.reshape(tensor, - tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0)) - - # swap B and T axes - res = tf.transpose( - rs, - [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0])))) - - if drop_last: - return res[:-1] - return res - - if self.model.state_in: - max_seq_len = tf.reduce_max(self.model.seq_lens) - 1 - mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) - mask = tf.reshape(mask, [-1]) - else: - mask = tf.ones_like(rewards) - - # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. - if self.config["vtrace"]: - logger.info("Using V-Trace surrogate loss (vtrace=True)") - - # Prepare actions for loss - loss_actions = actions if is_multidiscrete else tf.expand_dims( - actions, axis=1) - - self.loss = VTraceSurrogateLoss( - actions=make_time_major(loss_actions, drop_last=True), - prev_actions_logp=make_time_major( - prev_action_dist.logp(actions), drop_last=True), - actions_logp=make_time_major( - action_dist.logp(actions), drop_last=True), - action_kl=prev_action_dist.kl(action_dist), - actions_entropy=make_time_major( - action_dist.entropy(), drop_last=True), - dones=make_time_major(dones, drop_last=True), - behaviour_logits=make_time_major( - unpacked_behaviour_logits, drop_last=True), - target_logits=make_time_major( - unpacked_outputs, drop_last=True), - discount=config["gamma"], - rewards=make_time_major(rewards, drop_last=True), - values=make_time_major(values, drop_last=True), - bootstrap_value=make_time_major(values)[-1], - dist_class=dist_class, - valid_mask=make_time_major(mask, drop_last=True), - vf_loss_coeff=self.config["vf_loss_coeff"], - entropy_coeff=self.config["entropy_coeff"], - clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], - clip_pg_rho_threshold=self.config[ - "vtrace_clip_pg_rho_threshold"], - clip_param=self.config["clip_param"]) - else: - logger.info("Using PPO surrogate loss (vtrace=False)") - self.loss = PPOSurrogateLoss( - prev_actions_logp=make_time_major( - prev_action_dist.logp(actions)), - actions_logp=make_time_major(action_dist.logp(actions)), - action_kl=prev_action_dist.kl(action_dist), - actions_entropy=make_time_major(action_dist.entropy()), - values=make_time_major(values), - valid_mask=make_time_major(mask), - advantages=make_time_major(adv_ph), - value_targets=make_time_major(value_targets), - vf_loss_coeff=self.config["vf_loss_coeff"], - entropy_coeff=self.config["entropy_coeff"], - clip_param=self.config["clip_param"]) - - # KL divergence between worker and learner logits for debugging - model_dist = MultiCategorical(unpacked_outputs) - behaviour_dist = MultiCategorical(unpacked_behaviour_logits) - - kls = model_dist.kl(behaviour_dist) - if len(kls) > 1: - self.KL_stats = {} - - for i, kl in enumerate(kls): - self.KL_stats.update({ - "mean_KL_{}".format(i): tf.reduce_mean(kl), - "max_KL_{}".format(i): tf.reduce_max(kl), - }) - else: - self.KL_stats = { - "mean_KL": tf.reduce_mean(kls[0]), - "max_KL": tf.reduce_max(kls[0]), - } - - # Initialize TFPolicyGraph - loss_in = [ - ("actions", actions), - ("dones", dones), - ("behaviour_logits", behaviour_logits), - ("rewards", rewards), - ("obs", observations), - ("prev_actions", prev_actions), - ("prev_rewards", prev_rewards), - ] - if not self.config["vtrace"]: - loss_in.append(("advantages", adv_ph)) - loss_in.append(("value_targets", value_targets)) - LearningRateSchedule.__init__(self, self.config["lr"], - self.config["lr_schedule"]) - TFPolicyGraph.__init__( - self, - observation_space, - action_space, - self.sess, - obs_input=observations, - action_sampler=action_dist.sample(), - action_prob=action_dist.sampled_action_prob(), - loss=self.loss.total_loss, - model=self.model, - loss_inputs=loss_in, - state_inputs=self.model.state_in, - state_outputs=self.model.state_out, - prev_action_input=prev_actions, - prev_reward_input=prev_rewards, - seq_lens=self.model.seq_lens, - max_seq_len=self.config["model"]["max_seq_len"], - batch_divisibility_req=self.config["sample_batch_size"]) - - self.sess.run(tf.global_variables_initializer()) - - values_batched = make_time_major( - values, drop_last=self.config["vtrace"]) - self.stats_fetches = { - LEARNER_STATS_KEY: dict({ - "cur_lr": tf.cast(self.cur_lr, tf.float64), - "policy_loss": self.loss.pi_loss, - "entropy": self.loss.entropy, - "grad_gnorm": tf.global_norm(self._grads), - "var_gnorm": tf.global_norm(self.var_list), - "vf_loss": self.loss.vf_loss, - "vf_explained_var": explained_variance( - tf.reshape(self.loss.value_targets, [-1]), - tf.reshape(values_batched, [-1])), - }, **self.KL_stats), - } - - def optimizer(self): - if self.config["opt_type"] == "adam": - return tf.train.AdamOptimizer(self.cur_lr) - else: - return tf.train.RMSPropOptimizer(self.cur_lr, self.config["decay"], - self.config["momentum"], - self.config["epsilon"]) - - def gradients(self, optimizer, loss): - grads = tf.gradients(loss, self.var_list) - self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) - clipped_grads = list(zip(self.grads, self.var_list)) - return clipped_grads - - def extra_compute_grad_fetches(self): - return self.stats_fetches - def value(self, ob, *args): - feed_dict = {self.observations: [ob], self.model.seq_lens: [1]} + feed_dict = {self._obs_input: [ob], self.model.seq_lens: [1]} assert len(args) == len(self.model.state_in), \ (args, self.model.state_in) for k, v in zip(self.model.state_in, args): feed_dict[k] = v - vf = self.sess.run(self.value_function, feed_dict) + vf = self._sess.run(self.value_function, feed_dict) return vf[0] - def get_initial_state(self): - return self.model.state_init - def copy(self, existing_inputs): - return AsyncPPOPolicyGraph( - self.observation_space, - self.action_space, - self.config, - existing_inputs=existing_inputs) +def setup_mixins(policy, obs_space, action_space, config): + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + ValueNetworkMixin.__init__(policy) + + +AsyncPPOTFPolicy = build_tf_policy( + name="AsyncPPOTFPolicy", + get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG, + loss_fn=build_appo_surrogate_loss, + stats_fn=stats, + grad_stats_fn=grad_stats, + postprocess_fn=postprocess_trajectory, + optimizer_fn=choose_optimizer, + gradients_fn=clip_gradients, + extra_action_fetches_fn=add_values_and_logits, + before_init=validate_config, + before_loss_init=setup_mixins, + mixins=[LearningRateSchedule, ValueNetworkMixin], + get_batch_divisibility_req=lambda p: p.config["sample_batch_size"]) diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index 8f69c91149e7..d3f5abdaa95c 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -4,10 +4,10 @@ import logging -from ray.rllib.agents import Trainer, with_common_config -from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph +from ray.rllib.agents import with_common_config +from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy +from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer -from ray.rllib.utils.annotations import override logger = logging.getLogger(__name__) @@ -63,110 +63,104 @@ # yapf: enable -class PPOTrainer(Trainer): - """Multi-GPU optimized implementation of PPO in TensorFlow.""" - - _name = "PPO" - _default_config = DEFAULT_CONFIG - _policy_graph = PPOPolicyGraph - - @override(Trainer) - def _init(self, config, env_creator): - self._validate_config() - self.local_evaluator = self.make_local_evaluator( - env_creator, self._policy_graph) - self.remote_evaluators = self.make_remote_evaluators( - env_creator, self._policy_graph, config["num_workers"]) - if config["simple_optimizer"]: - self.optimizer = SyncSamplesOptimizer( - self.local_evaluator, - self.remote_evaluators, - num_sgd_iter=config["num_sgd_iter"], - train_batch_size=config["train_batch_size"]) - else: - self.optimizer = LocalMultiGPUOptimizer( - self.local_evaluator, - self.remote_evaluators, - sgd_batch_size=config["sgd_minibatch_size"], - num_sgd_iter=config["num_sgd_iter"], - num_gpus=config["num_gpus"], - sample_batch_size=config["sample_batch_size"], - num_envs_per_worker=config["num_envs_per_worker"], - train_batch_size=config["train_batch_size"], - standardize_fields=["advantages"], - straggler_mitigation=config["straggler_mitigation"]) - - @override(Trainer) - def _train(self): - if "observation_filter" not in self.raw_user_config: - # TODO(ekl) remove this message after a few releases - logger.info( - "Important! Since 0.7.0, observation normalization is no " - "longer enabled by default. To enable running-mean " - "normalization, set 'observation_filter': 'MeanStdFilter'. " - "You can ignore this message if your environment doesn't " - "require observation normalization.") - prev_steps = self.optimizer.num_steps_sampled - fetches = self.optimizer.step() - if "kl" in fetches: - # single-agent - self.local_evaluator.for_policy( - lambda pi: pi.update_kl(fetches["kl"])) - else: - - def update(pi, pi_id): - if pi_id in fetches: - pi.update_kl(fetches[pi_id]["kl"]) - else: - logger.debug( - "No data for {}, not updating kl".format(pi_id)) - - # multi-agent - self.local_evaluator.foreach_trainable_policy(update) - res = self.collect_metrics() - res.update( - timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, - info=res.get("info", {})) - - # Warn about bad clipping configs - if self.config["vf_clip_param"] <= 0: - rew_scale = float("inf") - elif res["policy_reward_mean"]: - rew_scale = 0 # punt on handling multiagent case - else: - rew_scale = round( - abs(res["episode_reward_mean"]) / self.config["vf_clip_param"], - 0) - if rew_scale > 200: - logger.warning( - "The magnitude of your environment rewards are more than " - "{}x the scale of `vf_clip_param`. ".format(rew_scale) + - "This means that it will take more than " - "{} iterations for your value ".format(rew_scale) + - "function to converge. If this is not intended, consider " - "increasing `vf_clip_param`.") - return res - - def _validate_config(self): - if self.config["entropy_coeff"] < 0: - raise DeprecationWarning("entropy_coeff must be >= 0") - if self.config["sgd_minibatch_size"] > self.config["train_batch_size"]: - raise ValueError( - "Minibatch size {} must be <= train batch size {}.".format( - self.config["sgd_minibatch_size"], - self.config["train_batch_size"])) - if (self.config["batch_mode"] == "truncate_episodes" - and not self.config["use_gae"]): - raise ValueError( - "Episode truncation is not supported without a value " - "function. Consider setting batch_mode=complete_episodes.") - if (self.config["multiagent"]["policy_graphs"] - and not self.config["simple_optimizer"]): - logger.info( - "In multi-agent mode, policies will be optimized sequentially " - "by the multi-GPU optimizer. Consider setting " - "simple_optimizer=True if this doesn't work for you.") - if not self.config["vf_share_layers"]: - logger.warning( - "FYI: By default, the value function will not share layers " - "with the policy model ('vf_share_layers': False).") +def make_optimizer(local_evaluator, remote_evaluators, config): + if config["simple_optimizer"]: + return SyncSamplesOptimizer( + local_evaluator, + remote_evaluators, + num_sgd_iter=config["num_sgd_iter"], + train_batch_size=config["train_batch_size"]) + + return LocalMultiGPUOptimizer( + local_evaluator, + remote_evaluators, + sgd_batch_size=config["sgd_minibatch_size"], + num_sgd_iter=config["num_sgd_iter"], + num_gpus=config["num_gpus"], + sample_batch_size=config["sample_batch_size"], + num_envs_per_worker=config["num_envs_per_worker"], + train_batch_size=config["train_batch_size"], + standardize_fields=["advantages"], + straggler_mitigation=config["straggler_mitigation"]) + + +def update_kl(trainer, fetches): + if "kl" in fetches: + # single-agent + trainer.local_evaluator.for_policy( + lambda pi: pi.update_kl(fetches["kl"])) + else: + + def update(pi, pi_id): + if pi_id in fetches: + pi.update_kl(fetches[pi_id]["kl"]) + else: + logger.debug("No data for {}, not updating kl".format(pi_id)) + + # multi-agent + trainer.local_evaluator.foreach_trainable_policy(update) + + +def warn_about_obs_filter(trainer): + if "observation_filter" not in trainer.raw_user_config: + # TODO(ekl) remove this message after a few releases + logger.info( + "Important! Since 0.7.0, observation normalization is no " + "longer enabled by default. To enable running-mean " + "normalization, set 'observation_filter': 'MeanStdFilter'. " + "You can ignore this message if your environment doesn't " + "require observation normalization.") + + +def warn_about_bad_reward_scales(trainer, result): + # Warn about bad clipping configs + if trainer.config["vf_clip_param"] <= 0: + rew_scale = float("inf") + elif result["policy_reward_mean"]: + rew_scale = 0 # punt on handling multiagent case + else: + rew_scale = round( + abs(result["episode_reward_mean"]) / + trainer.config["vf_clip_param"], 0) + if rew_scale > 200: + logger.warning( + "The magnitude of your environment rewards are more than " + "{}x the scale of `vf_clip_param`. ".format(rew_scale) + + "This means that it will take more than " + "{} iterations for your value ".format(rew_scale) + + "function to converge. If this is not intended, consider " + "increasing `vf_clip_param`.") + + +def validate_config(config): + if config["entropy_coeff"] < 0: + raise DeprecationWarning("entropy_coeff must be >= 0") + if config["sgd_minibatch_size"] > config["train_batch_size"]: + raise ValueError( + "Minibatch size {} must be <= train batch size {}.".format( + config["sgd_minibatch_size"], config["train_batch_size"])) + if (config["batch_mode"] == "truncate_episodes" and not config["use_gae"]): + raise ValueError( + "Episode truncation is not supported without a value " + "function. Consider setting batch_mode=complete_episodes.") + if (config["multiagent"]["policy_graphs"] + and not config["simple_optimizer"]): + logger.info( + "In multi-agent mode, policies will be optimized sequentially " + "by the multi-GPU optimizer. Consider setting " + "simple_optimizer=True if this doesn't work for you.") + if not config["vf_share_layers"]: + logger.warning( + "FYI: By default, the value function will not share layers " + "with the policy model ('vf_share_layers': False).") + + +PPOTrainer = build_trainer( + name="PPO", + default_config=DEFAULT_CONFIG, + default_policy=PPOTFPolicy, + make_policy_optimizer=make_optimizer, + validate_config=validate_config, + after_optimizer_step=update_kl, + before_train_step=warn_about_obs_filter, + after_train_result=warn_about_bad_reward_scales) diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index 61aced1db740..334ca788c936 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -7,13 +7,10 @@ import ray from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ - LearningRateSchedule +from ray.rllib.evaluation.tf_policy_graph import LearningRateSchedule +from ray.rllib.evaluation.tf_policy_template import build_tf_policy from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.utils.annotations import override from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.utils import try_import_tf @@ -107,119 +104,106 @@ def reduce_mean_valid(t): self.loss = loss -class PPOPostprocessing(object): +def ppo_surrogate_loss(policy, batch_tensors): + if policy.model.state_in: + max_seq_len = tf.reduce_max(policy.model.seq_lens) + mask = tf.sequence_mask(policy.model.seq_lens, max_seq_len) + mask = tf.reshape(mask, [-1]) + else: + mask = tf.ones_like( + batch_tensors[Postprocessing.ADVANTAGES], dtype=tf.bool) + + policy.loss_obj = PPOLoss( + policy.action_space, + batch_tensors[Postprocessing.VALUE_TARGETS], + batch_tensors[Postprocessing.ADVANTAGES], + batch_tensors[SampleBatch.ACTIONS], + batch_tensors[BEHAVIOUR_LOGITS], + batch_tensors[SampleBatch.VF_PREDS], + policy.action_dist, + policy.value_function, + policy.kl_coeff, + mask, + entropy_coeff=policy.config["entropy_coeff"], + clip_param=policy.config["clip_param"], + vf_clip_param=policy.config["vf_clip_param"], + vf_loss_coeff=policy.config["vf_loss_coeff"], + use_gae=policy.config["use_gae"]) + + return policy.loss_obj.loss + + +def kl_and_loss_stats(policy, batch_tensors): + policy.explained_variance = explained_variance( + batch_tensors[Postprocessing.VALUE_TARGETS], policy.value_function) + + stats_fetches = { + "cur_kl_coeff": policy.kl_coeff, + "cur_lr": tf.cast(policy.cur_lr, tf.float64), + "total_loss": policy.loss_obj.loss, + "policy_loss": policy.loss_obj.mean_policy_loss, + "vf_loss": policy.loss_obj.mean_vf_loss, + "vf_explained_var": policy.explained_variance, + "kl": policy.loss_obj.mean_kl, + "entropy": policy.loss_obj.mean_entropy, + } + + return stats_fetches + + +def vf_preds_and_logits_fetches(policy): + """Adds value function and logits outputs to experience batches.""" + return { + SampleBatch.VF_PREDS: policy.value_function, + BEHAVIOUR_LOGITS: policy.model.outputs, + } + + +def postprocess_ppo_gae(policy, + sample_batch, + other_agent_batches=None, + episode=None): """Adds the policy logits, VF preds, and advantages to the trajectory.""" - @override(TFPolicyGraph) - def extra_compute_action_fetches(self): - return dict( - TFPolicyGraph.extra_compute_action_fetches(self), **{ - SampleBatch.VF_PREDS: self.value_function, - BEHAVIOUR_LOGITS: self.logits - }) - - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - completed = sample_batch["dones"][-1] - if completed: - last_r = 0.0 - else: - next_state = [] - for i in range(len(self.model.state_in)): - next_state.append([sample_batch["state_out_{}".format(i)][-1]]) - last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1], - sample_batch[SampleBatch.ACTIONS][-1], - sample_batch[SampleBatch.REWARDS][-1], - *next_state) - batch = compute_advantages( - sample_batch, - last_r, - self.config["gamma"], - self.config["lambda"], - use_gae=self.config["use_gae"]) - return batch - - -class PPOPolicyGraph(LearningRateSchedule, PPOPostprocessing, TFPolicyGraph): - def __init__(self, - observation_space, - action_space, - config, - existing_inputs=None): - """ - Arguments: - observation_space: Environment observation space specification. - action_space: Environment action space specification. - config (dict): Configuration values for PPO graph. - existing_inputs (list): Optional list of tuples that specify the - placeholders upon which the graph should be built upon. - """ - config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config) - self.sess = tf.get_default_session() - self.action_space = action_space - self.config = config - self.kl_coeff_val = self.config["kl_coeff"] - self.kl_target = self.config["kl_target"] - dist_cls, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - - if existing_inputs: - obs_ph, value_targets_ph, adv_ph, act_ph, \ - logits_ph, vf_preds_ph, prev_actions_ph, prev_rewards_ph = \ - existing_inputs[:8] - existing_state_in = existing_inputs[8:-1] - existing_seq_lens = existing_inputs[-1] - else: - obs_ph = tf.placeholder( - tf.float32, - name="obs", - shape=(None, ) + observation_space.shape) - adv_ph = tf.placeholder( - tf.float32, name="advantages", shape=(None, )) - act_ph = ModelCatalog.get_action_placeholder(action_space) - logits_ph = tf.placeholder( - tf.float32, name="logits", shape=(None, logit_dim)) - vf_preds_ph = tf.placeholder( - tf.float32, name="vf_preds", shape=(None, )) - value_targets_ph = tf.placeholder( - tf.float32, name="value_targets", shape=(None, )) - prev_actions_ph = ModelCatalog.get_action_placeholder(action_space) - prev_rewards_ph = tf.placeholder( - tf.float32, [None], name="prev_reward") - existing_state_in = None - existing_seq_lens = None - self.observations = obs_ph - self.prev_actions = prev_actions_ph - self.prev_rewards = prev_rewards_ph - - self.loss_in = [ - (SampleBatch.CUR_OBS, obs_ph), - (Postprocessing.VALUE_TARGETS, value_targets_ph), - (Postprocessing.ADVANTAGES, adv_ph), - (SampleBatch.ACTIONS, act_ph), - (BEHAVIOUR_LOGITS, logits_ph), - (SampleBatch.VF_PREDS, vf_preds_ph), - (SampleBatch.PREV_ACTIONS, prev_actions_ph), - (SampleBatch.PREV_REWARDS, prev_rewards_ph), - ] - self.model = ModelCatalog.get_model( - { - "obs": obs_ph, - "prev_actions": prev_actions_ph, - "prev_rewards": prev_rewards_ph, - "is_training": self._get_is_training_placeholder(), - }, - observation_space, - action_space, - logit_dim, - self.config["model"], - state_in=existing_state_in, - seq_lens=existing_seq_lens) - + completed = sample_batch["dones"][-1] + if completed: + last_r = 0.0 + else: + next_state = [] + for i in range(len(policy.model.state_in)): + next_state.append([sample_batch["state_out_{}".format(i)][-1]]) + last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1], + sample_batch[SampleBatch.ACTIONS][-1], + sample_batch[SampleBatch.REWARDS][-1], + *next_state) + batch = compute_advantages( + sample_batch, + last_r, + policy.config["gamma"], + policy.config["lambda"], + use_gae=policy.config["use_gae"]) + return batch + + +def clip_gradients(policy, optimizer, loss): + if policy.config["grad_clip"] is not None: + policy.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name) + grads = tf.gradients(loss, policy.var_list) + policy.grads, _ = tf.clip_by_global_norm(grads, + policy.config["grad_clip"]) + clipped_grads = list(zip(policy.grads, policy.var_list)) + return clipped_grads + else: + return optimizer.compute_gradients( + loss, colocate_gradients_with_ops=True) + + +class KLCoeffMixin(object): + def __init__(self, config): # KL Coefficient + self.kl_coeff_val = config["kl_coeff"] + self.kl_target = config["kl_target"] self.kl_coeff = tf.get_variable( initializer=tf.constant_initializer(self.kl_coeff_val), name="kl_coeff", @@ -227,14 +211,22 @@ def __init__(self, trainable=False, dtype=tf.float32) - self.logits = self.model.outputs - curr_action_dist = dist_cls(self.logits) - self.sampler = curr_action_dist.sample() - if self.config["use_gae"]: - if self.config["vf_share_layers"]: + def update_kl(self, sampled_kl): + if sampled_kl > 2.0 * self.kl_target: + self.kl_coeff_val *= 1.5 + elif sampled_kl < 0.5 * self.kl_target: + self.kl_coeff_val *= 0.5 + self.kl_coeff.load(self.kl_coeff_val, session=self._sess) + return self.kl_coeff_val + + +class ValueNetworkMixin(object): + def __init__(self, obs_space, action_space, config): + if config["use_gae"]: + if config["vf_share_layers"]: self.value_function = self.model.value_function() else: - vf_config = self.config["model"].copy() + vf_config = config["model"].copy() # Do not split the last layer of the value function into # mean parameters and standard deviation parameters and # do not make the standard deviations free variables. @@ -249,122 +241,43 @@ def __init__(self, "value_function() method.") with tf.variable_scope("value_function"): self.value_function = ModelCatalog.get_model({ - "obs": obs_ph, - "prev_actions": prev_actions_ph, - "prev_rewards": prev_rewards_ph, + "obs": self._obs_input, + "prev_actions": self._prev_action_input, + "prev_rewards": self._prev_reward_input, "is_training": self._get_is_training_placeholder(), - }, observation_space, action_space, 1, vf_config).outputs + }, obs_space, action_space, 1, vf_config).outputs self.value_function = tf.reshape(self.value_function, [-1]) else: - self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1]) - - if self.model.state_in: - max_seq_len = tf.reduce_max(self.model.seq_lens) - mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) - mask = tf.reshape(mask, [-1]) - else: - mask = tf.ones_like(adv_ph, dtype=tf.bool) - - self.loss_obj = PPOLoss( - action_space, - value_targets_ph, - adv_ph, - act_ph, - logits_ph, - vf_preds_ph, - curr_action_dist, - self.value_function, - self.kl_coeff, - mask, - entropy_coeff=self.config["entropy_coeff"], - clip_param=self.config["clip_param"], - vf_clip_param=self.config["vf_clip_param"], - vf_loss_coeff=self.config["vf_loss_coeff"], - use_gae=self.config["use_gae"]) - - LearningRateSchedule.__init__(self, self.config["lr"], - self.config["lr_schedule"]) - TFPolicyGraph.__init__( - self, - observation_space, - action_space, - self.sess, - obs_input=obs_ph, - action_sampler=self.sampler, - action_prob=curr_action_dist.sampled_action_prob(), - loss=self.loss_obj.loss, - model=self.model, - loss_inputs=self.loss_in, - state_inputs=self.model.state_in, - state_outputs=self.model.state_out, - prev_action_input=prev_actions_ph, - prev_reward_input=prev_rewards_ph, - seq_lens=self.model.seq_lens, - max_seq_len=config["model"]["max_seq_len"]) - - self.sess.run(tf.global_variables_initializer()) - self.explained_variance = explained_variance(value_targets_ph, - self.value_function) - self.stats_fetches = { - "cur_kl_coeff": self.kl_coeff, - "cur_lr": tf.cast(self.cur_lr, tf.float64), - "total_loss": self.loss_obj.loss, - "policy_loss": self.loss_obj.mean_policy_loss, - "vf_loss": self.loss_obj.mean_vf_loss, - "vf_explained_var": self.explained_variance, - "kl": self.loss_obj.mean_kl, - "entropy": self.loss_obj.mean_entropy - } - - @override(TFPolicyGraph) - def copy(self, existing_inputs): - """Creates a copy of self using existing input placeholders.""" - return PPOPolicyGraph( - self.observation_space, - self.action_space, - self.config, - existing_inputs=existing_inputs) - - @override(TFPolicyGraph) - def gradients(self, optimizer, loss): - if self.config["grad_clip"] is not None: - self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, - tf.get_variable_scope().name) - grads = tf.gradients(loss, self.var_list) - self.grads, _ = tf.clip_by_global_norm(grads, - self.config["grad_clip"]) - clipped_grads = list(zip(self.grads, self.var_list)) - return clipped_grads - else: - return optimizer.compute_gradients( - loss, colocate_gradients_with_ops=True) - - @override(PolicyGraph) - def get_initial_state(self): - return self.model.state_init - - @override(TFPolicyGraph) - def extra_compute_grad_fetches(self): - return {LEARNER_STATS_KEY: self.stats_fetches} - - def update_kl(self, sampled_kl): - if sampled_kl > 2.0 * self.kl_target: - self.kl_coeff_val *= 1.5 - elif sampled_kl < 0.5 * self.kl_target: - self.kl_coeff_val *= 0.5 - self.kl_coeff.load(self.kl_coeff_val, session=self.sess) - return self.kl_coeff_val + self.value_function = tf.zeros(shape=tf.shape(self._obs_input)[:1]) def _value(self, ob, prev_action, prev_reward, *args): feed_dict = { - self.observations: [ob], - self.prev_actions: [prev_action], - self.prev_rewards: [prev_reward], + self._obs_input: [ob], + self._prev_action_input: [prev_action], + self._prev_reward_input: [prev_reward], self.model.seq_lens: [1] } assert len(args) == len(self.model.state_in), \ (args, self.model.state_in) for k, v in zip(self.model.state_in, args): feed_dict[k] = v - vf = self.sess.run(self.value_function, feed_dict) + vf = self._sess.run(self.value_function, feed_dict) return vf[0] + + +def setup_mixins(policy, obs_space, action_space, config): + ValueNetworkMixin.__init__(policy, obs_space, action_space, config) + KLCoeffMixin.__init__(policy, config) + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + + +PPOTFPolicy = build_tf_policy( + name="PPOTFPolicy", + get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, + loss_fn=ppo_surrogate_loss, + stats_fn=kl_and_loss_stats, + extra_action_fetches_fn=vf_preds_and_logits_fetches, + postprocess_fn=postprocess_ppo_gae, + gradients_fn=clip_gradients, + before_loss_init=setup_mixins, + mixins=[LearningRateSchedule, KLCoeffMixin, ValueNetworkMixin]) diff --git a/python/ray/rllib/agents/trainer_template.py b/python/ray/rllib/agents/trainer_template.py new file mode 100644 index 000000000000..618bc3b30ace --- /dev/null +++ b/python/ray/rllib/agents/trainer_template.py @@ -0,0 +1,97 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.agents.trainer import Trainer +from ray.rllib.optimizers import SyncSamplesOptimizer +from ray.rllib.utils.annotations import override, DeveloperAPI + + +@DeveloperAPI +def build_trainer(name, + default_policy, + default_config=None, + make_policy_optimizer=None, + validate_config=None, + get_policy_class=None, + before_train_step=None, + after_optimizer_step=None, + after_train_result=None): + """Helper function for defining a custom trainer. + + Arguments: + name (str): name of the trainer (e.g., "PPO") + default_policy (cls): the default PolicyGraph class to use + default_config (dict): the default config dict of the algorithm, + otherwises uses the Trainer default config + make_policy_optimizer (func): optional function that returns a + PolicyOptimizer instance given + (local_evaluator, remote_evaluators, config) + validate_config (func): optional callback that checks a given config + for correctness. It may mutate the config as needed. + get_policy_class (func): optional callback that takes a config and + returns the policy graph class to override the default with + before_train_step (func): optional callback to run before each train() + call. It takes the trainer instance as an argument. + after_optimizer_step (func): optional callback to run after each + step() call to the policy optimizer. It takes the trainer instance + and the policy gradient fetches as arguments. + after_train_result (func): optional callback to run at the end of each + train() call. It takes the trainer instance and result dict as + arguments, and may mutate the result dict as needed. + + Returns: + a Trainer instance that uses the specified args. + """ + + if name.endswith("Trainer"): + raise ValueError("Algorithm name should not include *Trainer suffix", + name) + + class trainer_cls(Trainer): + _name = name + _default_config = default_config or Trainer.COMMON_CONFIG + _policy_graph = default_policy + + def _init(self, config, env_creator): + if validate_config: + validate_config(config) + if get_policy_class is None: + policy_graph = default_policy + else: + policy_graph = get_policy_class(config) + self.local_evaluator = self.make_local_evaluator( + env_creator, policy_graph) + self.remote_evaluators = self.make_remote_evaluators( + env_creator, policy_graph, config["num_workers"]) + if make_policy_optimizer: + self.optimizer = make_policy_optimizer( + self.local_evaluator, self.remote_evaluators, config) + else: + optimizer_config = dict( + config["optimizer"], + **{"train_batch_size": config["train_batch_size"]}) + self.optimizer = SyncSamplesOptimizer(self.local_evaluator, + self.remote_evaluators, + **optimizer_config) + + @override(Trainer) + def _train(self): + if before_train_step: + before_train_step(self) + prev_steps = self.optimizer.num_steps_sampled + fetches = self.optimizer.step() + if after_optimizer_step: + after_optimizer_step(self, fetches) + res = self.collect_metrics() + res.update( + timesteps_this_iter=self.optimizer.num_steps_sampled - + prev_steps, + info=res.get("info", {})) + if after_train_result: + after_train_result(self, res) + return res + + trainer_cls.__name__ = name + "Trainer" + trainer_cls.__qualname__ = name + "Trainer" + return trainer_cls diff --git a/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py b/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py new file mode 100644 index 000000000000..73e08fcf9093 --- /dev/null +++ b/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py @@ -0,0 +1,275 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import logging +import numpy as np + +from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils.annotations import override +from ray.rllib.utils import try_import_tf +from ray.rllib.utils.debug import log_once, summarize +from ray.rllib.utils.tracking_dict import UsageTrackingDict + +tf = try_import_tf() + +logger = logging.getLogger(__name__) + + +class DynamicTFPolicyGraph(TFPolicyGraph): + """A TFPolicyGraph that auto-defines placeholders dynamically at runtime. + + Initialization of this class occurs in two phases. + * Phase 1: the model is created and model variables are initialized. + * Phase 2: a fake batch of data is created, sent to the trajectory + postprocessor, and then used to create placeholders for the loss + function. The loss and stats functions are initialized with these + placeholders. + """ + + def __init__(self, + obs_space, + action_space, + config, + loss_fn, + stats_fn=None, + grad_stats_fn=None, + before_loss_init=None, + make_action_sampler=None, + existing_inputs=None, + get_batch_divisibility_req=None): + """Initialize a dynamic TF policy graph. + + Arguments: + observation_space (gym.Space): Observation space of the policy. + action_space (gym.Space): Action space of the policy. + config (dict): Policy-specific configuration data. + loss_fn (func): function that returns a loss tensor the policy + graph, and dict of experience tensor placeholders + stats_fn (func): optional function that returns a dict of + TF fetches given the policy graph and batch input tensors + grad_stats_fn (func): optional function that returns a dict of + TF fetches given the policy graph and loss gradient tensors + before_loss_init (func): optional function to run prior to loss + init that takes the same arguments as __init__ + make_action_sampler (func): optional function that returns a + tuple of action and action prob tensors. The function takes + (policy, input_dict, obs_space, action_space, config) as its + arguments + existing_inputs (OrderedDict): when copying a policy graph, this + specifies an existing dict of placeholders to use instead of + defining new ones + get_batch_divisibility_req (func): optional function that returns + the divisibility requirement for sample batches + """ + self.config = config + self._loss_fn = loss_fn + self._stats_fn = stats_fn + self._grad_stats_fn = grad_stats_fn + + # Setup standard placeholders + if existing_inputs is not None: + obs = existing_inputs[SampleBatch.CUR_OBS] + prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS] + prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS] + else: + obs = tf.placeholder( + tf.float32, + shape=[None] + list(obs_space.shape), + name="observation") + prev_actions = ModelCatalog.get_action_placeholder(action_space) + prev_rewards = tf.placeholder( + tf.float32, [None], name="prev_reward") + + input_dict = { + "obs": obs, + "prev_actions": prev_actions, + "prev_rewards": prev_rewards, + "is_training": self._get_is_training_placeholder(), + } + + # Create the model network and action outputs + if make_action_sampler: + assert not existing_inputs, \ + "Cloning not supported with custom action sampler" + self.model = None + self.dist_class = None + self.action_dist = None + action_sampler, action_prob = make_action_sampler( + self, input_dict, obs_space, action_space, config) + else: + self.dist_class, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"]) + if existing_inputs: + existing_state_in = [ + v for k, v in existing_inputs.items() + if k.startswith("state_in_") + ] + if existing_state_in: + existing_seq_lens = existing_inputs["seq_lens"] + else: + existing_seq_lens = None + else: + existing_state_in = [] + existing_seq_lens = None + self.model = ModelCatalog.get_model( + input_dict, + obs_space, + action_space, + logit_dim, + self.config["model"], + state_in=existing_state_in, + seq_lens=existing_seq_lens) + self.action_dist = self.dist_class(self.model.outputs) + action_sampler = self.action_dist.sample() + action_prob = self.action_dist.sampled_action_prob() + + # Phase 1 init + sess = tf.get_default_session() + if get_batch_divisibility_req: + batch_divisibility_req = get_batch_divisibility_req(self) + else: + batch_divisibility_req = 1 + TFPolicyGraph.__init__( + self, + obs_space, + action_space, + sess, + obs_input=obs, + action_sampler=action_sampler, + action_prob=action_prob, + loss=None, # dynamically initialized on run + loss_inputs=[], + model=self.model, + state_inputs=self.model and self.model.state_in, + state_outputs=self.model and self.model.state_out, + prev_action_input=prev_actions, + prev_reward_input=prev_rewards, + seq_lens=self.model and self.model.seq_lens, + max_seq_len=config["model"]["max_seq_len"], + batch_divisibility_req=batch_divisibility_req) + + # Phase 2 init + before_loss_init(self, obs_space, action_space, config) + if not existing_inputs: + self._initialize_loss() + + @override(TFPolicyGraph) + def copy(self, existing_inputs): + """Creates a copy of self using existing input placeholders.""" + + # Note that there might be RNN state inputs at the end of the list + if self._state_inputs: + num_state_inputs = len(self._state_inputs) + 1 + else: + num_state_inputs = 0 + if len(self._loss_inputs) + num_state_inputs != len(existing_inputs): + raise ValueError("Tensor list mismatch", self._loss_inputs, + self._state_inputs, existing_inputs) + for i, (k, v) in enumerate(self._loss_inputs): + if v.shape.as_list() != existing_inputs[i].shape.as_list(): + raise ValueError("Tensor shape mismatch", i, k, v.shape, + existing_inputs[i].shape) + # By convention, the loss inputs are followed by state inputs and then + # the seq len tensor + rnn_inputs = [] + for i in range(len(self._state_inputs)): + rnn_inputs.append(("state_in_{}".format(i), + existing_inputs[len(self._loss_inputs) + i])) + if rnn_inputs: + rnn_inputs.append(("seq_lens", existing_inputs[-1])) + input_dict = OrderedDict( + [(k, existing_inputs[i]) + for i, (k, _) in enumerate(self._loss_inputs)] + rnn_inputs) + instance = self.__class__( + self.observation_space, + self.action_space, + self.config, + existing_inputs=input_dict) + loss = instance._loss_fn(instance, input_dict) + if instance._stats_fn: + instance._stats_fetches.update( + instance._stats_fn(instance, input_dict)) + TFPolicyGraph._initialize_loss( + instance, loss, [(k, existing_inputs[i]) + for i, (k, _) in enumerate(self._loss_inputs)]) + if instance._grad_stats_fn: + instance._stats_fetches.update( + instance._grad_stats_fn(instance, instance._grads)) + return instance + + @override(PolicyGraph) + def get_initial_state(self): + if self.model: + return self.model.state_init + else: + return [] + + def _initialize_loss(self): + def fake_array(tensor): + shape = tensor.shape.as_list() + shape[0] = 1 + return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) + + dummy_batch = { + SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), + SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), + SampleBatch.CUR_OBS: fake_array(self._obs_input), + SampleBatch.NEXT_OBS: fake_array(self._obs_input), + SampleBatch.ACTIONS: fake_array(self._prev_action_input), + SampleBatch.REWARDS: np.array([0], dtype=np.float32), + SampleBatch.DONES: np.array([False], dtype=np.bool), + } + state_init = self.get_initial_state() + for i, h in enumerate(state_init): + dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) + dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) + if state_init: + dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) + for k, v in self.extra_compute_action_fetches().items(): + dummy_batch[k] = fake_array(v) + + # postprocessing might depend on variable init, so run it first here + self._sess.run(tf.global_variables_initializer()) + postprocessed_batch = self.postprocess_trajectory( + SampleBatch(dummy_batch)) + + batch_tensors = UsageTrackingDict({ + SampleBatch.PREV_ACTIONS: self._prev_action_input, + SampleBatch.PREV_REWARDS: self._prev_reward_input, + SampleBatch.CUR_OBS: self._obs_input, + }) + loss_inputs = [ + (SampleBatch.PREV_ACTIONS, self._prev_action_input), + (SampleBatch.PREV_REWARDS, self._prev_reward_input), + (SampleBatch.CUR_OBS, self._obs_input), + ] + + for k, v in postprocessed_batch.items(): + if k in batch_tensors: + continue + elif v.dtype == np.object: + continue # can't handle arbitrary objects in TF + shape = (None, ) + v.shape[1:] + dtype = np.float32 if v.dtype == np.float64 else v.dtype + placeholder = tf.placeholder(dtype, shape=shape, name=k) + batch_tensors[k] = placeholder + + if log_once("loss_init"): + logger.info( + "Initializing loss function with dummy input:\n\n{}\n".format( + summarize(batch_tensors))) + + loss = self._loss_fn(self, batch_tensors) + if self._stats_fn: + self._stats_fetches.update(self._stats_fn(self, batch_tensors)) + for k in sorted(batch_tensors.accessed_keys): + loss_inputs.append((k, batch_tensors[k])) + TFPolicyGraph._initialize_loss(self, loss, loss_inputs) + if self._grad_stats_fn: + self._stats_fetches.update(self._grad_stats_fn(self, self._grads)) + self._sess.run(tf.global_variables_initializer()) diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index f6761122156e..48e19dfcb96e 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -65,7 +65,7 @@ class PolicyEvaluator(EvaluatorInterface): >>> # Create a policy evaluator and using it to collect experiences. >>> evaluator = PolicyEvaluator( ... env_creator=lambda _: gym.make("CartPole-v0"), - ... policy_graph=PGPolicyGraph) + ... policy_graph=PGTFPolicy) >>> print(evaluator.sample()) SampleBatch({ "obs": [[...]], "actions": [[...]], "rewards": [[...]], @@ -76,7 +76,7 @@ class PolicyEvaluator(EvaluatorInterface): ... evaluator_cls=PolicyEvaluator, ... evaluator_args={ ... "env_creator": lambda _: gym.make("CartPole-v0"), - ... "policy_graph": PGPolicyGraph, + ... "policy_graph": PGTFPolicy, ... }, ... num_workers=10) >>> for _ in range(10): optimizer.step() @@ -87,12 +87,12 @@ class PolicyEvaluator(EvaluatorInterface): ... policy_graphs={ ... # Use an ensemble of two policies for car agents ... "car_policy1": - ... (PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.99}), + ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}), ... "car_policy2": - ... (PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.95}), + ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}), ... # Use a single shared policy for all traffic lights ... "traffic_light_policy": - ... (PGPolicyGraph, Box(...), Discrete(...), {}), + ... (PGTFPolicy, Box(...), Discrete(...), {}), ... }, ... policy_mapping_fn=lambda agent_id: ... random.choice(["car_policy1", "car_policy2"]) diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index 2b1eca9e8d5b..b921e6cfb0d1 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -112,24 +112,45 @@ def __init__(self, self._prev_action_input = prev_action_input self._prev_reward_input = prev_reward_input self._sampler = action_sampler - self._loss_inputs = loss_inputs - self._loss_input_dict = dict(self._loss_inputs) self._is_training = self._get_is_training_placeholder() self._action_prob = action_prob self._state_inputs = state_inputs or [] self._state_outputs = state_outputs or [] - for i, ph in enumerate(self._state_inputs): - self._loss_input_dict["state_in_{}".format(i)] = ph self._seq_lens = seq_lens self._max_seq_len = max_seq_len self._batch_divisibility_req = batch_divisibility_req + self._update_ops = update_ops + self._stats_fetches = {} + + if loss is not None: + self._initialize_loss(loss, loss_inputs) + else: + self._loss = None + + if len(self._state_inputs) != len(self._state_outputs): + raise ValueError( + "Number of state input and output tensors must match, got: " + "{} vs {}".format(self._state_inputs, self._state_outputs)) + if len(self.get_initial_state()) != len(self._state_inputs): + raise ValueError( + "Length of initial state must match number of state inputs, " + "got: {} vs {}".format(self.get_initial_state(), + self._state_inputs)) + if self._state_inputs and self._seq_lens is None: + raise ValueError( + "seq_lens tensor must be given if state inputs are defined") + + def _initialize_loss(self, loss, loss_inputs): + self._loss_inputs = loss_inputs + self._loss_input_dict = dict(self._loss_inputs) + for i, ph in enumerate(self._state_inputs): + self._loss_input_dict["state_in_{}".format(i)] = ph if self.model: self._loss = self.model.custom_loss(loss, self._loss_input_dict) - self._stats_fetches = {"model": self.model.custom_stats()} + self._stats_fetches.update({"model": self.model.custom_stats()}) else: self._loss = loss - self._stats_fetches = {} self._optimizer = self.optimizer() self._grads_and_vars = [ @@ -141,9 +162,7 @@ def __init__(self, self._loss, self._sess) # gather update ops for any batch norm layers - if update_ops: - self._update_ops = update_ops - else: + if not self._update_ops: self._update_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) if self._update_ops: @@ -153,21 +172,12 @@ def __init__(self, self._apply_op = self.build_apply_op(self._optimizer, self._grads_and_vars) - if len(self._state_inputs) != len(self._state_outputs): - raise ValueError( - "Number of state input and output tensors must match, got: " - "{} vs {}".format(self._state_inputs, self._state_outputs)) - if len(self.get_initial_state()) != len(self._state_inputs): - raise ValueError( - "Length of initial state must match number of state inputs, " - "got: {} vs {}".format(self.get_initial_state(), - self._state_inputs)) - if self._state_inputs and self._seq_lens is None: - raise ValueError( - "seq_lens tensor must be given if state inputs are defined") + if log_once("loss_used"): + logger.debug( + "These tensors were used in the loss_fn:\n\n{}\n".format( + summarize(self._loss_input_dict))) - logger.debug("Created {} with loss inputs: {}".format( - self, self._loss_input_dict)) + self._sess.run(tf.global_variables_initializer()) @override(PolicyGraph) def compute_actions(self, @@ -186,18 +196,21 @@ def compute_actions(self, @override(PolicyGraph) def compute_gradients(self, postprocessed_batch): + assert self._loss is not None, "Loss not initialized" builder = TFRunBuilder(self._sess, "compute_gradients") fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches) @override(PolicyGraph) def apply_gradients(self, gradients): + assert self._loss is not None, "Loss not initialized" builder = TFRunBuilder(self._sess, "apply_gradients") fetches = self._build_apply_gradients(builder, gradients) builder.get(fetches) @override(PolicyGraph) def learn_on_batch(self, postprocessed_batch): + assert self._loss is not None, "Loss not initialized" builder = TFRunBuilder(self._sess, "learn_on_batch") fetches = self._build_learn_on_batch(builder, postprocessed_batch) return builder.get(fetches) @@ -271,7 +284,10 @@ def extra_compute_grad_fetches(self): @DeveloperAPI def optimizer(self): """TF optimizer to use for policy optimization.""" - return tf.train.AdamOptimizer() + if hasattr(self, "config"): + return tf.train.AdamOptimizer(self.config["lr"]) + else: + return tf.train.AdamOptimizer() @DeveloperAPI def gradients(self, optimizer, loss): diff --git a/python/ray/rllib/evaluation/tf_policy_template.py b/python/ray/rllib/evaluation/tf_policy_template.py new file mode 100644 index 000000000000..b2549e973a65 --- /dev/null +++ b/python/ray/rllib/evaluation/tf_policy_template.py @@ -0,0 +1,146 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.evaluation.dynamic_tf_policy_graph import DynamicTFPolicyGraph +from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.utils.annotations import override, DeveloperAPI + + +@DeveloperAPI +def build_tf_policy(name, + loss_fn, + get_default_config=None, + stats_fn=None, + grad_stats_fn=None, + extra_action_fetches_fn=None, + postprocess_fn=None, + optimizer_fn=None, + gradients_fn=None, + before_init=None, + before_loss_init=None, + after_init=None, + make_action_sampler=None, + mixins=None, + get_batch_divisibility_req=None): + """Helper function for creating a dynamic tf policy at runtime. + + Arguments: + name (str): name of the graph (e.g., "PPOPolicy") + loss_fn (func): function that returns a loss tensor the policy, + and dict of experience tensor placeholders + get_default_config (func): optional function that returns the default + config to merge with any overrides + stats_fn (func): optional function that returns a dict of + TF fetches given the policy and batch input tensors + grad_stats_fn (func): optional function that returns a dict of + TF fetches given the policy and loss gradient tensors + extra_action_fetches_fn (func): optional function that returns + a dict of TF fetches given the policy object + postprocess_fn (func): optional experience postprocessing function + that takes the same args as PolicyGraph.postprocess_trajectory() + optimizer_fn (func): optional function that returns a tf.Optimizer + given the policy and config + gradients_fn (func): optional function that returns a list of gradients + given a tf optimizer and loss tensor. If not specified, this + defaults to optimizer.compute_gradients(loss) + before_init (func): optional function to run at the beginning of + policy init that takes the same arguments as the policy constructor + before_loss_init (func): optional function to run prior to loss + init that takes the same arguments as the policy constructor + after_init (func): optional function to run at the end of policy init + that takes the same arguments as the policy constructor + make_action_sampler (func): optional function that returns a + tuple of action and action prob tensors. The function takes + (policy, input_dict, obs_space, action_space, config) as its + arguments + mixins (list): list of any class mixins for the returned policy class. + These mixins will be applied in order and will have higher + precedence than the DynamicTFPolicyGraph class + get_batch_divisibility_req (func): optional function that returns + the divisibility requirement for sample batches + + Returns: + a DynamicTFPolicyGraph instance that uses the specified args + """ + + if not name.endswith("TFPolicy"): + raise ValueError("Name should match *TFPolicy", name) + + base = DynamicTFPolicyGraph + while mixins: + + class new_base(mixins.pop(), base): + pass + + base = new_base + + class graph_cls(base): + def __init__(self, + obs_space, + action_space, + config, + existing_inputs=None): + if get_default_config: + config = dict(get_default_config(), **config) + + if before_init: + before_init(self, obs_space, action_space, config) + + def before_loss_init_wrapper(policy, obs_space, action_space, + config): + if before_loss_init: + before_loss_init(policy, obs_space, action_space, config) + if extra_action_fetches_fn is None: + self._extra_action_fetches = {} + else: + self._extra_action_fetches = extra_action_fetches_fn(self) + + DynamicTFPolicyGraph.__init__( + self, + obs_space, + action_space, + config, + loss_fn, + stats_fn=stats_fn, + grad_stats_fn=grad_stats_fn, + before_loss_init=before_loss_init_wrapper, + existing_inputs=existing_inputs) + + if after_init: + after_init(self, obs_space, action_space, config) + + @override(PolicyGraph) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + if not postprocess_fn: + return sample_batch + return postprocess_fn(self, sample_batch, other_agent_batches, + episode) + + @override(TFPolicyGraph) + def optimizer(self): + if optimizer_fn: + return optimizer_fn(self, self.config) + else: + return TFPolicyGraph.optimizer(self) + + @override(TFPolicyGraph) + def gradients(self, optimizer, loss): + if gradients_fn: + return gradients_fn(self, optimizer, loss) + else: + return TFPolicyGraph.gradients(self, optimizer, loss) + + @override(TFPolicyGraph) + def extra_compute_action_fetches(self): + return dict( + TFPolicyGraph.extra_compute_action_fetches(self), + **self._extra_action_fetches) + + graph_cls.__name__ = name + graph_cls.__qualname__ = name + return graph_cls diff --git a/python/ray/rllib/evaluation/torch_policy_graph.py b/python/ray/rllib/evaluation/torch_policy_graph.py index fb5c879a1ab8..ccf1b9eeb81d 100644 --- a/python/ray/rllib/evaluation/torch_policy_graph.py +++ b/python/ray/rllib/evaluation/torch_policy_graph.py @@ -15,6 +15,7 @@ from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.utils.annotations import override +from ray.rllib.utils.tracking_dict import UsageTrackingDict class TorchPolicyGraph(PolicyGraph): @@ -30,7 +31,7 @@ class TorchPolicyGraph(PolicyGraph): """ def __init__(self, observation_space, action_space, model, loss, - loss_inputs, action_distribution_cls): + action_distribution_cls): """Build a policy graph from policy and loss torch modules. Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES @@ -42,13 +43,8 @@ def __init__(self, observation_space, action_space, model, loss, model (nn.Module): PyTorch policy module. Given observations as input, this module must return a list of outputs where the first item is action logits, and the rest can be any value. - loss (nn.Module): Loss defined as a PyTorch module. The inputs for - this module are defined by the `loss_inputs` param. This module - returns a single scalar loss. Note that this module should - internally be using the model module. - loss_inputs (list): List of SampleBatch columns that will be - passed to the loss module's forward() function when computing - the loss. For example, ["obs", "action", "advantages"]. + loss (func): Function that takes (policy_graph, batch_tensors) + and returns a single scalar loss. action_distribution_cls (ActionDistribution): Class for action distribution. """ @@ -60,7 +56,6 @@ def __init__(self, observation_space, action_space, model, loss, else torch.device("cpu")) self._model = model.to(self.device) self._loss = loss - self._loss_inputs = loss_inputs self._optimizer = self.optimizer() self._action_dist_cls = action_distribution_cls @@ -87,30 +82,26 @@ def compute_actions(self, @override(PolicyGraph) def learn_on_batch(self, postprocessed_batch): + batch_tensors = self._lazy_tensor_dict(postprocessed_batch) + with self.lock: - loss_in = [] - for key in self._loss_inputs: - loss_in.append( - torch.from_numpy(postprocessed_batch[key]).to(self.device)) - loss_out = self._loss(self._model, *loss_in) + loss_out = self._loss(self, batch_tensors) self._optimizer.zero_grad() loss_out.backward() grad_process_info = self.extra_grad_process() self._optimizer.step() - grad_info = self.extra_grad_info() + grad_info = self.extra_grad_info(batch_tensors) grad_info.update(grad_process_info) return {LEARNER_STATS_KEY: grad_info} @override(PolicyGraph) def compute_gradients(self, postprocessed_batch): + batch_tensors = self._lazy_tensor_dict(postprocessed_batch) + with self.lock: - loss_in = [] - for key in self._loss_inputs: - loss_in.append( - torch.from_numpy(postprocessed_batch[key]).to(self.device)) - loss_out = self._loss(self._model, *loss_in) + loss_out = self._loss(self, batch_tensors) self._optimizer.zero_grad() loss_out.backward() @@ -125,7 +116,7 @@ def compute_gradients(self, postprocessed_batch): else: grads.append(None) - grad_info = self.extra_grad_info() + grad_info = self.extra_grad_info(batch_tensors) grad_info.update(grad_process_info) return grads, {LEARNER_STATS_KEY: grad_info} @@ -163,11 +154,21 @@ def extra_action_out(self, model_out): model_out (list): Outputs of the policy model module.""" return {} - def extra_grad_info(self): + def extra_grad_info(self, batch_tensors): """Return dict of extra grad info.""" return {} def optimizer(self): """Custom PyTorch optimizer to use.""" - return torch.optim.Adam(self._model.parameters()) + if hasattr(self, "config"): + return torch.optim.Adam( + self._model.parameters(), lr=self.config["lr"]) + else: + return torch.optim.Adam(self._model.parameters()) + + def _lazy_tensor_dict(self, postprocessed_batch): + batch_tensors = UsageTrackingDict(postprocessed_batch) + batch_tensors.set_get_interceptor( + lambda arr: torch.from_numpy(arr).to(self.device)) + return batch_tensors diff --git a/python/ray/rllib/evaluation/torch_policy_template.py b/python/ray/rllib/evaluation/torch_policy_template.py new file mode 100644 index 000000000000..7f65c2b963b8 --- /dev/null +++ b/python/ray/rllib/evaluation/torch_policy_template.py @@ -0,0 +1,133 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils.annotations import override, DeveloperAPI + + +@DeveloperAPI +def build_torch_policy(name, + loss_fn, + get_default_config=None, + stats_fn=None, + postprocess_fn=None, + extra_action_out_fn=None, + extra_grad_process_fn=None, + optimizer_fn=None, + before_init=None, + after_init=None, + make_model_and_action_dist=None, + mixins=None): + """Helper function for creating a torch policy at runtime. + + Arguments: + name (str): name of the graph (e.g., "PPOPolicy") + loss_fn (func): function that returns a loss tensor the policy, + and dict of experience tensor placeholders + get_default_config (func): optional function that returns the default + config to merge with any overrides + stats_fn (func): optional function that returns a dict of + values given the policy and batch input tensors + postprocess_fn (func): optional experience postprocessing function + that takes the same args as PolicyGraph.postprocess_trajectory() + extra_action_out_fn (func): optional function that returns + a dict of extra values to include in experiences + extra_grad_process_fn (func): optional function that is called after + gradients are computed and returns processing info + optimizer_fn (func): optional function that returns a torch optimizer + given the policy and config + before_init (func): optional function to run at the beginning of + policy init that takes the same arguments as the policy constructor + after_init (func): optional function to run at the end of policy init + that takes the same arguments as the policy constructor + make_model_and_action_dist (func): optional func that takes the same + arguments as policy init and returns a tuple of model instance and + torch action distribution class. If not specified, the default + model and action dist from the catalog will be used + mixins (list): list of any class mixins for the returned policy class. + These mixins will be applied in order and will have higher + precedence than the TorchPolicyGraph class + + Returns: + a TorchPolicyGraph instance that uses the specified args + """ + + if not name.endswith("TorchPolicy"): + raise ValueError("Name should match *TorchPolicy", name) + + base = TorchPolicyGraph + while mixins: + + class new_base(mixins.pop(), base): + pass + + base = new_base + + class graph_cls(base): + def __init__(self, obs_space, action_space, config): + if get_default_config: + config = dict(get_default_config(), **config) + self.config = config + + if before_init: + before_init(self, obs_space, action_space, config) + + if make_model_and_action_dist: + self.model, self.dist_class = make_model_and_action_dist( + self, obs_space, action_space, config) + else: + self.dist_class, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"], torch=True) + self.model = ModelCatalog.get_torch_model( + obs_space, logit_dim, self.config["model"]) + + TorchPolicyGraph.__init__(self, obs_space, action_space, + self.model, loss_fn, self.dist_class) + + if after_init: + after_init(self, obs_space, action_space, config) + + @override(PolicyGraph) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + if not postprocess_fn: + return sample_batch + return postprocess_fn(self, sample_batch, other_agent_batches, + episode) + + @override(TorchPolicyGraph) + def extra_grad_process(self): + if extra_grad_process_fn: + return extra_grad_process_fn(self) + else: + return TorchPolicyGraph.extra_grad_process(self) + + @override(TorchPolicyGraph) + def extra_action_out(self, model_out): + if extra_action_out_fn: + return extra_action_out_fn(self, model_out) + else: + return TorchPolicyGraph.extra_action_out(self, model_out) + + @override(TorchPolicyGraph) + def optimizer(self): + if optimizer_fn: + return optimizer_fn(self, self.config) + else: + return TorchPolicyGraph.optimizer(self) + + @override(TorchPolicyGraph) + def extra_grad_info(self, batch_tensors): + if stats_fn: + return stats_fn(self, batch_tensors) + else: + return TorchPolicyGraph.extra_grad_info(self, batch_tensors) + + graph_cls.__name__ = name + graph_cls.__qualname__ = name + return graph_cls diff --git a/python/ray/rllib/examples/multiagent_two_trainers.py b/python/ray/rllib/examples/multiagent_two_trainers.py index 2c18f2bf4b96..1d4257e4eb9d 100644 --- a/python/ray/rllib/examples/multiagent_two_trainers.py +++ b/python/ray/rllib/examples/multiagent_two_trainers.py @@ -18,7 +18,7 @@ from ray.rllib.agents.dqn.dqn import DQNTrainer from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph from ray.rllib.agents.ppo.ppo import PPOTrainer -from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph +from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy from ray.rllib.tests.test_multi_agent_env import MultiCartpole from ray.tune.logger import pretty_print from ray.tune.registry import register_env @@ -39,7 +39,7 @@ # You can also have multiple policy graphs per trainer, but here we just # show one each for PPO and DQN. policy_graphs = { - "ppo_policy": (PPOPolicyGraph, obs_space, act_space, {}), + "ppo_policy": (PPOTFPolicy, obs_space, act_space, {}), "dqn_policy": (DQNPolicyGraph, obs_space, act_space, {}), } diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index d892dbe7dbac..8d1bbd4fb54d 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -255,7 +255,7 @@ def optimize(self, sess, batch_index): fetches = {"train": self._train_op} for tower in self._towers: - fetches.update(tower.loss_graph.extra_compute_grad_fetches()) + fetches.update(tower.loss_graph._get_grad_and_stats_fetches()) return sess.run(fetches, feed_dict=feed_dict) diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 45df865e43ff..de2671e6a932 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -222,6 +222,6 @@ def stats(self): def _averaged(kv): out = {} for k, v in kv.items(): - if v[0] is not None: + if v[0] is not None and not isinstance(v[0], dict): out[k] = np.mean(v) return out diff --git a/python/ray/rllib/tests/test_external_multi_agent_env.py b/python/ray/rllib/tests/test_external_multi_agent_env.py index e5e182b38655..c01e6fa0b7ae 100644 --- a/python/ray/rllib/tests/test_external_multi_agent_env.py +++ b/python/ray/rllib/tests/test_external_multi_agent_env.py @@ -8,7 +8,7 @@ import unittest import ray -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy from ray.rllib.optimizers import SyncSamplesOptimizer from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv @@ -67,7 +67,7 @@ def testTrainExternalMultiCartpoleManyPolicies(self): obs_space = single_env.observation_space policies = {} for i in range(20): - policies["pg_{}".format(i)] = (PGPolicyGraph, obs_space, act_space, + policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space, {}) policy_ids = list(policies.keys()) ev = PolicyEvaluator( diff --git a/python/ray/rllib/tests/test_io.py b/python/ray/rllib/tests/test_io.py index 9f92c9107c4e..0706be1019cc 100644 --- a/python/ray/rllib/tests/test_io.py +++ b/python/ray/rllib/tests/test_io.py @@ -15,7 +15,7 @@ import ray from ray.rllib.agents.pg import PGTrainer -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy from ray.rllib.evaluation import SampleBatch from ray.rllib.offline import IOContext, JsonWriter, JsonReader from ray.rllib.offline.json_writer import _to_json @@ -159,7 +159,7 @@ def testMultiAgent(self): def gen_policy(): obs_space = single_env.observation_space act_space = single_env.action_space - return (PGPolicyGraph, obs_space, act_space, {}) + return (PGTFPolicy, obs_space, act_space, {}) pg = PGTrainer( env="multi_cartpole", diff --git a/python/ray/rllib/tests/test_multi_agent_env.py b/python/ray/rllib/tests/test_multi_agent_env.py index eccb9aa82fb8..72130712d555 100644 --- a/python/ray/rllib/tests/test_multi_agent_env.py +++ b/python/ray/rllib/tests/test_multi_agent_env.py @@ -8,7 +8,7 @@ import ray from ray.rllib.agents.pg import PGTrainer -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer, AsyncGradientsOptimizer) @@ -470,7 +470,7 @@ def get_initial_state(self): self.assertEqual(batch["state_out_0"][1], h) def testReturningModelBasedRolloutsData(self): - class ModelBasedPolicyGraph(PGPolicyGraph): + class ModelBasedPolicyGraph(PGTFPolicy): def compute_actions(self, obs_batch, state_batches, @@ -584,7 +584,7 @@ def _testWithOptimizer(self, optimizer_cls): } else: policies = { - "p1": (PGPolicyGraph, obs_space, act_space, {}), + "p1": (PGTFPolicy, obs_space, act_space, {}), "p2": (DQNPolicyGraph, obs_space, act_space, dqn_config), } ev = PolicyEvaluator( @@ -640,7 +640,7 @@ def testTrainMultiCartpoleManyPolicies(self): obs_space = env.observation_space policies = {} for i in range(20): - policies["pg_{}".format(i)] = (PGPolicyGraph, obs_space, act_space, + policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space, {}) policy_ids = list(policies.keys()) ev = PolicyEvaluator( diff --git a/python/ray/rllib/tests/test_nested_spaces.py b/python/ray/rllib/tests/test_nested_spaces.py index e4285e42287c..b70bd9a2908e 100644 --- a/python/ray/rllib/tests/test_nested_spaces.py +++ b/python/ray/rllib/tests/test_nested_spaces.py @@ -12,7 +12,7 @@ import ray from ray.rllib.agents.a3c import A2CTrainer from ray.rllib.agents.pg import PGTrainer -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy from ray.rllib.env import MultiAgentEnv from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.vector_env import VectorEnv @@ -333,10 +333,10 @@ def testMultiAgentComplexSpaces(self): "multiagent": { "policy_graphs": { "tuple_policy": ( - PGPolicyGraph, TUPLE_SPACE, act_space, + PGTFPolicy, TUPLE_SPACE, act_space, {"model": {"custom_model": "tuple_spy"}}), "dict_policy": ( - PGPolicyGraph, DICT_SPACE, act_space, + PGTFPolicy, DICT_SPACE, act_space, {"model": {"custom_model": "dict_spy"}}), }, "policy_mapping_fn": lambda a: { diff --git a/python/ray/rllib/tests/test_optimizers.py b/python/ray/rllib/tests/test_optimizers.py index 9c9e6b56b426..5436baeafa90 100644 --- a/python/ray/rllib/tests/test_optimizers.py +++ b/python/ray/rllib/tests/test_optimizers.py @@ -9,7 +9,7 @@ import ray from ray.rllib.agents.ppo import PPOTrainer -from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph +from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy from ray.rllib.evaluation import SampleBatch from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer @@ -240,12 +240,12 @@ def make_sess(): local = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=PPOPolicyGraph, + policy_graph=PPOTFPolicy, tf_session_creator=make_sess) remotes = [ PolicyEvaluator.as_remote().remote( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=PPOPolicyGraph, + policy_graph=PPOTFPolicy, tf_session_creator=make_sess) ] return local, remotes diff --git a/python/ray/rllib/utils/tracking_dict.py b/python/ray/rllib/utils/tracking_dict.py new file mode 100644 index 000000000000..c0f145734e78 --- /dev/null +++ b/python/ray/rllib/utils/tracking_dict.py @@ -0,0 +1,32 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class UsageTrackingDict(dict): + """Dict that tracks which keys have been accessed. + + It can also intercept gets and allow an arbitrary callback to be applied + (i.e., to lazily convert numpy arrays to Tensors). + + We make the simplifying assumption only __getitem__ is used to access + values. + """ + + def __init__(self, *args, **kwargs): + dict.__init__(self, *args, **kwargs) + self.accessed_keys = set() + self.intercepted_values = {} + self.get_interceptor = None + + def set_get_interceptor(self, fn): + self.get_interceptor = fn + + def __getitem__(self, key): + self.accessed_keys.add(key) + value = dict.__getitem__(self, key) + if self.get_interceptor: + if key not in self.intercepted_values: + self.intercepted_values[key] = self.get_interceptor(value) + value = self.intercepted_values[key] + return value