-
Notifications
You must be signed in to change notification settings - Fork 75
JAX code #30
Comments
Releasing the JAX code might take some time but it should be easy to modify existing dopamine agents. In the meanwhile, here are some tips to get started with jax agents:
|
Thank you very much! May I ask if I can also run the code in TPU-VM with JAX? Best, |
I think so -- you probably want to use the tfds datasets or use much larger batch sizes with the dopamine codebase. |
Thank you very much! I'll have a try. |
Dear agarwl, I try to follow your provided code and reproduce the results of offline dqn based on jax. I find the training speed of jax is quite slow compared with TensorFlow. May I ask the possible reason about that. I try to change these parts in the vanilla dopamine code:
(2) creat offline buffer: fixed_replay_buffer.py (3) create OfflineJaxDQNAgent:
(4) I try to compare the difference between the jax code and vanilla tf code, I find they use different repaly buffer (FixedReplayBuffer in JAX and WrappedFixedReplayBuffer in TF). I'm not sure whether this is the main reason. Best |
Hi I find the update_period is 1 and the tf code is 4. Maybe that is the main reason. |
Yeah, update_period 1 corresponds to 1 gradient step every environment step (default is 4 which corresponds to 1 grad step every env step). In each iteration, we do 62.5K grad steps, so we can also set num_training_steps to 62.5K with update period 1. |
Hi @agarwl Thanks for your reply. I will try it. By the way, I would like to ask can I run the TF code on TPU-VM? Since I find TF is still a little bit faster. |
Sure -- you may not see much benefit of using TPUs (due to small batch size and dopamine replay) but the code be run on TPU. |
Here's some JAX code for reference: https://github.com/google/dopamine/tree/master/dopamine/labs/offline_rl |
@agarwl Thank you very much! I will have a try. Thanks! |
Hi,
I would like to ask whether there is a jax-based code.
And whether there are some recommendations about jax-based offline rl algorithms.
Thanks!
The text was updated successfully, but these errors were encountered: