Skip to content

Conversation

@sven1977
Copy link
Contributor

@sven1977 sven1977 commented Oct 31, 2020

  • RLlib's attention nets (GTrXL) have been forced so far to run "inside" RLlib's RNN API (previous internal states are being passed as new state-ins in subsequent timesteps). This is not favorable for attention nets, which need a different handling and time-slicing of past states (attention net's memory). The trajectory view API allows for specifying the needed time-step ranges for forward passes and batched train passes through attention nets.
  • Besides the above, the handling of the tau-memory of attention nets was also not correct. This PR fixes existing bugs.
  • In a follow up PR, the torch version of GTrXL will be fully included in the testing as well (to make sure it's 100% en-par with the tf version).

Why are these changes needed?

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

…ectory_view_api_attention_nets

# Conflicts:
#	rllib/agents/ppo/ppo_tf_policy.py
#	rllib/evaluation/tests/test_trajectory_view_api.py
#	rllib/policy/tf_policy_template.py
#	rllib/utils/tf_ops.py
…on_nets

# Conflicts:
#	rllib/agents/ppo/ppo_tf_policy.py
#	rllib/evaluation/collectors/simple_list_collector.py
#	rllib/evaluation/tests/test_trajectory_view_api.py
#	rllib/policy/dynamic_tf_policy.py
#	rllib/policy/policy.py
#	rllib/policy/sample_batch.py
#	rllib/policy/view_requirement.py
def __init__(self, shift_before: int = 0):
self.shift_before = max(shift_before, 1)
def __init__(self, view_reqs):
self.shift_before = -min(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to add comment to describe what this code does!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
env_id: EnvID, t: int, init_obs: TensorType,
view_requirements: Dict[str, ViewRequirement]) -> None:
"""Adds an initial observation (after reset) to the Agent's trajectory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the description. It adds more than a single observation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it doesn't it's really just adds a single one. Same as it used to work w/ SampleBatchBuilder.

/ view_req.batch_repeat_value))
repeat_count = (view_req.data_rel_pos_to -
view_req.data_rel_pos_from + 1)
data = np.asarray([
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, big confused. Add comments on what these lines of code do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Provided an example.

shift = view_req.data_rel_pos + obs_shift
# Shift is exactly 0: Use trajectory as is.
if shift == 0:
data = np_data[data_col][self.shift_before:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Provided an example.

[np.zeros(shape=shape, dtype=dtype)
for _ in range(shift)]

def _get_input_dict(self, view_reqs, abs_pos: int = -1) -> \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add description what this method does

batch = SampleBatch(self.buffers)
assert SampleBatch.UNROLL_ID in batch.data
batch = SampleBatch(
self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is _dont_check_lens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added explanation.

return batch


class _PolicyCollectorGroup:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably add comments on what this is

for i, seq_len in enumerate(self.seq_lens):
count += seq_len
if count >= end:
data["state_in_0"] = self.data["state_in_0"][state_start:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment on what this does

# Range of indices on time-axis, make sure to create
if view_req.data_rel_pos_from is not None:
ret[view_col] = np.zeros_like([[
view_req.space.sample()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same add comment here

return x


class PositionalEmbedding(tf.keras.layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments on what this does (how it initializes embedding per position based on cos/sin something)

…ectory_view_api_attention_nets

# Conflicts:
#	rllib/agents/trainer.py
#	rllib/evaluation/collectors/simple_list_collector.py
#	rllib/evaluation/tests/test_trajectory_view_api.py
#	rllib/models/tf/attention_net.py
#	rllib/policy/policy.py
#	rllib/policy/torch_policy_template.py
#	rllib/policy/view_requirement.py
#	src/ray/raylet/node_manager.cc
…ectory_view_api_attention_nets

# Conflicts:
#	rllib/agents/trainer.py
#	rllib/evaluation/collectors/simple_list_collector.py
#	rllib/evaluation/tests/test_trajectory_view_api.py
#	rllib/models/tf/attention_net.py
#	rllib/policy/policy.py
#	rllib/policy/torch_policy_template.py
#	rllib/policy/view_requirement.py
#	src/ray/raylet/node_manager.cc
…ectory_view_api_attention_nets

� Conflicts:
�	rllib/agents/ppo/appo_tf_policy.py
�	rllib/agents/ppo/ppo_torch_policy.py
�	rllib/agents/qmix/model.py
�	rllib/evaluation/collectors/simple_list_collector.py
�	rllib/evaluation/rollout_worker.py
�	rllib/evaluation/tests/test_trajectory_view_api.py
�	rllib/models/modelv2.py
�	rllib/policy/dynamic_tf_policy.py
�	rllib/policy/policy.py
�	rllib/policy/sample_batch.py
�	rllib/policy/view_requirement.py
@sven1977 sven1977 closed this Dec 10, 2020
@sven1977
Copy link
Contributor Author

Moved here:
#12753

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RLlib Trajectory View API] Trajectory View API works with our Attention Nets.

2 participants