Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1074,13 +1074,6 @@ py_test(
# Tag: models
# --------------------------------------------------------------------

py_test(
name = "test_attention_nets",
tags = ["models"],
size = "small",
srcs = ["models/tests/test_attention_nets.py"]
)

py_test(
name = "test_convtranspose2d_stack",
tags = ["models"],
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,
**kwargs) -> None:
"""Called at the beginning of Policy.learn_on_batch().

Note: This is called before the Model's `preprocess_train_batch()`
is called.
Note: This is called before 0-padding via
`pad_batch_to_sequences_of_same_size`.

Args:
policy (Policy): Reference to the current Policy object.
Expand Down
3 changes: 2 additions & 1 deletion rllib/agents/ppo/ppo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def postprocess_ppo_gae(
# input_dict.
if policy.config["_use_trajectory_view_api"]:
# Create an input dict according to the Model's requirements.
input_dict = policy.model.get_input_dict(sample_batch, index=-1)
input_dict = policy.model.get_input_dict(
sample_batch, index="last")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 love it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking ahead, is it possible we unify model.from_batch() and model.get_input_dict()? It seems currently only PPO is using this get_input_dict method so maybe we should remove it.

Thinking... model.from_batch(sample_batch, batch_index="last") would be pretty clean. This is probably out of scope of this PR though.

last_r = policy._value(**input_dict)
# TODO: (sven) Remove once trajectory view API is all-algo default.
else:
Expand Down
171 changes: 138 additions & 33 deletions rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
from gym.spaces import Space
import logging
import math
import numpy as np
from typing import Any, List, Dict, Tuple, TYPE_CHECKING, Union

Expand Down Expand Up @@ -34,6 +35,9 @@ def to_float_np_array(v: List[Any]) -> np.ndarray:
return arr


_INIT_COLS = [SampleBatch.OBS]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1



class _AgentCollector:
"""Collects samples for one agent in one trajectory (episode).

Expand All @@ -45,9 +49,18 @@ class _AgentCollector:

_next_unroll_id = 0 # disambiguates unrolls within a single episode

def __init__(self, shift_before: int = 0):
self.shift_before = max(shift_before, 1)
def __init__(self, view_reqs):
# Determine the size of the buffer we need for data before the actual
# episode starts. This is used for 0-buffering of e.g. prev-actions,
# or internal state inputs.
self.shift_before = -min(
(int(vr.shift.split(":")[0])
if isinstance(vr.shift, str) else vr.shift) +
(-1 if vr.data_col in _INIT_COLS or k in _INIT_COLS else 0)
for k, vr in view_reqs.items())
# The actual data buffers (lists holding each timestep's data).
self.buffers: Dict[str, List] = {}
# The episode ID for the agent for which we collect data.
self.episode_id = None
# The simple timestep count for this agent. Gets increased by one
# each time a (non-initial!) observation is added.
Expand Down Expand Up @@ -137,31 +150,88 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch:
# -> skip.
if data_col not in self.buffers:
continue

# OBS are already shifted by -1 (the initial obs starts one ts
# before all other data columns).
shift = view_req.shift - \
(1 if data_col == SampleBatch.OBS else 0)
obs_shift = -1 if data_col == SampleBatch.OBS else 0

# Keep an np-array cache so we don't have to regenerate the
# np-array for different view_cols using to the same data_col.
if data_col not in np_data:
np_data[data_col] = to_float_np_array(self.buffers[data_col])
# Shift is exactly 0: Send trajectory as is.
if shift == 0:
data = np_data[data_col][self.shift_before:]
# Shift is positive: We still need to 0-pad at the end here.
elif shift > 0:
data = to_float_np_array(
self.buffers[data_col][self.shift_before + shift:] + [
np.zeros(
shape=view_req.space.shape,
dtype=view_req.space.dtype) for _ in range(shift)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is becoming more complex as we now allow for more sophisticated view requirements (ranges of timesteps).

# Range of indices on time-axis, e.g. "-50:-1". Together with
# the `batch_repeat_value`, this determines the data produced.
# Example:
# batch_repeat_value=10, shift_from=-3, shift_to=-1
# buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# resulting data=[[-3, -2, -1], [7, 8, 9]]
# Range of 3 consecutive items repeats every 10 timesteps.
if view_req.shift_from is not None:
if view_req.batch_repeat_value > 1:
count = int(
math.ceil((len(np_data[data_col]) - self.shift_before)
/ view_req.batch_repeat_value))
data = np.asarray([
np_data[data_col][self.shift_before +
(i * view_req.batch_repeat_value) +
view_req.shift_from +
obs_shift:self.shift_before +
(i * view_req.batch_repeat_value) +
view_req.shift_to + 1 + obs_shift]
for i in range(count)
])
# Shift is negative: Shift into the already existing and 0-padded
# "before" area of our buffers.
else:
data = np_data[data_col][self.shift_before +
view_req.shift_from +
obs_shift:self.shift_before +
view_req.shift_to + 1 + obs_shift]
# Set of (probably non-consecutive) indices.
# Example:
# shift=[-3, 0]
# buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...]
elif isinstance(view_req.shift, np.ndarray):
data = np_data[data_col][self.shift_before + obs_shift +
view_req.shift]
# Single shift int value. Use the trajectory as-is, and if
# `shift` != 0: shifted by that value.
else:
data = np_data[data_col][self.shift_before + shift:shift]
shift = view_req.shift + obs_shift

# Batch repeat (only provide a value every n timesteps).
if view_req.batch_repeat_value > 1:
count = int(
math.ceil((len(np_data[data_col]) - self.shift_before)
/ view_req.batch_repeat_value))
data = np.asarray([
np_data[data_col][self.shift_before + (
i * view_req.batch_repeat_value) + shift]
for i in range(count)
])
# Shift is exactly 0: Use trajectory as is.
elif shift == 0:
data = np_data[data_col][self.shift_before:]
# Shift is positive: We still need to 0-pad at the end.
elif shift > 0:
data = to_float_np_array(
self.buffers[data_col][self.shift_before + shift:] + [
np.zeros(
shape=view_req.space.shape,
dtype=view_req.space.dtype)
for _ in range(shift)
])
# Shift is negative: Shift into the already existing and
# 0-padded "before" area of our buffers.
else:
data = np_data[data_col][self.shift_before + shift:shift]

if len(data) > 0:
batch_data[view_col] = data

batch = SampleBatch(batch_data)
# Due to possible batch-repeats > 1, columns in the resulting batch
# may not all have the same batch size.
batch = SampleBatch(batch_data, _dont_check_lens=True)

# Add EPS_ID and UNROLL_ID to batch.
batch.data[SampleBatch.EPS_ID] = np.repeat(self.episode_id,
Expand Down Expand Up @@ -230,15 +300,22 @@ class _PolicyCollector:
appended to this policy's buffers.
"""

def __init__(self):
"""Initializes a _PolicyCollector instance."""
def __init__(self, policy):
"""Initializes a _PolicyCollector instance.

Args:
policy (Policy): The policy object.
"""

self.buffers: Dict[str, List] = collections.defaultdict(list)
self.policy = policy
# The total timestep count for all agents that use this policy.
# NOTE: This is not an env-step count (across n agents). AgentA and
# agentB, both using this policy, acting in the same episode and both
# doing n steps would increase the count by 2*n.
self.agent_steps = 0
# Seq-lens list of already added agent batches.
self.seq_lens = [] if policy.is_recurrent() else None

def add_postprocessed_batch_for_training(
self, batch: SampleBatch,
Expand All @@ -257,11 +334,18 @@ def add_postprocessed_batch_for_training(
# 1) If col is not in view_requirements, we must have a direct
# child of the base Policy that doesn't do auto-view req creation.
# 2) Col is in view-reqs and needed for training.
if view_col not in view_requirements or \
view_requirements[view_col].used_for_training:
view_req = view_requirements.get(view_col)
if view_req is None or view_req.used_for_training:
self.buffers[view_col].extend(data)
# Add the agent's trajectory length to our count.
self.agent_steps += batch.count
# Adjust the seq-lens array depending on the incoming agent sequences.
if self.seq_lens is not None:
max_seq_len = self.policy.config["model"]["max_seq_len"]
count = batch.count
while count > 0:
self.seq_lens.append(min(count, max_seq_len))
count -= max_seq_len

def build(self):
"""Builds a SampleBatch for this policy from the collected data.
Expand All @@ -273,20 +357,22 @@ def build(self):
this policy.
"""
# Create batch from our buffers.
batch = SampleBatch(self.buffers)
assert SampleBatch.UNROLL_ID in batch.data
batch = SampleBatch(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internal states are already time-chunked (only one internal state value every max_seq_len timesteps). That's why we shouldn't compare column sizes in SampleBatch (e.g. the obs column will have a different batch dim than the state_in_o one)

self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True)
# Clear buffers for future samples.
self.buffers.clear()
# Reset agent steps to 0.
# Reset agent steps to 0 and seq-lens to empty list.
self.agent_steps = 0
if self.seq_lens is not None:
self.seq_lens = []
return batch


class _PolicyCollectorGroup:
def __init__(self, policy_map):
self.policy_collectors = {
pid: _PolicyCollector()
for pid in policy_map.keys()
pid: _PolicyCollector(policy)
for pid, policy in policy_map.items()
}
# Total env-steps (1 env-step=up to N agents stepped).
self.env_steps = 0
Expand Down Expand Up @@ -396,11 +482,14 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
self.agent_key_to_policy_id[agent_key] = policy_id
else:
assert self.agent_key_to_policy_id[agent_key] == policy_id
policy = self.policy_map[policy_id]
view_reqs = policy.model.inference_view_requirements if \
getattr(policy, "model", None) else policy.view_requirements

# Add initial obs to Trajectory.
assert agent_key not in self.agent_collectors
# TODO: determine exact shift-before based on the view-req shifts.
self.agent_collectors[agent_key] = _AgentCollector()
self.agent_collectors[agent_key] = _AgentCollector(view_reqs)
self.agent_collectors[agent_key].add_init_obs(
episode_id=episode.episode_id,
agent_index=episode._agent_index(agent_id),
Expand Down Expand Up @@ -466,11 +555,19 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \
for view_col, view_req in view_reqs.items():
# Create the batch of data from the different buffers.
data_col = view_req.data_col or view_col
time_indices = \
view_req.shift - (
1 if data_col in [SampleBatch.OBS, "t", "env_id",
SampleBatch.AGENT_INDEX] else 0)
delta = -1 if data_col in [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also more complex now to build an action input dict: Inputs could be over some range of time steps.

SampleBatch.OBS, "t", "env_id", SampleBatch.EPS_ID,
SampleBatch.AGENT_INDEX
] else 0
# Range of shifts, e.g. "-100:0". Note: This includes index 0!
if view_req.shift_from is not None:
time_indices = (view_req.shift_from + delta,
view_req.shift_to + delta)
# Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0].
else:
time_indices = view_req.shift + delta
data_list = []
# Loop through agents and add-up their data (batch).
for k in keys:
if data_col == SampleBatch.EPS_ID:
data_list.append(self.agent_collectors[k].episode_id)
Expand All @@ -482,7 +579,15 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \
self.agent_collectors[k]._build_buffers({
data_col: fill_value
})
data_list.append(buffers[k][data_col][time_indices])
if isinstance(time_indices, tuple):
if time_indices[1] == -1:
data_list.append(
buffers[k][data_col][time_indices[0]:])
else:
data_list.append(buffers[k][data_col][time_indices[
0]:time_indices[1] + 1])
else:
data_list.append(buffers[k][data_col][time_indices])
input_dict[view_col] = np.array(data_list)

self._reset_inference_calls(policy_id)
Expand Down
4 changes: 0 additions & 4 deletions rllib/evaluation/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def compute_advantages(rollout: SampleBatch,
processed rewards.
"""

rollout_size = len(rollout[SampleBatch.ACTIONS])

assert SampleBatch.VF_PREDS in rollout or not use_critic, \
"use_critic=True but values not found"
assert use_critic or not use_gae, \
Expand Down Expand Up @@ -90,6 +88,4 @@ def compute_advantages(rollout: SampleBatch,
rollout[Postprocessing.ADVANTAGES] = rollout[
Postprocessing.ADVANTAGES].astype(np.float32)

assert all(val.shape[0] == rollout_size for key, val in rollout.items()), \
"Rollout stacked incorrectly!"
return rollout
10 changes: 6 additions & 4 deletions rllib/examples/attention_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

config = {
"env": args.env,
# This env_config is only used for the RepeatAfterMeEnv env.
"env_config": {
"repeat_delay": 2,
},
Expand All @@ -48,17 +49,18 @@
"num_workers": 0,
"num_envs_per_worker": 20,
"entropy_coeff": 0.001,
"num_sgd_iter": 5,
"num_sgd_iter": 10,
"vf_loss_coeff": 1e-5,
"model": {
"custom_model": GTrXLNet,
"max_seq_len": 50,
"custom_model_config": {
"num_transformer_units": 1,
"attn_dim": 64,
"num_heads": 2,
"memory_tau": 50,
"memory_inference": 100,
"memory_training": 50,
"head_dim": 32,
"num_heads": 2,
"ff_hidden_dim": 32,
},
},
Expand All @@ -71,7 +73,7 @@
"episode_reward_mean": args.stop_reward,
}

results = tune.run(args.run, config=config, stop=stop, verbose=1)
results = tune.run(args.run, config=config, stop=stop, verbose=2)

if args.as_test:
check_learning_achieved(results, args.stop_reward)
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/cartpole_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"episode_reward_mean": args.stop_reward,
}

results = tune.run(args.run, config=config, stop=stop, verbose=1)
results = tune.run(args.run, config=config, stop=stop, verbose=2)

if args.as_test:
check_learning_achieved(results, args.stop_reward)
Expand Down
Loading