-
Notifications
You must be signed in to change notification settings - Fork 704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX TD3 prototype #225
JAX TD3 prototype #225
Conversation
The latest updates on your projects. Learn more about Vercel for Git ↗︎
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. Looks great!
cleanrl/td3_continuous_action_jax.py
Outdated
# TODO Maybe generate a lot of random keys right in the beginning | ||
# also check https://jax.readthedocs.io/en/latest/jax.random.html | ||
key, noise_key = jax.random.split(key, 2) | ||
clipped_noise = jnp.clip((jax.random.normal( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clipped_noise
is not actually used. Also, maybe generating it with numpy is a little bit faster? With jnp
we would need to jit function probably. Would you mind doing a speed test like %timeit
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added it and fixed clipping, will test what's the most efficient way to generate noise soonish
cleanrl/td3_continuous_action_jax.py
Outdated
(qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss, | ||
has_aux=True)(qf1_state.params, qf1) | ||
(qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss, | ||
has_aux=True)(qf2_state.params, qf2) | ||
qf1_state = qf1_state.apply_gradients(grads=grads1) | ||
qf2_state = qf2_state.apply_gradients(grads=grads2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are doing grad passes twice. Would it be faster to have them share the same optimizer as done in here?
cleanrl/cleanrl/ppo_continuous_action_envpool_jax.py
Lines 213 to 217 in 399f9a3
agent_params = AgentParams( | |
actor_params, | |
critic_params, | |
) | |
agent_optimizer_state = agent_optimizer.init(agent_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, I added some changes. Overall it looks pretty good, but I don't know why the experiments do not work yet...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great addition.
Looking good on my side
@joaogui1 could you take a final look at https://cleanrl-git-fork-joaogui1-master-vwxyzjn.vercel.app/rl-algorithms/td3/#td3_continuous_action_jaxpy to see if there is anything missing? |
LGTM! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for this contribution!
Description
Closes #218
Initiali implementation, needs testing
Types of changes
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.
--capture-video
flag toggled on (required).mkdocs serve
.width=500
andheight=300
).