Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Skip existing for advantage modules #1070

Merged
merged 1 commit into from
Apr 18, 2023
Merged

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Apr 18, 2023

Allows to ask the advantage modules to skip the value network if its output is already present.

@matteobettini @albertbou92
Curious to hear your thoughts about this:
Should all modules skip existing keys by defaults? This could have some unexpected consequences if someone does this

collector = SyncDataCollector(env, policy=actor_cricic, ...) # the collector will also return the value!
for data in collector:
    loss(data)  # the critic is not queried bc the key is already there 

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 18, 2023
@vmoens vmoens added the enhancement New feature or request label Apr 18, 2023
@albertbou92
Copy link
Contributor

What about the other way around? by default not skipping.

In that case we should make sure in the advantage module the value is computed without gradients.

@matteobettini
Copy link
Contributor

matteobettini commented Apr 18, 2023

Could you explain a bit more why your example constitutes a problem?
Wouldn't the value computed during rollout be the same if we just compute it later in the loss?

Also I wonder if this is a small step in the direction of #1000.
Because here we are allowing to precompute values (maybe on bigger batches) outside losse forward.
And in #1000 we want to allow also precomputing targets (maybe on bigger batches) outside loss forward.

PS in #1000 advantages would not even have access to neural networks. Only losses will. and thus it will make all this even simpler

@vmoens
Copy link
Contributor Author

vmoens commented Apr 18, 2023

Could you explain a bit more why your example constitutes a problem? Wouldn't the value computed during rollout be the same if we just compute it later in the loss?

Also I wonder if this is a small step in the direction of #1000. Because here we are allowing to precompute values (maybe on bigger batches) outside losse forward. And in #1000 we want to allow also precomputing targets (maybe on bigger batches) outside loss forward.

PS in #1000 advantages would not even have access to neural networks. Only losses will. and thus it will make all this even simpler

you can perfectly not pass a network and it will work fine.
I'd like to keep the networks bc it's much easier, otherwise we are decomposing things and we'll be asking users to do

value_net(td)
value_net(td.get("next"))
gae(td)

instead of

gae(td)

@matteobettini
Copy link
Contributor

matteobettini commented Apr 18, 2023

Yep i didn't see that we could pass None, that works perfectly.

Thinking about this more, if losses start skipping existing keys by default, this may have some bad consequences if collected data (from collectors or buffers) has some of the loss keys of interest (cause they might be old).

But thinking of #1000 we definitely need a way to precomupte keys for losses (values and targets) and make losses use those precomputed keys.

I think it would be fine to use skip_existing by default also in the losses and if users feed the losses data with old values it is their problem.

batch = rb.sample(60_000)
loss_module.compute_value_target(batch) # if i didn't write this line and somehow the user had some old targets in their data it is their fault
for _ in range(n):
    minibatch = subsample(batch, 1000)
    loss_vals = loss_module(minibatch) # by default uses the precomputed targets

Alternatively, what @albertbou92 says also makes sense, the losses could recompute everything by default and we would need to

batch = rb.sample(60_000)
loss_module.compute_value_target(batch)
for _ in range(n):
    minibatch = subsample(batch, 1000)
    loss_vals = loss_module(minibatch, skip_exisitng=True) # we need to enforce it to not recompute the targets

I am leaning more towards the first option tho

@vmoens
Copy link
Contributor Author

vmoens commented Apr 18, 2023

Thinking about this more, if losses start skipping existing keys by default, this may have some bad consequences if collected data (from collectors or buffers) has some of the loss keys of interest (cause they might be old).

yes this is why I set it to False, and why I was asking you guys what you were thinking. I guess we'll leave it as it is.

But thinking of #1000 we definitely need a way to precomupte keys for losses (values and targets) and make losses use those precomputed keys.

if you don't pass the value net to the module an error will be raised bc the entry is missing.

I think it would be fine to use that skip_existing by default also in the losses and if users feed the losses data with old values it is their problem.

yep that's the idea.
There are 2 things you could do

with set_skip_existing(True):
    loss_fn(td)

or

loss_fn.make_value_estimator(..., skip_existing=True)

@matteobettini
Copy link
Contributor

I am more for the first of the 2 since you are not bound to smth at init time.

Regarding the default of skip_existing i am leaning more towards True but I am fine with anything

@vmoens
Copy link
Contributor Author

vmoens commented Apr 18, 2023

I am more for the first of the 2 since you are not bound to smth at init time.

then use that one 😉

Regarding the default of skip_existing i am leaning more towards True but I am fine with anything

Let's put it that way:
worst that can happen when True: wrong value -> no training -> torchrl sucks
worst that can happen when False: slow training -> good training, but slow -> issue

@matteobettini
Copy link
Contributor

then use that one 😉

ahah i thought we were choosing alernatively but both is better ahaha

worst that can happen when True: wrong value -> no training -> torchrl sucks
worst that can happen when False: slow training -> good training, but slow -> issue

ok im sold

@albertbou92
Copy link
Contributor

albertbou92 commented Apr 18, 2023

I am more for the first of the 2 since you are not bound to smth at init time.

then use that one 😉

Regarding the default of skip_existing i am leaning more towards True but I am fine with anything

Let's put it that way: worst that can happen when True: wrong value -> no training -> torchrl sucks worst that can happen when False: slow training -> good training, but slow -> issue

I like the idea of having something that works out of the box, even if there is margin for improvement in terms of performance. Beginner users will be able to train and more advanced users will be able to adjust to code for better performance.

@vmoens vmoens merged commit 55976e4 into main Apr 18, 2023
@vmoens vmoens deleted the skip_existing_adv branch April 18, 2023 14:55
albertbou92 pushed a commit to PyTorchRL/rl that referenced this pull request Apr 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants