diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py index 706f6824ee7e..faf22f602a7f 100644 --- a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py @@ -49,7 +49,6 @@ def __init__(self, observation_space, action_space, config): [-1]) self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) - is_training = tf.placeholder_with_default(True, ()) # Setup the policy loss if isinstance(action_space, gym.spaces.Box): @@ -74,16 +73,13 @@ def __init__(self, observation_space, action_space, config): ("advantages", advantages), ("value_targets", v_target), ] - for i, ph in enumerate(self.model.state_in): - loss_in.append(("state_in_{}".format(i), ph)) self.state_in = self.model.state_in self.state_out = self.model.state_out TFPolicyGraph.__init__( self, observation_space, action_space, self.sess, obs_input=self.observations, action_sampler=action_dist.sample(), loss=self.loss.total_loss, loss_inputs=loss_in, - is_training=is_training, state_inputs=self.state_in, - state_outputs=self.state_out, + state_inputs=self.state_in, state_outputs=self.state_out, seq_lens=self.model.seq_lens, max_seq_len=self.config["model"]["max_seq_len"]) diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index cdd6bdee715d..f53923030039 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -46,6 +46,8 @@ "gpu_options": { "allow_growth": True, }, + "log_device_placement": False, + "device_count": {"CPU": 1}, "allow_soft_placement": True, # required by PPO multi-gpu }, # Whether to LZ4 compress observations diff --git a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py index a8a44980b705..1dd8941b9768 100644 --- a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py +++ b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py @@ -262,12 +262,11 @@ def _build_action_network(p_values, stochastic, eps): ("dones", self.done_mask), ("weights", self.importance_weights), ] - self.is_training = tf.placeholder_with_default(True, ()) TFPolicyGraph.__init__( self, observation_space, action_space, self.sess, obs_input=self.cur_observations, action_sampler=self.output_actions, loss=self.loss.total_loss, - loss_inputs=self.loss_inputs, is_training=self.is_training) + loss_inputs=self.loss_inputs) self.sess.run(tf.global_variables_initializer()) # Note that this encompasses both the policy and Q-value networks and diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index f94cc16a32da..7905935ce817 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -171,12 +171,11 @@ def _build_q_network(obs): ("dones", self.done_mask), ("weights", self.importance_weights), ] - self.is_training = tf.placeholder_with_default(True, ()) TFPolicyGraph.__init__( self, observation_space, action_space, self.sess, obs_input=self.cur_observations, action_sampler=self.output_actions, loss=self.loss.loss, - loss_inputs=self.loss_inputs, is_training=self.is_training) + loss_inputs=self.loss_inputs) self.sess.run(tf.global_variables_initializer()) def optimizer(self): diff --git a/python/ray/rllib/agents/pg/pg_policy_graph.py b/python/ray/rllib/agents/pg/pg_policy_graph.py index 42124e3d1284..cbd9b274598f 100644 --- a/python/ray/rllib/agents/pg/pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/pg_policy_graph.py @@ -41,16 +41,10 @@ def __init__(self, obs_space, action_space, config): ("advantages", advantages), ] - # LSTM support - for i, ph in enumerate(self.model.state_in): - loss_in.append(("state_in_{}".format(i), ph)) - - is_training = tf.placeholder_with_default(True, ()) TFPolicyGraph.__init__( self, obs_space, action_space, sess, obs_input=obs, action_sampler=action_dist.sample(), loss=loss, - loss_inputs=loss_in, is_training=is_training, - state_inputs=self.model.state_in, + loss_inputs=loss_in, state_inputs=self.model.state_in, state_outputs=self.model.state_out, seq_lens=self.model.seq_lens, max_seq_len=config["model"]["max_seq_len"]) diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index d1e3cde75519..2f8b403aaaa4 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -50,7 +50,7 @@ "simple_optimizer": False, # Override model config "model": { - # Use LSTM model (note: requires simple optimizer for now). + # Whether to use LSTM model "use_lstm": False, # Max seq length for LSTM training. "max_seq_len": 20, diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index ecddb2f993ce..2bc6d5507b1e 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -92,9 +92,10 @@ def __init__(self, observation_space, action_space, dist_cls, logit_dim = ModelCatalog.get_action_dist(action_space) if existing_inputs: - self.loss_in = existing_inputs obs_ph, value_targets_ph, adv_ph, act_ph, \ - logits_ph, vf_preds_ph = [ph for _, ph in existing_inputs] + logits_ph, vf_preds_ph = existing_inputs[:6] + existing_state_in = existing_inputs[6:-1] + existing_seq_lens = existing_inputs[-1] else: obs_ph = tf.placeholder( tf.float32, name="obs", shape=(None,)+observation_space.shape) @@ -107,23 +108,20 @@ def __init__(self, observation_space, action_space, tf.float32, name="vf_preds", shape=(None,)) value_targets_ph = tf.placeholder( tf.float32, name="value_targets", shape=(None,)) - - self.loss_in = [ - ("obs", obs_ph), - ("value_targets", value_targets_ph), - ("advantages", adv_ph), - ("actions", act_ph), - ("logits", logits_ph), - ("vf_preds", vf_preds_ph), - ] - + existing_state_in = None + existing_seq_lens = None + + self.loss_in = [ + ("obs", obs_ph), + ("value_targets", value_targets_ph), + ("advantages", adv_ph), + ("actions", act_ph), + ("logits", logits_ph), + ("vf_preds", vf_preds_ph), + ] self.model = ModelCatalog.get_model( - obs_ph, logit_dim, self.config["model"]) - - # LSTM support - if not existing_inputs: - for i, ph in enumerate(self.model.state_in): - self.loss_in.append(("state_in_{}".format(i), ph)) + obs_ph, logit_dim, self.config["model"], + state_in=existing_state_in, seq_lens=existing_seq_lens) # KL Coefficient self.kl_coeff = tf.get_variable( @@ -155,15 +153,14 @@ def __init__(self, observation_space, action_space, clip_param=self.config["clip_param"], vf_loss_coeff=self.config["kl_target"], use_gae=self.config["use_gae"]) - self.is_training = tf.placeholder_with_default(True, ()) TFPolicyGraph.__init__( self, observation_space, action_space, self.sess, obs_input=obs_ph, action_sampler=self.sampler, loss=self.loss_obj.loss, - loss_inputs=self.loss_in, is_training=self.is_training, - state_inputs=self.model.state_in, - state_outputs=self.model.state_out, seq_lens=self.model.seq_lens) + loss_inputs=self.loss_in, state_inputs=self.model.state_in, + state_outputs=self.model.state_out, seq_lens=self.model.seq_lens, + max_seq_len=config["model"]["max_seq_len"]) self.sess.run(tf.global_variables_initializer()) diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index 0df9d9935f0c..58f5e3cacb42 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -3,6 +3,7 @@ from __future__ import print_function import tensorflow as tf +import numpy as np import ray from ray.rllib.evaluation.policy_graph import PolicyGraph @@ -36,9 +37,8 @@ class TFPolicyGraph(PolicyGraph): def __init__( self, observation_space, action_space, sess, obs_input, - action_sampler, loss, loss_inputs, is_training, - state_inputs=None, state_outputs=None, seq_lens=None, - max_seq_len=20): + action_sampler, loss, loss_inputs, state_inputs=None, + state_outputs=None, seq_lens=None, max_seq_len=20): """Initialize the policy graph. Arguments: @@ -54,10 +54,8 @@ def __init__( input argument. Each placeholder name must correspond to a SampleBatch column key returned by postprocess_trajectory(), and has shape [BATCH_SIZE, data...]. - is_training (Tensor): input placeholder for whether we are - currently training the policy. - state_inputs (list): list of RNN state output Tensors. - state_outputs (list): list of initial state values. + state_inputs (list): list of RNN state input Tensors. + state_outputs (list): list of RNN state output Tensors. seq_lens (Tensor): placeholder for RNN sequence lengths, of shape [NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See models/lstm.py for more information. @@ -72,9 +70,11 @@ def __init__( self._loss = loss self._loss_inputs = loss_inputs self._loss_input_dict = dict(self._loss_inputs) - self._is_training = is_training + self._is_training = tf.placeholder_with_default(True, ()) self._state_inputs = state_inputs or [] self._state_outputs = state_outputs or [] + for i, ph in enumerate(self._state_inputs): + self._loss_input_dict["state_in_{}".format(i)] = ph self._seq_lens = seq_lens self._max_seq_len = max_seq_len self._optimizer = self.optimizer() @@ -99,6 +99,8 @@ def build_compute_actions( (self._state_inputs, state_batches) builder.add_feed_dict(self.extra_compute_action_feed_dict()) builder.add_feed_dict({self._obs_input: obs_batch}) + if state_batches: + builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) builder.add_feed_dict({self._is_training: is_training}) builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) fetches = builder.add_fetches( @@ -123,10 +125,9 @@ def _get_loss_inputs_dict(self, batch): return feed_dict # RNN case - feature_keys = [ - k for k, v in self._loss_inputs if not k.startswith("state_in_")] + feature_keys = [k for k, v in self._loss_inputs] state_keys = [ - k for k, v in self._loss_inputs if k.startswith("state_in_")] + "state_in_{}".format(i) for i in range(len(self._state_inputs))] feature_sequences, initial_states, seq_lens = chop_into_sequences( batch["t"], [batch[k] for k in feature_keys], diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 0da5dc2dbf5f..db5717a83e76 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -138,41 +138,47 @@ def get_action_placeholder(action_space): " not supported".format(action_space)) @staticmethod - def get_model(inputs, num_outputs, options=None): + def get_model( + inputs, num_outputs, options=None, state_in=None, seq_lens=None): """Returns a suitable model conforming to given input and output specs. Args: inputs (Tensor): The input tensor to the model. num_outputs (int): The size of the output vector of the model. options (dict): Optional args to pass to the model constructor. + state_in (list): Optional RNN state in tensors. + seq_in (Tensor): Optional RNN sequence length tensor. Returns: model (Model): Neural network model. """ options = options or {} - model = ModelCatalog._get_model(inputs, num_outputs, options) + model = ModelCatalog._get_model( + inputs, num_outputs, options, state_in, seq_lens) if options.get("use_lstm"): - model = LSTM(model.last_layer, num_outputs, options) + model = LSTM( + model.last_layer, num_outputs, options, state_in, seq_lens) return model @staticmethod - def _get_model(inputs, num_outputs, options): + def _get_model(inputs, num_outputs, options, state_in, seq_lens): if "custom_model" in options: model = options["custom_model"] print("Using custom model {}".format(model)) return _global_registry.get(RLLIB_MODEL, model)( - inputs, num_outputs, options) + inputs, num_outputs, options, + state_in=state_in, seq_lens=seq_lens) obs_rank = len(inputs.shape) - 1 # num_outputs > 1 used to avoid hitting this with the value function if isinstance(options.get("custom_options", {}).get( "multiagent_fcnet_hiddens", 1), list) and num_outputs > 1: - return MultiAgentFullyConnectedNetwork(inputs, - num_outputs, options) + return MultiAgentFullyConnectedNetwork( + inputs, num_outputs, options) if obs_rank > 1: return VisionNetwork(inputs, num_outputs, options) diff --git a/python/ray/rllib/models/lstm.py b/python/ray/rllib/models/lstm.py index 304f3470ea6f..55a9626cbd82 100644 --- a/python/ray/rllib/models/lstm.py +++ b/python/ray/rllib/models/lstm.py @@ -41,8 +41,8 @@ def add_time_dimension(padded_inputs, seq_lens): # Sequence lengths have to be specified for LSTM batch inputs. The # input batch must be padded to the max seq length given here. That is, # batch_size == len(seq_lens) * max(seq_lens) - max_seq_len = tf.reduce_max(seq_lens) padded_batch_size = tf.shape(padded_inputs)[0] + max_seq_len = padded_batch_size // tf.shape(seq_lens)[0] # Dynamically reshape the padded batch to introduce a time dimension. new_batch_size = padded_batch_size // max_seq_len @@ -155,9 +155,14 @@ def _build_layers(self, inputs, num_outputs, options): np.zeros(lstm.state_size.h, np.float32)] # Setup LSTM inputs - c_in = tf.placeholder(tf.float32, [None, lstm.state_size.c], name="c") - h_in = tf.placeholder(tf.float32, [None, lstm.state_size.h], name="h") - self.state_in = [c_in, h_in] + if self.state_in: + c_in, h_in = self.state_in + else: + c_in = tf.placeholder( + tf.float32, [None, lstm.state_size.c], name="c") + h_in = tf.placeholder( + tf.float32, [None, lstm.state_size.h], name="h") + self.state_in = [c_in, h_in] # Setup LSTM outputs if use_tf100_api: diff --git a/python/ray/rllib/models/model.py b/python/ray/rllib/models/model.py index 278fd887f55b..27206adaf2d0 100644 --- a/python/ray/rllib/models/model.py +++ b/python/ray/rllib/models/model.py @@ -37,17 +37,19 @@ class Model(object): a scale parameter (like a standard deviation). """ - def __init__(self, inputs, num_outputs, options): + def __init__( + self, inputs, num_outputs, options, state_in=None, seq_lens=None): self.inputs = inputs # Default attribute values for the non-RNN case self.state_init = [] - self.state_in = [] + self.state_in = state_in or [] self.state_out = [] - self.seq_lens = tf.placeholder_with_default( - tf.ones( # reshape needed for older tf versions - tf.reshape(tf.shape(inputs)[0], [1]), dtype=tf.int32), - [None], name="seq_lens") + if seq_lens is not None: + self.seq_lens = seq_lens + else: + self.seq_lens = tf.placeholder( + dtype=tf.int32, shape=[None], name="seq_lens") if options.get("free_log_std", False): assert num_outputs % 2 == 0 diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index 6a694e1e5147..844dc11fbedb 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -3,9 +3,7 @@ from __future__ import print_function from collections import namedtuple -import os -from tensorflow.python.client import timeline import tensorflow as tf @@ -34,9 +32,11 @@ class LocalSyncParallelOptimizer(object): Args: optimizer: Delegate TensorFlow optimizer object. devices: List of the names of TensorFlow devices to parallelize over. - input_placeholders: List of (name, input_placeholder) - for the loss function. Tensors of these shapes will be passed - to build_graph() in order to define the per-device loss ops. + input_placeholders: List of input_placeholders for the loss function. + Tensors of these shapes will be passed to build_graph() in order + to define the per-device loss ops. + rnn_inputs: Extra input placeholders for RNN inputs. These will have + shape [BATCH_SIZE // MAX_SEQ_LEN, ...]. per_device_batch_size: Number of tuples to optimize over at a time per device. In each call to `optimize()`, `len(devices) * per_device_batch_size` tuples of data will be @@ -47,7 +47,7 @@ class LocalSyncParallelOptimizer(object): grad_norm_clipping: None or int stdev to clip grad norms by """ - def __init__(self, optimizer, devices, input_placeholders, + def __init__(self, optimizer, devices, input_placeholders, rnn_inputs, per_device_batch_size, build_graph, logdir, grad_norm_clipping=None): # TODO(rliaw): remove logdir @@ -55,27 +55,31 @@ def __init__(self, optimizer, devices, input_placeholders, self.devices = devices self.batch_size = per_device_batch_size * len(devices) self.per_device_batch_size = per_device_batch_size - self.loss_inputs = input_placeholders + self.loss_inputs = input_placeholders + rnn_inputs self.build_graph = build_graph self.logdir = logdir # First initialize the shared loss network with tf.name_scope(TOWER_SCOPE_NAME): - self._shared_loss = build_graph(input_placeholders) + self._shared_loss = build_graph(self.loss_inputs) # Then setup the per-device loss graphs that use the shared weights - self._batch_index = tf.placeholder(tf.int32) + self._batch_index = tf.placeholder(tf.int32, name="batch_index") + + # When loading RNN input, we dynamically determine the max seq len + self._max_seq_len = tf.placeholder(tf.int32, name="max_seq_len") + self._loaded_max_seq_len = 1 # Split on the CPU in case the data doesn't fit in GPU memory. with tf.device("/cpu:0"): - names, placeholders = zip(*input_placeholders) data_splits = zip( - *[tf.split(ph, len(devices)) for ph in placeholders]) + *[tf.split(ph, len(devices)) for ph in self.loss_inputs]) self._towers = [] for device, device_placeholders in zip(self.devices, data_splits): self._towers.append( - self._setup_device(device, zip(names, device_placeholders))) + self._setup_device( + device, device_placeholders, len(input_placeholders))) avg = average_gradients([t.grads for t in self._towers]) if grad_norm_clipping: @@ -84,7 +88,7 @@ def __init__(self, optimizer, devices, input_placeholders, avg[i] = (tf.clip_by_norm(grad, grad_norm_clipping), var) self._train_op = self.optimizer.apply_gradients(avg) - def load_data(self, sess, inputs, full_trace=False): + def load_data(self, sess, inputs, state_inputs): """Bulk loads the specified inputs into device memory. The shape of the inputs must conform to the shapes of the input @@ -95,37 +99,47 @@ def load_data(self, sess, inputs, full_trace=False): Args: sess: TensorFlow session. - inputs: List of Tensors matching the input placeholders specified - at construction time of this optimizer. - full_trace: Whether to profile data loading. + inputs: List of arrays matching the input placeholders, of shape + [BATCH_SIZE, ...]. + state_inputs: List of RNN input arrays. These arrays have size + [BATCH_SIZE / MAX_SEQ_LEN, ...]. Returns: The number of tuples loaded per device. """ feed_dict = {} - assert len(self.loss_inputs) == len(inputs) - for (name, ph), arr in zip(self.loss_inputs, inputs): - truncated_arr = make_divisible_by(arr, self.batch_size) - feed_dict[ph] = truncated_arr - truncated_len = len(truncated_arr) - - if full_trace: - run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + assert len(self.loss_inputs) == len(inputs + state_inputs), \ + (self.loss_inputs, inputs, state_inputs) + + # The RNN truncation case is more complicated + if len(state_inputs) > 0: + seq_len = len(inputs[0]) // len(state_inputs[0]) + self._loaded_max_seq_len = seq_len + assert len(state_inputs[0]) * seq_len == len(inputs[0]) + # Make sure the shorter state inputs arrays are evenly divisible + state_inputs = [ + make_divisible_by(arr, self.batch_size) + for arr in state_inputs + ] + # Then truncate the data inputs to match + inputs = [ + arr[:len(state_inputs[0]) * seq_len] + for arr in inputs + ] + assert len(state_inputs[0]) * seq_len == len(inputs[0]) + assert len(state_inputs[0]) % self.batch_size == 0 + for ph, arr in zip(self.loss_inputs, inputs + state_inputs): + feed_dict[ph] = arr + truncated_len = len(inputs[0]) else: - run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE) - run_metadata = tf.RunMetadata() + for ph, arr in zip(self.loss_inputs, inputs + state_inputs): + truncated_arr = make_divisible_by(arr, self.batch_size) + feed_dict[ph] = truncated_arr + truncated_len = len(truncated_arr) sess.run( - [t.init_op for t in self._towers], - feed_dict=feed_dict, - options=run_options, - run_metadata=run_metadata) - if full_trace: - trace = timeline.Timeline(step_stats=run_metadata.step_stats) - trace_file = open(os.path.join(self.logdir, "timeline-load.json"), - "w") - trace_file.write(trace.generate_chrome_trace_format()) + [t.init_op for t in self._towers], feed_dict=feed_dict) tuples_per_device = truncated_len / len(self.devices) assert tuples_per_device > 0, \ @@ -136,7 +150,7 @@ def load_data(self, sess, inputs, full_trace=False): assert tuples_per_device % self.per_device_batch_size == 0 return tuples_per_device - def optimize(self, sess, batch_index, file_writer=None): + def optimize(self, sess, batch_index): """Run a single step of SGD. Runs a SGD step over a slice of the preloaded batch with size given by @@ -151,19 +165,14 @@ def optimize(self, sess, batch_index, file_writer=None): batch_index: Offset into the preloaded data. This value must be between `0` and `tuples_per_device`. The amount of data to process is always fixed to `per_device_batch_size`. - file_writer: If specified, tf metrics will be written out using - this. Returns: The outputs of extra_ops evaluated over the batch. """ - if file_writer: - run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) - else: - run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE) - run_metadata = tf.RunMetadata() - - feed_dict = {self._batch_index: batch_index} + feed_dict = { + self._batch_index: batch_index, + self._max_seq_len: self._loaded_max_seq_len, + } for tower in self._towers: feed_dict.update(tower.loss_graph.extra_compute_grad_feed_dict()) feed_dict.update(tower.loss_graph.extra_apply_grad_feed_dict()) @@ -173,21 +182,7 @@ def optimize(self, sess, batch_index, file_writer=None): fetches.update(tower.loss_graph.extra_compute_grad_fetches()) fetches.update(tower.loss_graph.extra_apply_grad_fetches()) - outs = sess.run( - fetches, - feed_dict=feed_dict, - options=run_options, - run_metadata=run_metadata) - - if file_writer: - trace = timeline.Timeline(step_stats=run_metadata.step_stats) - trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"), - "w") - trace_file.write(trace.generate_chrome_trace_format()) - file_writer.add_run_metadata( - run_metadata, "sgd_train_{}".format(batch_index)) - - return outs + return sess.run(fetches, feed_dict=feed_dict) def get_common_loss(self): return self._shared_loss @@ -195,23 +190,31 @@ def get_common_loss(self): def get_device_losses(self): return [t.loss_graph for t in self._towers] - def _setup_device(self, device, device_input_placeholders): + def _setup_device(self, device, device_input_placeholders, num_data_in): + assert num_data_in <= len(device_input_placeholders) with tf.device(device): with tf.name_scope(TOWER_SCOPE_NAME): device_input_batches = [] device_input_slices = [] - for name, ph in device_input_placeholders: + for i, ph in enumerate(device_input_placeholders): current_batch = tf.Variable( ph, trainable=False, validate_shape=False, collections=[]) device_input_batches.append(current_batch) + if i < num_data_in: + scale = self._max_seq_len + granularity = self._max_seq_len + else: + scale = self._max_seq_len + granularity = 1 current_slice = tf.slice( current_batch, - [self._batch_index] + [0] * len(ph.shape[1:]), - ([self.per_device_batch_size] + [-1] * - len(ph.shape[1:]))) + ([self._batch_index // scale * granularity] + + [0] * len(ph.shape[1:])), + ([self.per_device_batch_size // scale * granularity] + + [-1] * len(ph.shape[1:]))) current_slice.set_shape(ph.shape) - device_input_slices.append((name, current_slice)) + device_input_slices.append(current_slice) graph_obj = self.build_graph(device_input_slices) device_grads = graph_obj.gradients(self.optimizer) return Tower( diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index ee348e362285..0c39aab7a678 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -55,12 +55,12 @@ def _init(self, sgd_batch_size=128, sgd_stepsize=5e-5, num_sgd_iter=10, print("LocalMultiGPUOptimizer devices", self.devices) assert set(self.local_evaluator.policy_map.keys()) == {"default"}, \ - "Multi-agent is not supported" + ("Multi-agent is not supported with multi-GPU. Try using the " + "simple optimizer instead.") self.policy = self.local_evaluator.policy_map["default"] assert isinstance(self.policy, TFPolicyGraph), \ - "Only TF policies are supported" - assert len(self.policy.get_initial_state()) == 0, \ - "No RNN support yet for multi-gpu. Try the simple optimizer." + ("Only TF policies are supported with multi-GPU. Try using the " + "simple optimizer instead.") # per-GPU graph copies created below must share vars with the policy # reuse is set to AUTO_REUSE because Adam nodes are created after @@ -68,10 +68,16 @@ def _init(self, sgd_batch_size=128, sgd_stepsize=5e-5, num_sgd_iter=10, with self.local_evaluator.tf_sess.graph.as_default(): with self.local_evaluator.tf_sess.as_default(): with tf.variable_scope("default", reuse=tf.AUTO_REUSE): + if self.policy._state_inputs: + rnn_inputs = self.policy._state_inputs + [ + self.policy._seq_lens] + else: + rnn_inputs = [] self.par_opt = LocalSyncParallelOptimizer( tf.train.AdamOptimizer(self.sgd_stepsize), self.devices, - self.policy.loss_inputs(), + [v for _, v in self.policy.loss_inputs()], + rnn_inputs, self.per_device_batch_size, self.policy.copy, os.getcwd()) @@ -103,9 +109,17 @@ def step(self): samples.shuffle() with self.load_timer: + tuples = self.policy._get_loss_inputs_dict(samples) + data_keys = [ph for _, ph in self.policy.loss_inputs()] + if self.policy._state_inputs: + state_keys = ( + self.policy._state_inputs + [self.policy._seq_lens]) + else: + state_keys = [] tuples_per_device = self.par_opt.load_data( self.sess, - samples.columns([key for key, _ in self.policy.loss_inputs()])) + [tuples[k] for k in data_keys], + [tuples[k] for k in state_keys]) with self.grad_timer: num_batches = ( diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 236d31ef90a1..7fd239645954 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -30,7 +30,14 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ --env CartPole-v1 \ --run PPO \ --stop '{"training_iteration": 2}' \ - --config '{"simple_optimizer": true, "model": {"use_lstm": true}}' + --config '{"simple_optimizer": false, "num_sgd_iter": 2, "model": {"use_lstm": true}}' + +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env CartPole-v1 \ + --run PPO \ + --stop '{"training_iteration": 2}' \ + --config '{"simple_optimizer": true, "num_sgd_iter": 2, "model": {"use_lstm": true}}' docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \