-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[RLlib] Fix performance and functionality flaws in attention nets (via Trajectory view API). #11729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Fix performance and functionality flaws in attention nets (via Trajectory view API). #11729
Conversation
…ectory_view_api_attention_nets
…nto trajectory_view_api_attention_nets # Conflicts: # rllib/models/tf/attention_net.py # rllib/policy/view_requirement.py
…ectory_view_api_enable_by_default_for_all_simple
…ectory_view_api_attention_nets
…ectory_view_api_enable_by_default_for_all_simple
…ectory_view_api_attention_nets # Conflicts: # rllib/agents/ppo/ppo_tf_policy.py # rllib/evaluation/tests/test_trajectory_view_api.py # rllib/policy/tf_policy_template.py # rllib/utils/tf_ops.py
…on_nets # Conflicts: # rllib/agents/ppo/ppo_tf_policy.py # rllib/evaluation/collectors/simple_list_collector.py # rllib/evaluation/tests/test_trajectory_view_api.py # rllib/policy/dynamic_tf_policy.py # rllib/policy/policy.py # rllib/policy/sample_batch.py # rllib/policy/view_requirement.py
| def __init__(self, shift_before: int = 0): | ||
| self.shift_before = max(shift_before, 1) | ||
| def __init__(self, view_reqs): | ||
| self.shift_before = -min( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might want to add comment to describe what this code does!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| def add_init_obs(self, episode_id: EpisodeID, agent_index: int, | ||
| env_id: EnvID, t: int, init_obs: TensorType, | ||
| view_requirements: Dict[str, ViewRequirement]) -> None: | ||
| """Adds an initial observation (after reset) to the Agent's trajectory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change the description. It adds more than a single observation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it doesn't it's really just adds a single one. Same as it used to work w/ SampleBatchBuilder.
| / view_req.batch_repeat_value)) | ||
| repeat_count = (view_req.data_rel_pos_to - | ||
| view_req.data_rel_pos_from + 1) | ||
| data = np.asarray([ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, big confused. Add comments on what these lines of code do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Provided an example.
| shift = view_req.data_rel_pos + obs_shift | ||
| # Shift is exactly 0: Use trajectory as is. | ||
| if shift == 0: | ||
| data = np_data[data_col][self.shift_before:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Provided an example.
| [np.zeros(shape=shape, dtype=dtype) | ||
| for _ in range(shift)] | ||
|
|
||
| def _get_input_dict(self, view_reqs, abs_pos: int = -1) -> \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add description what this method does
| batch = SampleBatch(self.buffers) | ||
| assert SampleBatch.UNROLL_ID in batch.data | ||
| batch = SampleBatch( | ||
| self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is _dont_check_lens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added explanation.
| return batch | ||
|
|
||
|
|
||
| class _PolicyCollectorGroup: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably add comments on what this is
rllib/policy/sample_batch.py
Outdated
| for i, seq_len in enumerate(self.seq_lens): | ||
| count += seq_len | ||
| if count >= end: | ||
| data["state_in_0"] = self.data["state_in_0"][state_start: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment on what this does
| # Range of indices on time-axis, make sure to create | ||
| if view_req.data_rel_pos_from is not None: | ||
| ret[view_col] = np.zeros_like([[ | ||
| view_req.space.sample() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same add comment here
| return x | ||
|
|
||
|
|
||
| class PositionalEmbedding(tf.keras.layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comments on what this does (how it initializes embedding per position based on cos/sin something)
…ectory_view_api_attention_nets # Conflicts: # rllib/agents/trainer.py # rllib/evaluation/collectors/simple_list_collector.py # rllib/evaluation/tests/test_trajectory_view_api.py # rllib/models/tf/attention_net.py # rllib/policy/policy.py # rllib/policy/torch_policy_template.py # rllib/policy/view_requirement.py # src/ray/raylet/node_manager.cc
…ectory_view_api_attention_nets # Conflicts: # rllib/agents/trainer.py # rllib/evaluation/collectors/simple_list_collector.py # rllib/evaluation/tests/test_trajectory_view_api.py # rllib/models/tf/attention_net.py # rllib/policy/policy.py # rllib/policy/torch_policy_template.py # rllib/policy/view_requirement.py # src/ray/raylet/node_manager.cc
…ectory_view_api_attention_nets � Conflicts: � rllib/agents/ppo/appo_tf_policy.py � rllib/agents/ppo/ppo_torch_policy.py � rllib/agents/qmix/model.py � rllib/evaluation/collectors/simple_list_collector.py � rllib/evaluation/rollout_worker.py � rllib/evaluation/tests/test_trajectory_view_api.py � rllib/models/modelv2.py � rllib/policy/dynamic_tf_policy.py � rllib/policy/policy.py � rllib/policy/sample_batch.py � rllib/policy/view_requirement.py
…ectory_view_api_attention_nets
…ectory_view_api_attention_nets
…ectory_view_api_attention_nets
…ectory_view_api_attention_nets
|
Moved here: |
Why are these changes needed?
Related issue number
Checks
scripts/format.shto lint the changes in this PR.