-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[RLlib] Attention Nets: tf #12753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Attention Nets: tf #12753
Changes from all commits
351f4ad
2ffc39f
3c8a73e
a2c350a
f7d4924
ee40b8f
a850a93
ad711b7
bc5193d
db88f64
c3eaec4
15b0dc6
7c4d79c
68d799b
e13fe76
9992949
8ea1352
32a04c0
8f4091d
7162bda
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import collections | ||
| from gym.spaces import Space | ||
| import logging | ||
| import math | ||
| import numpy as np | ||
| from typing import Any, List, Dict, Tuple, TYPE_CHECKING, Union | ||
|
|
||
|
|
@@ -34,6 +35,9 @@ def to_float_np_array(v: List[Any]) -> np.ndarray: | |
| return arr | ||
|
|
||
|
|
||
| _INIT_COLS = [SampleBatch.OBS] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
|
|
||
|
|
||
| class _AgentCollector: | ||
| """Collects samples for one agent in one trajectory (episode). | ||
|
|
||
|
|
@@ -45,9 +49,18 @@ class _AgentCollector: | |
|
|
||
| _next_unroll_id = 0 # disambiguates unrolls within a single episode | ||
|
|
||
| def __init__(self, shift_before: int = 0): | ||
| self.shift_before = max(shift_before, 1) | ||
| def __init__(self, view_reqs): | ||
| # Determine the size of the buffer we need for data before the actual | ||
| # episode starts. This is used for 0-buffering of e.g. prev-actions, | ||
| # or internal state inputs. | ||
| self.shift_before = -min( | ||
| (int(vr.shift.split(":")[0]) | ||
| if isinstance(vr.shift, str) else vr.shift) + | ||
| (-1 if vr.data_col in _INIT_COLS or k in _INIT_COLS else 0) | ||
| for k, vr in view_reqs.items()) | ||
| # The actual data buffers (lists holding each timestep's data). | ||
| self.buffers: Dict[str, List] = {} | ||
| # The episode ID for the agent for which we collect data. | ||
| self.episode_id = None | ||
| # The simple timestep count for this agent. Gets increased by one | ||
| # each time a (non-initial!) observation is added. | ||
|
|
@@ -137,31 +150,88 @@ def build(self, view_requirements: ViewRequirementsDict) -> SampleBatch: | |
| # -> skip. | ||
| if data_col not in self.buffers: | ||
| continue | ||
|
|
||
| # OBS are already shifted by -1 (the initial obs starts one ts | ||
| # before all other data columns). | ||
| shift = view_req.shift - \ | ||
| (1 if data_col == SampleBatch.OBS else 0) | ||
| obs_shift = -1 if data_col == SampleBatch.OBS else 0 | ||
|
|
||
| # Keep an np-array cache so we don't have to regenerate the | ||
| # np-array for different view_cols using to the same data_col. | ||
| if data_col not in np_data: | ||
| np_data[data_col] = to_float_np_array(self.buffers[data_col]) | ||
| # Shift is exactly 0: Send trajectory as is. | ||
| if shift == 0: | ||
| data = np_data[data_col][self.shift_before:] | ||
| # Shift is positive: We still need to 0-pad at the end here. | ||
| elif shift > 0: | ||
| data = to_float_np_array( | ||
| self.buffers[data_col][self.shift_before + shift:] + [ | ||
| np.zeros( | ||
| shape=view_req.space.shape, | ||
| dtype=view_req.space.dtype) for _ in range(shift) | ||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is becoming more complex as we now allow for more sophisticated view requirements (ranges of timesteps). |
||
| # Range of indices on time-axis, e.g. "-50:-1". Together with | ||
| # the `batch_repeat_value`, this determines the data produced. | ||
| # Example: | ||
| # batch_repeat_value=10, shift_from=-3, shift_to=-1 | ||
| # buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] | ||
| # resulting data=[[-3, -2, -1], [7, 8, 9]] | ||
| # Range of 3 consecutive items repeats every 10 timesteps. | ||
| if view_req.shift_from is not None: | ||
| if view_req.batch_repeat_value > 1: | ||
| count = int( | ||
| math.ceil((len(np_data[data_col]) - self.shift_before) | ||
| / view_req.batch_repeat_value)) | ||
| data = np.asarray([ | ||
| np_data[data_col][self.shift_before + | ||
| (i * view_req.batch_repeat_value) + | ||
| view_req.shift_from + | ||
| obs_shift:self.shift_before + | ||
| (i * view_req.batch_repeat_value) + | ||
| view_req.shift_to + 1 + obs_shift] | ||
| for i in range(count) | ||
| ]) | ||
| # Shift is negative: Shift into the already existing and 0-padded | ||
| # "before" area of our buffers. | ||
| else: | ||
| data = np_data[data_col][self.shift_before + | ||
| view_req.shift_from + | ||
| obs_shift:self.shift_before + | ||
| view_req.shift_to + 1 + obs_shift] | ||
| # Set of (probably non-consecutive) indices. | ||
| # Example: | ||
| # shift=[-3, 0] | ||
| # buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] | ||
| # resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...] | ||
| elif isinstance(view_req.shift, np.ndarray): | ||
| data = np_data[data_col][self.shift_before + obs_shift + | ||
| view_req.shift] | ||
| # Single shift int value. Use the trajectory as-is, and if | ||
| # `shift` != 0: shifted by that value. | ||
| else: | ||
| data = np_data[data_col][self.shift_before + shift:shift] | ||
| shift = view_req.shift + obs_shift | ||
|
|
||
| # Batch repeat (only provide a value every n timesteps). | ||
| if view_req.batch_repeat_value > 1: | ||
| count = int( | ||
| math.ceil((len(np_data[data_col]) - self.shift_before) | ||
| / view_req.batch_repeat_value)) | ||
| data = np.asarray([ | ||
| np_data[data_col][self.shift_before + ( | ||
| i * view_req.batch_repeat_value) + shift] | ||
| for i in range(count) | ||
| ]) | ||
| # Shift is exactly 0: Use trajectory as is. | ||
| elif shift == 0: | ||
| data = np_data[data_col][self.shift_before:] | ||
| # Shift is positive: We still need to 0-pad at the end. | ||
| elif shift > 0: | ||
| data = to_float_np_array( | ||
| self.buffers[data_col][self.shift_before + shift:] + [ | ||
| np.zeros( | ||
| shape=view_req.space.shape, | ||
| dtype=view_req.space.dtype) | ||
| for _ in range(shift) | ||
| ]) | ||
| # Shift is negative: Shift into the already existing and | ||
| # 0-padded "before" area of our buffers. | ||
| else: | ||
| data = np_data[data_col][self.shift_before + shift:shift] | ||
|
|
||
| if len(data) > 0: | ||
| batch_data[view_col] = data | ||
|
|
||
| batch = SampleBatch(batch_data) | ||
| # Due to possible batch-repeats > 1, columns in the resulting batch | ||
| # may not all have the same batch size. | ||
| batch = SampleBatch(batch_data, _dont_check_lens=True) | ||
|
|
||
| # Add EPS_ID and UNROLL_ID to batch. | ||
| batch.data[SampleBatch.EPS_ID] = np.repeat(self.episode_id, | ||
|
|
@@ -230,15 +300,22 @@ class _PolicyCollector: | |
| appended to this policy's buffers. | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| """Initializes a _PolicyCollector instance.""" | ||
| def __init__(self, policy): | ||
| """Initializes a _PolicyCollector instance. | ||
|
|
||
| Args: | ||
| policy (Policy): The policy object. | ||
| """ | ||
|
|
||
| self.buffers: Dict[str, List] = collections.defaultdict(list) | ||
| self.policy = policy | ||
| # The total timestep count for all agents that use this policy. | ||
| # NOTE: This is not an env-step count (across n agents). AgentA and | ||
| # agentB, both using this policy, acting in the same episode and both | ||
| # doing n steps would increase the count by 2*n. | ||
| self.agent_steps = 0 | ||
| # Seq-lens list of already added agent batches. | ||
| self.seq_lens = [] if policy.is_recurrent() else None | ||
|
|
||
| def add_postprocessed_batch_for_training( | ||
| self, batch: SampleBatch, | ||
|
|
@@ -257,11 +334,18 @@ def add_postprocessed_batch_for_training( | |
| # 1) If col is not in view_requirements, we must have a direct | ||
| # child of the base Policy that doesn't do auto-view req creation. | ||
| # 2) Col is in view-reqs and needed for training. | ||
| if view_col not in view_requirements or \ | ||
| view_requirements[view_col].used_for_training: | ||
| view_req = view_requirements.get(view_col) | ||
| if view_req is None or view_req.used_for_training: | ||
| self.buffers[view_col].extend(data) | ||
| # Add the agent's trajectory length to our count. | ||
| self.agent_steps += batch.count | ||
| # Adjust the seq-lens array depending on the incoming agent sequences. | ||
| if self.seq_lens is not None: | ||
| max_seq_len = self.policy.config["model"]["max_seq_len"] | ||
| count = batch.count | ||
| while count > 0: | ||
| self.seq_lens.append(min(count, max_seq_len)) | ||
| count -= max_seq_len | ||
|
|
||
| def build(self): | ||
| """Builds a SampleBatch for this policy from the collected data. | ||
|
|
@@ -273,20 +357,22 @@ def build(self): | |
| this policy. | ||
| """ | ||
| # Create batch from our buffers. | ||
| batch = SampleBatch(self.buffers) | ||
| assert SampleBatch.UNROLL_ID in batch.data | ||
| batch = SampleBatch( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. internal states are already time-chunked (only one internal state value every |
||
| self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True) | ||
| # Clear buffers for future samples. | ||
| self.buffers.clear() | ||
| # Reset agent steps to 0. | ||
| # Reset agent steps to 0 and seq-lens to empty list. | ||
| self.agent_steps = 0 | ||
| if self.seq_lens is not None: | ||
| self.seq_lens = [] | ||
| return batch | ||
|
|
||
|
|
||
| class _PolicyCollectorGroup: | ||
| def __init__(self, policy_map): | ||
| self.policy_collectors = { | ||
| pid: _PolicyCollector() | ||
| for pid in policy_map.keys() | ||
| pid: _PolicyCollector(policy) | ||
| for pid, policy in policy_map.items() | ||
| } | ||
| # Total env-steps (1 env-step=up to N agents stepped). | ||
| self.env_steps = 0 | ||
|
|
@@ -396,11 +482,14 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, | |
| self.agent_key_to_policy_id[agent_key] = policy_id | ||
| else: | ||
| assert self.agent_key_to_policy_id[agent_key] == policy_id | ||
| policy = self.policy_map[policy_id] | ||
| view_reqs = policy.model.inference_view_requirements if \ | ||
| getattr(policy, "model", None) else policy.view_requirements | ||
|
|
||
| # Add initial obs to Trajectory. | ||
| assert agent_key not in self.agent_collectors | ||
| # TODO: determine exact shift-before based on the view-req shifts. | ||
| self.agent_collectors[agent_key] = _AgentCollector() | ||
| self.agent_collectors[agent_key] = _AgentCollector(view_reqs) | ||
| self.agent_collectors[agent_key].add_init_obs( | ||
| episode_id=episode.episode_id, | ||
| agent_index=episode._agent_index(agent_id), | ||
|
|
@@ -466,11 +555,19 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ | |
| for view_col, view_req in view_reqs.items(): | ||
| # Create the batch of data from the different buffers. | ||
| data_col = view_req.data_col or view_col | ||
| time_indices = \ | ||
| view_req.shift - ( | ||
| 1 if data_col in [SampleBatch.OBS, "t", "env_id", | ||
| SampleBatch.AGENT_INDEX] else 0) | ||
| delta = -1 if data_col in [ | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also more complex now to build an action input dict: Inputs could be over some range of time steps. |
||
| SampleBatch.OBS, "t", "env_id", SampleBatch.EPS_ID, | ||
| SampleBatch.AGENT_INDEX | ||
| ] else 0 | ||
| # Range of shifts, e.g. "-100:0". Note: This includes index 0! | ||
| if view_req.shift_from is not None: | ||
| time_indices = (view_req.shift_from + delta, | ||
| view_req.shift_to + delta) | ||
| # Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0]. | ||
| else: | ||
| time_indices = view_req.shift + delta | ||
| data_list = [] | ||
| # Loop through agents and add-up their data (batch). | ||
| for k in keys: | ||
| if data_col == SampleBatch.EPS_ID: | ||
| data_list.append(self.agent_collectors[k].episode_id) | ||
|
|
@@ -482,7 +579,15 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ | |
| self.agent_collectors[k]._build_buffers({ | ||
| data_col: fill_value | ||
| }) | ||
| data_list.append(buffers[k][data_col][time_indices]) | ||
| if isinstance(time_indices, tuple): | ||
| if time_indices[1] == -1: | ||
| data_list.append( | ||
| buffers[k][data_col][time_indices[0]:]) | ||
| else: | ||
| data_list.append(buffers[k][data_col][time_indices[ | ||
| 0]:time_indices[1] + 1]) | ||
| else: | ||
| data_list.append(buffers[k][data_col][time_indices]) | ||
| input_dict[view_col] = np.array(data_list) | ||
|
|
||
| self._reset_inference_calls(policy_id) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 love it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking ahead, is it possible we unify model.from_batch() and model.get_input_dict()? It seems currently only PPO is using this get_input_dict method so maybe we should remove it.
Thinking...
model.from_batch(sample_batch, batch_index="last")would be pretty clean. This is probably out of scope of this PR though.