Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions python/ray/rllib/optimizers/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(self, *args, **kwargs):
@staticmethod
def concat_samples(samples):
out = {}
for k in samples[0].data.keys():
out[k] = np.concatenate([s.data[k] for s in samples])
for k in samples[0].keys():
out[k] = np.concatenate([s[k] for s in samples])
return SampleBatch(out)

def concat(self, other):
Expand All @@ -50,10 +50,10 @@ def concat(self, other):
{"a": [1, 2, 3, 4, 5]}
"""

assert self.data.keys() == other.data.keys(), "must have same columns"
assert self.keys() == other.keys(), "must have same columns"
out = {}
for k in self.data.keys():
out[k] = np.concatenate([self.data[k], other.data[k]])
for k in self.keys():
out[k] = np.concatenate([self[k], other[k]])
return SampleBatch(out)

def rows(self):
Expand All @@ -70,7 +70,7 @@ def rows(self):

for i in range(self.count):
row = {}
for k in self.data.keys():
for k in self.keys():
row[k] = self[k][i]
yield row

Expand All @@ -85,19 +85,34 @@ def columns(self, keys):

out = []
for k in keys:
out.append(self.data[k])
out.append(self[k])
return out

def shuffle(self):
permutation = np.random.permutation(self.count)
for key, val in self.data.items():
self.data[key] = val[permutation]
for key, val in self.items():
self[key] = val[permutation]

def __getitem__(self, key):
return self.data[key]

def __setitem__(self, key, item):
self.data[key] = item

def __str__(self):
return "SampleBatch({})".format(str(self.data))

def __repr__(self):
return "SampleBatch({})".format(str(self.data))

def keys(self):
return self.data.keys()

def items(self):
return self.data.items()

def __iter__(self):
return self.data.__iter__()

def __contains__(self, x):
return x in self.data
14 changes: 7 additions & 7 deletions python/ray/rllib/test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


class AsyncOptimizerTest(unittest.TestCase):

def tearDown(self):
ray.worker.cleanup()

Expand All @@ -21,8 +20,9 @@ def testBasic(self):
local = _MockEvaluator()
remotes = ray.remote(_MockEvaluator)
remote_evaluators = [remotes.remote() for i in range(5)]
test_optimizer = AsyncOptimizer(
{"grads_per_step": 10}, local, remote_evaluators)
test_optimizer = AsyncOptimizer({
"grads_per_step": 10
}, local, remote_evaluators)
test_optimizer.step()
self.assertTrue(all(local.get_weights() == 0))

Expand All @@ -33,11 +33,11 @@ def testConcat(self):
b2 = SampleBatch({"a": np.array([1]), "b": np.array([4])})
b3 = SampleBatch({"a": np.array([1]), "b": np.array([5])})
b12 = b1.concat(b2)
self.assertEqual(b12.data["a"].tolist(), [1, 2, 3, 1])
self.assertEqual(b12.data["b"].tolist(), [4, 5, 6, 4])
self.assertEqual(b12["a"].tolist(), [1, 2, 3, 1])
self.assertEqual(b12["b"].tolist(), [4, 5, 6, 4])
b = SampleBatch.concat_samples([b1, b2, b3])
self.assertEqual(b.data["a"].tolist(), [1, 2, 3, 1, 1])
self.assertEqual(b.data["b"].tolist(), [4, 5, 6, 4, 5])
self.assertEqual(b["a"].tolist(), [1, 2, 3, 1, 1])
self.assertEqual(b["b"].tolist(), [4, 5, 6, 4, 5])


if __name__ == '__main__':
Expand Down
16 changes: 8 additions & 8 deletions python/ray/rllib/utils/process_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@ def process_rollout(rollout, reward_filter, gamma, lambda_=1.0, use_gae=True):
processed rewards."""

traj = {}
trajsize = len(rollout.data["actions"])
for key in rollout.data:
traj[key] = np.stack(rollout.data[key])
trajsize = len(rollout["actions"])
for key in rollout:
traj[key] = np.stack(rollout[key])

if use_gae:
assert "vf_preds" in rollout.data, "Values not found!"
vpred_t = np.stack(
rollout.data["vf_preds"] + [np.array(rollout.last_r)]).squeeze()
assert "vf_preds" in rollout, "Values not found!"
vpred_t = np.stack(rollout["vf_preds"] +
[np.array(rollout.last_r)]).squeeze()
delta_t = traj["rewards"] + gamma * vpred_t[1:] - vpred_t[:-1]
# This formula for the advantage comes
# "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
traj["advantages"] = discount(delta_t, gamma * lambda_)
traj["value_targets"] = traj["advantages"] + traj["vf_preds"]
else:
rewards_plus_v = np.stack(
rollout.data["rewards"] + [np.array(rollout.last_r)]).squeeze()
rewards_plus_v = np.stack(rollout["rewards"] +
[np.array(rollout.last_r)]).squeeze()
traj["advantages"] = discount(rewards_plus_v, gamma)[:-1]

for i in range(traj["advantages"].shape[0]):
Expand Down
67 changes: 44 additions & 23 deletions python/ray/rllib/utils/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,30 @@ def is_terminal(self):
terminal (bool): if rollout has terminated."""
return self.data["dones"][-1]

def __getitem__(self, key):
return self.data[key]

CompletedRollout = namedtuple(
"CompletedRollout", ["episode_length", "episode_reward"])
def __setitem__(self, key, item):
self.data[key] = item

def keys(self):
return self.data.keys()

def items(self):
return self.data.items()

def __iter__(self):
return self.data.__iter__()

def __next__(self):
return self.data.__next__()

def __contains__(self, x):
return x in self.data


CompletedRollout = namedtuple("CompletedRollout",
["episode_length", "episode_reward"])


class SyncSampler(object):
Expand All @@ -71,16 +92,15 @@ class SyncSampler(object):
thread."""
async = False

def __init__(self, env, policy, obs_filter,
num_local_steps, horizon=None):
def __init__(self, env, policy, obs_filter, num_local_steps, horizon=None):
self.num_local_steps = num_local_steps
self.horizon = horizon
self.env = env
self.policy = policy
self._obs_filter = obs_filter
self.rollout_provider = _env_runner(
self.env, self.policy, self.num_local_steps, self.horizon,
self._obs_filter)
self.rollout_provider = _env_runner(self.env, self.policy,
self.num_local_steps, self.horizon,
self._obs_filter)
self.metrics_queue = queue.Queue()

def get_data(self):
Expand Down Expand Up @@ -108,10 +128,10 @@ class AsyncSampler(threading.Thread):
accumulate and the gradient can be calculated on up to 5 batches."""
async = True

def __init__(self, env, policy, obs_filter,
num_local_steps, horizon=None):
assert getattr(obs_filter, "is_concurrent", False), (
"Observation Filter must support concurrent updates.")
def __init__(self, env, policy, obs_filter, num_local_steps, horizon=None):
assert getattr(
obs_filter, "is_concurrent",
False), ("Observation Filter must support concurrent updates.")
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.metrics_queue = queue.Queue()
Expand All @@ -132,9 +152,9 @@ def run(self):
raise e

def _run(self):
rollout_provider = _env_runner(
self.env, self.policy, self.num_local_steps,
self.horizon, self._obs_filter)
rollout_provider = _env_runner(self.env, self.policy,
self.num_local_steps, self.horizon,
self._obs_filter)
while True:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
Expand Down Expand Up @@ -232,13 +252,14 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter):
action = np.concatenate(action, axis=0).flatten()

# Collect the experience.
rollout.add(obs=last_observation,
actions=action,
rewards=reward,
dones=terminal,
features=last_features,
new_obs=observation,
**pi_info)
rollout.add(
obs=last_observation,
actions=action,
rewards=reward,
dones=terminal,
features=last_features,
new_obs=observation,
**pi_info)

last_observation = observation
last_features = features
Expand All @@ -247,8 +268,8 @@ def _env_runner(env, policy, num_local_steps, horizon, obs_filter):
terminal_end = True
yield CompletedRollout(length, rewards)

if (length >= horizon or
not env.metadata.get("semantics.autoreset")):
if (length >= horizon
or not env.metadata.get("semantics.autoreset")):
last_observation = obs_filter(env.reset())
if hasattr(policy, "get_initial_features"):
last_features = policy.get_initial_features()
Expand Down