You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Model ensembling is appealing in the RL context with a range of use cases, e.g., critic ensembles and parallel inference of multiple agents with the same actor structure. And I believe the design of torchrl has considered model ensembling using functorch as an important feature (according to the doc and the design of TensorDictModule). However, currently, there is not much clue in the doc/examples/tutorials on what's the best or suggested practice to actually implement it.
Tutorials or examples on either use case would be helpful.
Solution
A comprehensive example or tutorial on how to leverage model ensembling to perform efficient training and inference.
One tricky thing here is how to perform parallel optimization with an ensembled functional model, as functorch did not provide any direct solution (yet). A viable approach I figured was to use torchopt's functional optimizer API and do something like
importtorchoptimportfunctorchfromtensordict.nnimportTensorDictModule, make_functionalactors=nn.ModuleList([
TensorDictModule(MyActorModule() for_inrange(num_agents)
]).to(device)
# this was the functorch way # things go similarly with tensordict.nn.make_functional
(fmodel, params, buffers) =functorch.combine_state_for_ensemble(actors)
actor_opt=torchopt.adam(lr=cfg.lr)
actor_opt_states=functorch.vmap(actor_opt.init)(params)
# to perform an optimization stepdefopt_step(batch, params, actor_opt_states):
actor_loss=functorch.vmap(actor_loss_fn)(params, batch)
grads=torch.autograd.grad(actor_loss, params)
updates, actor_opt_states=functorch.vmap(actor_opt.update)(grads, actor_opt_state)
params=torchopt.apply_updates(params, updates)
returnactor_loss, params, actor_opt_statesforbatchinsome_collected_data:
_, params, actor_opt_states=opt_step(batch, params, actor_opt_states)
But I am very uncertain about the efficiency of the above code (also tried getting gradients using functorch.grad but found it to be slightly slower) and have been wondering what the best practice is, especially if we want to use torchrl.
Additional Context
Some additional examples on critic ensembling, e.g., double Q-functions would also be of great help.
Checklist
I have checked that there is no similar issue in the repo (required)
The text was updated successfully, but these errors were encountered:
I would also find this very helpful! convert_to_functional seems like a very good starting point, it's a shame this is only usable in LossModules. Would it make sense to have an TensorDictModuleEnsemble class that could be applied to any nn.Module or TensorDictModule?
Motivation
Model ensembling is appealing in the RL context with a range of use cases, e.g., critic ensembles and parallel inference of multiple agents with the same actor structure. And I believe the design of torchrl has considered model ensembling using functorch as an important feature (according to the doc and the design of TensorDictModule). However, currently, there is not much clue in the doc/examples/tutorials on what's the best or suggested practice to actually implement it.
Tutorials or examples on either use case would be helpful.
Solution
A comprehensive example or tutorial on how to leverage model ensembling to perform efficient training and inference.
One tricky thing here is how to perform parallel optimization with an ensembled functional model, as functorch did not provide any direct solution (yet). A viable approach I figured was to use torchopt's functional optimizer API and do something like
But I am very uncertain about the efficiency of the above code (also tried getting gradients using
functorch.grad
but found it to be slightly slower) and have been wondering what the best practice is, especially if we want to use torchrl.Additional Context
Some additional examples on critic ensembling, e.g., double Q-functions would also be of great help.
Checklist
The text was updated successfully, but these errors were encountered: