Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Suggestion: Tutorial on Model Ensembling #876

Open
1 task done
btx0424 opened this issue Jan 28, 2023 · 2 comments
Open
1 task done

[Feature Request] Suggestion: Tutorial on Model Ensembling #876

btx0424 opened this issue Jan 28, 2023 · 2 comments
Assignees
Labels
enhancement New feature or request

Comments

@btx0424
Copy link
Contributor

btx0424 commented Jan 28, 2023

Motivation

Model ensembling is appealing in the RL context with a range of use cases, e.g., critic ensembles and parallel inference of multiple agents with the same actor structure. And I believe the design of torchrl has considered model ensembling using functorch as an important feature (according to the doc and the design of TensorDictModule). However, currently, there is not much clue in the doc/examples/tutorials on what's the best or suggested practice to actually implement it.

Tutorials or examples on either use case would be helpful.

Solution

A comprehensive example or tutorial on how to leverage model ensembling to perform efficient training and inference.

One tricky thing here is how to perform parallel optimization with an ensembled functional model, as functorch did not provide any direct solution (yet). A viable approach I figured was to use torchopt's functional optimizer API and do something like

import torchopt
import functorch
from tensordict.nn import TensorDictModule, make_functional

actors = nn.ModuleList([
  TensorDictModule(MyActorModule() for _ in range(num_agents)
]).to(device)

# this was the functorch way 
# things go similarly with tensordict.nn.make_functional
(fmodel, params, buffers) = functorch.combine_state_for_ensemble(actors)
actor_opt = torchopt.adam(lr=cfg.lr)
actor_opt_states = functorch.vmap(actor_opt.init)(params)

# to perform an optimization step
def opt_step(batch, params, actor_opt_states):
  actor_loss = functorch.vmap(actor_loss_fn)(params, batch)
  grads = torch.autograd.grad(actor_loss, params)
  updates, actor_opt_states = functorch.vmap(actor_opt.update)(grads, actor_opt_state)
  params = torchopt.apply_updates(params, updates)
  return actor_loss, params, actor_opt_states

for batch in some_collected_data:
  _, params, actor_opt_states = opt_step(batch, params, actor_opt_states)

But I am very uncertain about the efficiency of the above code (also tried getting gradients using functorch.grad but found it to be slightly slower) and have been wondering what the best practice is, especially if we want to use torchrl.

Additional Context

Some additional examples on critic ensembling, e.g., double Q-functions would also be of great help.

Checklist

  • I have checked that there is no similar issue in the repo (required)
@btx0424 btx0424 added the enhancement New feature or request label Jan 28, 2023
@vmoens
Copy link
Contributor

vmoens commented Jan 30, 2023

This is a very good idea, thanks for the suggestion

@smorad
Copy link
Contributor

smorad commented Jun 15, 2023

I would also find this very helpful! convert_to_functional seems like a very good starting point, it's a shame this is only usable in LossModules. Would it make sense to have an TensorDictModuleEnsemble class that could be applied to any nn.Module or TensorDictModule?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants