Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update priority #28

Open
tphanson opened this issue Jan 22, 2021 · 7 comments
Open

Update priority #28

tphanson opened this issue Jan 22, 2021 · 7 comments
Assignees

Comments

@tphanson
Copy link

Since I followed PER which is described in this paper, I believe that there must exist a way to update priority every time we compute TD error on every minibatch. However, I carefully looked at all your examples but it seems no function like this existing.

@acassirer
Copy link
Contributor

Hey Tu,

I was surprised to see that we indeed do no have any examples of how to update priorities in the example. The method you are looking for is Client.mutate_priorities. There is also an in-graph version in TFClient.update_priorities but you probably want to focus on the former.

I'll keep this issue until the examples have been updated but an example of how to update the priorities could look something like this:

# Batches 2 sequences together.
# Shapes of items is now [2, 3, 10, 10].
dataset = dataset.batch(2)

for sample in dataset.take(1):
  # Results in the following format.
  print(sample.info.key)          # ([2, 3], uint64)
  print(sample.info.probability)  # ([2, 3], float64)
  
  observation, action = sample.data
  print(observation)              # ([2, 3, 10, 10], uint8)
  print(action)                   # ([2, 3, 2], float32)

  # Set the priority of all sampled items to 0.5.
  client.mutate_priorities('my_table', {
      int(key[0]): 0.5            # key[0] => uint64 scalar.
      for key in sample.info.key  # for ([3], uint64) in ([2, 3], uint64)
  })

@sontuphan
Copy link

Thank you @acassirer. That's very helpful to me.

@sontuphan
Copy link

Thank you @acassirer. That's very helpful to me.

Just notice that I used another github account to reply you :) It's still me who posted the question. Thanks again.

@ebrevdo
Copy link
Collaborator

ebrevdo commented May 25, 2021

Leaving this open until Albin updates examples.

@abhishekbisht1429
Copy link

abhishekbisht1429 commented Sep 4, 2022

Hi @acassirer, I am using ReverbReplayBuffer from tf_agents and obtained the dataset using its as_dataset function. However sample.info gives me 'tuple' object has no attribute 'info' error. Could you please tell me how can I obtain keys to update priorities in tf_agents.replay_buffers.ReverbReplayBuffer ?

@acassirer
Copy link
Contributor

Hey @abhishekbisht1429,

I can't speak for the transformations that tf_agents are applying to the output of the Reverb dataset but my guess is that they lost the data container types along the way. If this is the case then I would expect the priorities to be found in sample[0][0] (as this would be the same as doing sample.info.key on the named tuple).

@abhishekbisht1429
Copy link

abhishekbisht1429 commented Sep 5, 2022

Thanks @acassirer for your comment.

After some search and trials I figured it out, leaving this comment here in case someone stumbles upon this problem. Each item in the dataset returned by ReverbReplayBuffer.as_dataset() is a 2-tuple, the first element of which is the actual data and the second element is a named tuple containing auxiliary information about the data. To access the keys corresponding to data item one can use sample[1].key where sample is the item contained the Dataset object returned by ReverbReplayBuffer.as_dataset() function. It will be a tensor of the same shape as the data. So one might need to reshape it in order to use it with ReverbReplayBuffer.update_priorities.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants