From 99caf1b8c55df151b2b065b61654743c2182b191 Mon Sep 17 00:00:00 2001 From: Alok Singh Date: Wed, 9 May 2018 23:23:04 +0000 Subject: [PATCH] Add magic methods for rollouts Some nice syntactic sugar. --- python/ray/rllib/optimizers/sample_batch.py | 33 +++++++--- python/ray/rllib/test/test_optimizers.py | 14 ++--- python/ray/rllib/utils/process_rollout.py | 16 ++--- python/ray/rllib/utils/sampler.py | 67 ++++++++++++++------- 4 files changed, 83 insertions(+), 47 deletions(-) diff --git a/python/ray/rllib/optimizers/sample_batch.py b/python/ray/rllib/optimizers/sample_batch.py index e3b1c0b98861..5e5e1e95b0b3 100644 --- a/python/ray/rllib/optimizers/sample_batch.py +++ b/python/ray/rllib/optimizers/sample_batch.py @@ -36,8 +36,8 @@ def __init__(self, *args, **kwargs): @staticmethod def concat_samples(samples): out = {} - for k in samples[0].data.keys(): - out[k] = np.concatenate([s.data[k] for s in samples]) + for k in samples[0].keys(): + out[k] = np.concatenate([s[k] for s in samples]) return SampleBatch(out) def concat(self, other): @@ -50,10 +50,10 @@ def concat(self, other): {"a": [1, 2, 3, 4, 5]} """ - assert self.data.keys() == other.data.keys(), "must have same columns" + assert self.keys() == other.keys(), "must have same columns" out = {} - for k in self.data.keys(): - out[k] = np.concatenate([self.data[k], other.data[k]]) + for k in self.keys(): + out[k] = np.concatenate([self[k], other[k]]) return SampleBatch(out) def rows(self): @@ -70,7 +70,7 @@ def rows(self): for i in range(self.count): row = {} - for k in self.data.keys(): + for k in self.keys(): row[k] = self[k][i] yield row @@ -85,19 +85,34 @@ def columns(self, keys): out = [] for k in keys: - out.append(self.data[k]) + out.append(self[k]) return out def shuffle(self): permutation = np.random.permutation(self.count) - for key, val in self.data.items(): - self.data[key] = val[permutation] + for key, val in self.items(): + self[key] = val[permutation] def __getitem__(self, key): return self.data[key] + def __setitem__(self, key, item): + self.data[key] = item + def __str__(self): return "SampleBatch({})".format(str(self.data)) def __repr__(self): return "SampleBatch({})".format(str(self.data)) + + def keys(self): + return self.data.keys() + + def items(self): + return self.data.items() + + def __iter__(self): + return self.data.__iter__() + + def __contains__(self, x): + return x in self.data diff --git a/python/ray/rllib/test/test_optimizers.py b/python/ray/rllib/test/test_optimizers.py index cfb606101db9..3118f7956746 100644 --- a/python/ray/rllib/test/test_optimizers.py +++ b/python/ray/rllib/test/test_optimizers.py @@ -12,7 +12,6 @@ class AsyncOptimizerTest(unittest.TestCase): - def tearDown(self): ray.worker.cleanup() @@ -21,8 +20,9 @@ def testBasic(self): local = _MockEvaluator() remotes = ray.remote(_MockEvaluator) remote_evaluators = [remotes.remote() for i in range(5)] - test_optimizer = AsyncOptimizer( - {"grads_per_step": 10}, local, remote_evaluators) + test_optimizer = AsyncOptimizer({ + "grads_per_step": 10 + }, local, remote_evaluators) test_optimizer.step() self.assertTrue(all(local.get_weights() == 0)) @@ -33,11 +33,11 @@ def testConcat(self): b2 = SampleBatch({"a": np.array([1]), "b": np.array([4])}) b3 = SampleBatch({"a": np.array([1]), "b": np.array([5])}) b12 = b1.concat(b2) - self.assertEqual(b12.data["a"].tolist(), [1, 2, 3, 1]) - self.assertEqual(b12.data["b"].tolist(), [4, 5, 6, 4]) + self.assertEqual(b12["a"].tolist(), [1, 2, 3, 1]) + self.assertEqual(b12["b"].tolist(), [4, 5, 6, 4]) b = SampleBatch.concat_samples([b1, b2, b3]) - self.assertEqual(b.data["a"].tolist(), [1, 2, 3, 1, 1]) - self.assertEqual(b.data["b"].tolist(), [4, 5, 6, 4, 5]) + self.assertEqual(b["a"].tolist(), [1, 2, 3, 1, 1]) + self.assertEqual(b["b"].tolist(), [4, 5, 6, 4, 5]) if __name__ == '__main__': diff --git a/python/ray/rllib/utils/process_rollout.py b/python/ray/rllib/utils/process_rollout.py index 2232135780d2..b2d52fddabb3 100644 --- a/python/ray/rllib/utils/process_rollout.py +++ b/python/ray/rllib/utils/process_rollout.py @@ -26,22 +26,22 @@ def process_rollout(rollout, reward_filter, gamma, lambda_=1.0, use_gae=True): processed rewards.""" traj = {} - trajsize = len(rollout.data["actions"]) - for key in rollout.data: - traj[key] = np.stack(rollout.data[key]) + trajsize = len(rollout["actions"]) + for key in rollout: + traj[key] = np.stack(rollout[key]) if use_gae: - assert "vf_preds" in rollout.data, "Values not found!" - vpred_t = np.stack( - rollout.data["vf_preds"] + [np.array(rollout.last_r)]).squeeze() + assert "vf_preds" in rollout, "Values not found!" + vpred_t = np.stack(rollout["vf_preds"] + + [np.array(rollout.last_r)]).squeeze() delta_t = traj["rewards"] + gamma * vpred_t[1:] - vpred_t[:-1] # This formula for the advantage comes # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438 traj["advantages"] = discount(delta_t, gamma * lambda_) traj["value_targets"] = traj["advantages"] + traj["vf_preds"] else: - rewards_plus_v = np.stack( - rollout.data["rewards"] + [np.array(rollout.last_r)]).squeeze() + rewards_plus_v = np.stack(rollout["rewards"] + + [np.array(rollout.last_r)]).squeeze() traj["advantages"] = discount(rewards_plus_v, gamma)[:-1] for i in range(traj["advantages"].shape[0]): diff --git a/python/ray/rllib/utils/sampler.py b/python/ray/rllib/utils/sampler.py index c522806d6699..8adc99605735 100644 --- a/python/ray/rllib/utils/sampler.py +++ b/python/ray/rllib/utils/sampler.py @@ -56,9 +56,30 @@ def is_terminal(self): terminal (bool): if rollout has terminated.""" return self.data["dones"][-1] + def __getitem__(self, key): + return self.data[key] -CompletedRollout = namedtuple( - "CompletedRollout", ["episode_length", "episode_reward"]) + def __setitem__(self, key, item): + self.data[key] = item + + def keys(self): + return self.data.keys() + + def items(self): + return self.data.items() + + def __iter__(self): + return self.data.__iter__() + + def __next__(self): + return self.data.__next__() + + def __contains__(self, x): + return x in self.data + + +CompletedRollout = namedtuple("CompletedRollout", + ["episode_length", "episode_reward"]) class SyncSampler(object): @@ -71,16 +92,15 @@ class SyncSampler(object): thread.""" async = False - def __init__(self, env, policy, obs_filter, - num_local_steps, horizon=None): + def __init__(self, env, policy, obs_filter, num_local_steps, horizon=None): self.num_local_steps = num_local_steps self.horizon = horizon self.env = env self.policy = policy self._obs_filter = obs_filter - self.rollout_provider = _env_runner( - self.env, self.policy, self.num_local_steps, self.horizon, - self._obs_filter) + self.rollout_provider = _env_runner(self.env, self.policy, + self.num_local_steps, self.horizon, + self._obs_filter) self.metrics_queue = queue.Queue() def get_data(self): @@ -108,10 +128,10 @@ class AsyncSampler(threading.Thread): accumulate and the gradient can be calculated on up to 5 batches.""" async = True - def __init__(self, env, policy, obs_filter, - num_local_steps, horizon=None): - assert getattr(obs_filter, "is_concurrent", False), ( - "Observation Filter must support concurrent updates.") + def __init__(self, env, policy, obs_filter, num_local_steps, horizon=None): + assert getattr( + obs_filter, "is_concurrent", + False), ("Observation Filter must support concurrent updates.") threading.Thread.__init__(self) self.queue = queue.Queue(5) self.metrics_queue = queue.Queue() @@ -132,9 +152,9 @@ def run(self): raise e def _run(self): - rollout_provider = _env_runner( - self.env, self.policy, self.num_local_steps, - self.horizon, self._obs_filter) + rollout_provider = _env_runner(self.env, self.policy, + self.num_local_steps, self.horizon, + self._obs_filter) while True: # The timeout variable exists because apparently, if one worker # dies, the other workers won't die with it, unless the timeout is @@ -232,13 +252,14 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter): action = np.concatenate(action, axis=0).flatten() # Collect the experience. - rollout.add(obs=last_observation, - actions=action, - rewards=reward, - dones=terminal, - features=last_features, - new_obs=observation, - **pi_info) + rollout.add( + obs=last_observation, + actions=action, + rewards=reward, + dones=terminal, + features=last_features, + new_obs=observation, + **pi_info) last_observation = observation last_features = features @@ -247,8 +268,8 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter): terminal_end = True yield CompletedRollout(length, rewards) - if (length >= horizon or - not env.metadata.get("semantics.autoreset")): + if (length >= horizon + or not env.metadata.get("semantics.autoreset")): last_observation = obs_filter(env.reset()) if hasattr(policy, "get_initial_features"): last_features = policy.get_initial_features()