Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX TD3 prototype #225

Merged
merged 23 commits into from
Jul 31, 2022
Merged

JAX TD3 prototype #225

merged 23 commits into from
Jul 31, 2022

Conversation

joaogui1
Copy link
Collaborator

@joaogui1 joaogui1 commented Jun 29, 2022

Description

Closes #218
Initiali implementation, needs testing

Types of changes

  • New feature
  • New algorithm

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the documentation and previewed the changes via mkdocs serve.
  • I have updated the tests accordingly (if applicable).

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.

  • I have contacted @vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
  • I have tracked applicable experiments in openrlbenchmark/cleanrl with --capture-video flag toggled on (required).
  • I have added additional documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers (if applicable).
    • I have added links to the PR related to the algorithm.
    • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves (in PNG format with width=500 and height=300).
    • I have added links to the tracked experiments.
    • I have updated the overview sections at the docs and the repo
  • I have updated the tests accordingly (if applicable).

@vercel
Copy link

vercel bot commented Jun 29, 2022

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Jul 31, 2022 at 7:09PM (UTC)

@vwxyzjn vwxyzjn changed the title JAX TD3 prototypw JAX TD3 prototype Jun 30, 2022
Copy link
Owner

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Looks great!

# TODO Maybe generate a lot of random keys right in the beginning
# also check https://jax.readthedocs.io/en/latest/jax.random.html
key, noise_key = jax.random.split(key, 2)
clipped_noise = jnp.clip((jax.random.normal(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clipped_noise is not actually used. Also, maybe generating it with numpy is a little bit faster? With jnp we would need to jit function probably. Would you mind doing a speed test like %timeit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added it and fixed clipping, will test what's the most efficient way to generate noise soonish

cleanrl/td3_continuous_action_jax.py Outdated Show resolved Hide resolved
cleanrl/td3_continuous_action_jax.py Outdated Show resolved Hide resolved
Comment on lines 212 to 217
(qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss,
has_aux=True)(qf1_state.params, qf1)
(qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss,
has_aux=True)(qf2_state.params, qf2)
qf1_state = qf1_state.apply_gradients(grads=grads1)
qf2_state = qf2_state.apply_gradients(grads=grads2)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are doing grad passes twice. Would it be faster to have them share the same optimizer as done in here?

agent_params = AgentParams(
actor_params,
critic_params,
)
agent_optimizer_state = agent_optimizer.init(agent_params)

Copy link
Owner

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I added some changes. Overall it looks pretty good, but I don't know why the experiments do not work yet...

cleanrl/td3_continuous_action_jax.py Outdated Show resolved Hide resolved
cleanrl/td3_continuous_action_jax.py Outdated Show resolved Hide resolved
@joaogui1 joaogui1 marked this pull request as ready for review July 21, 2022 16:15
@vwxyzjn vwxyzjn requested a review from dosssman July 22, 2022 16:07
Copy link
Collaborator

@dosssman dosssman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great addition.

Looking good on my side

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jul 27, 2022

@joaogui1 could you take a final look at https://cleanrl-git-fork-joaogui1-master-vwxyzjn.vercel.app/rl-algorithms/td3/#td3_continuous_action_jaxpy to see if there is anything missing?

@joaogui1
Copy link
Collaborator Author

LGTM!

Copy link
Owner

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for this contribution!

@vwxyzjn vwxyzjn merged commit 5bfdd45 into vwxyzjn:master Jul 31, 2022
@vwxyzjn vwxyzjn mentioned this pull request Aug 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

JAX Integration with CleanRL
3 participants