Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using jax scan for PPO + atari + envpool XLA #328

Merged
merged 17 commits into from
Dec 21, 2022
Merged

Using jax scan for PPO + atari + envpool XLA #328

merged 17 commits into from
Dec 21, 2022

Conversation

51616
Copy link
Collaborator

@51616 51616 commented Dec 6, 2022

Description

Modifying the code to use jax.lax.scan for fast compile time and small speed improvement.

The loss metrics of this pull request (blue) are consistent with the original version (green).
image

The performance is similar to the original with a slight speed improvement.
image

The command used is python cleanrl/ppo_atari_envpool_xla_jax_scan.py --env-id Breakout-v5 --total-timesteps 10000000 --num-envs 32 --seed 111 (blue) and python cleanrl/ppo_atari_envpool_xla_jax.py --env-id Breakout-v5 --total-timesteps 10000000 --num-envs 32 --seed 111 (green).

Types of changes

  • Bug fix
  • New feature
  • New algorithm
  • Documentation

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the documentation and previewed the changes via mkdocs serve.
  • I have updated the tests accordingly (if applicable).

If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.

  • I have contacted vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
  • I have tracked applicable experiments in openrlbenchmark/cleanrl with --capture-video flag toggled on (required).
  • I have added additional documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers (if applicable).
    • I have added links to the PR related to the algorithm variant.
    • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves (in PNG format).
    • I have added links to the tracked experiments.
    • I have updated the overview sections at the docs and the repo
  • I have updated the tests accordingly (if applicable).

@vercel
Copy link

vercel bot commented Dec 6, 2022

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Dec 21, 2022 at 5:24PM (UTC)

@vwxyzjn
Copy link
Owner

vwxyzjn commented Dec 6, 2022

Thanks for preparing this PR @51616. Out of curiosity, what's the speed difference when running with the following?

python cleanrl/ppo_atari_envpool_xla_jax_scan.py --env-id Breakout-v5 --total-timesteps 10000000 --num-envs 8
python cleanrl/ppo_atari_envpool_xla_jax.py --env-id Breakout-v5 --total-timesteps 10000000 --num-envs 8

@51616
Copy link
Collaborator Author

51616 commented Dec 7, 2022

image

Here's the training time comparison. I don't think we can compare the code speed by looking purely at this graph because the speed also depends on how fast the agent learns. Since environment reset is relatively expensive, the faster the agent gets better, the fewer resets are called. The rate at which the agent learns depends on exploration that we don't have full control over. Explicitly setting the random seed still cannot precisely reproduce runs. Anyway, I think we shouldn't expect any speed difference between the two versions (jax-ml/jax#402 (comment)). The benefits of this change is mostly the reduced compilation time.

The compilation time for the default num_minibatches=4 and update_epochs=4 decreases significantly using jax.lax.scan, from almost a minute to a few seconds. Using scan also does not increase compilation time when using higher values, whereas the python loop does.

If you think that jax's idiomatic is not so pythonic and hard to read, we can keep both versions. I think there is value in providing example in jax's idiomatic tools.

@vwxyzjn
Copy link
Owner

vwxyzjn commented Dec 7, 2022

@51616 thanks for the detailed explanation. I really like this prototype and think it's probably worth having both versions as references. On a high level, there are some remaining todos (in chronological orders):

  • sanitize the code — remove the commented-out code.
  • minimize the difference between the two versions (try doing a file diff in vs code: select both files and click compare selected)

image

  • add an end-to-end test case in the tests folder and probably add a unit test on compute_gae. You can probably do
from cleanrl.ppo_atari_envpool_xla_jax_scan import compute_gae
from cleanrl.ppo_atari_envpool_xla_jax import compute_gae

# fakes same data and assert the output from the compute gae functions are the same.

After these three steps feel free to ping me to review again, and the last step would be to do the following:

  • run benchmark on three environments to ensure performance is okay (there is no reason to run 57 atari games in this case). You should use the following
export WANDB_ENTITY=openrlbenchmark
python -m cleanrl_utils.benchmark \
    --env-ids Pong-v5 BeamRider-v5 Breakout-v5 \
    --command "poetry run python cleanrl/ppo_atari_envpool_xla_jax_scan.py --track --capture-video" \
    --num-seeds 3 \
    --workers 1

@51616
Copy link
Collaborator Author

51616 commented Dec 9, 2022

@vwxyzjn I did clean up some of the code. Please let me know if there's any specific place I should fix. I have a few questions regarding the tests/benchmarks:

  • For the compute_gae method, I have tested locally comparing with the original function. It was a quick and dirty test by pasting the function from the original file and comparing the output of the two. Testing in a separate file is a bit complicated though. It requires defining the function outside the if __name__ == '__main__' statement to be importable. Should I move the function out? What about other functions?
  • How do I access the project's wandb account?

@vwxyzjn
Copy link
Owner

vwxyzjn commented Dec 11, 2022

For the compute_gae method, I have tested locally comparing with the original function. It was a quick and dirty test by pasting the function from the original file and comparing the output of the two. Testing in a separate file is a bit complicated though. It requires defining the function outside the if name == 'main' statement to be importable. Should I move the function out? What about other functions?

Ah my bad for not thinking this through. In that case, maybe don't import the compute_gae and copy them to the test files and compare. If that's too much hassle, without test on it is also ok :)

How do I access the project's wandb account?

Could you share with me your wandb account username? I will add you to the openrlbenchmark wandb team.

@51616
Copy link
Collaborator Author

51616 commented Dec 12, 2022

In that case, maybe don't import the compute_gae and copy them to the test files and compare.

I will make a test file for that.

Could you share with me your wandb account username? I will add you to the openrlbenchmark wandb team.

Here's my wandb account: https://wandb.ai/51616

@vwxyzjn
Copy link
Owner

vwxyzjn commented Dec 12, 2022

Thank you @51616 I have added you to the openrlbenchmark team. You might want to run pre-commit run --all-files to fix CI.

@51616
Copy link
Collaborator Author

51616 commented Dec 13, 2022

@vwxyzjn I already did the pre-commit hooks for that commit but it still gives an error. I think it has something to do with the tests folder not being formatted but still being checked in CI? I will run the benchmarks today.

edit: turned out the test file was not formatted on my side.

@51616
Copy link
Collaborator Author

51616 commented Dec 15, 2022

@vwxyzjn I did the benchmarks. Please let me know if you want any specific updates for this pr.

@vwxyzjn
Copy link
Owner

vwxyzjn commented Dec 20, 2022

Thanks for your patience. The results look great. The next step is to add documentation. Could you give the following command a try? It compares jax.scan with the for loop variant and openai/baselines'PPO.

pip install openrlbenchmark
python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_atari_envpool_xla_jax_scan?tag=pr-328' \
    --filters '?we=openrlbenchmark&wpn=baselines&ceik=env&cen=exp_name&metric=charts/episodic_return' 'baselines-ppo2-cnn' \
    --filters '?we=openrlbenchmark&wpn=envpool-atari&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_atari_envpool_xla_jax_truncation' \
    --env-ids BeamRider-v5 Breakout-v5  Pong-v5 \
    --check-empty-runs False \
    --ncols 3 \
    --ncols-legend 2 \
    --output-filename compare \
    --scan-history \
    --report

It should generate a figure and tables (compare.md), which you can use to add the docs.

Screenshot 2022-12-20 at 9 21 17 AM

@51616
Copy link
Collaborator Author

51616 commented Dec 20, 2022

openrlbenchmark doesn't seem to work with python<3.9. I got the following error

ERROR: Ignored the following versions that require a different python version: 0.1.1a0 Requires-Python >=3.9,<4.0; 0.1.1a1 Requires-Python >=3.9,<4.0
ERROR: Could not find a version that satisfies the requirement openrlbenchmark (from versions: none)
ERROR: No matching distribution found for openrlbenchmark

edit: I ran the command with a new environment and it worked just fine but I'm not sure if the python>=3.9 requirement is intended.

@vwxyzjn
Copy link
Owner

vwxyzjn commented Dec 20, 2022

Try again pip install openrlbenchmark==0.1.1a2 or pip install https://files.pythonhosted.org/packages/03/6c/a365d82a4653255cbb553414c9f15669ce7b947871233b5ab0f43a8de546/openrlbenchmark-0.1.1a2.tar.gz.

@vwxyzjn
Copy link
Owner

vwxyzjn commented Dec 20, 2022

Yeah I have just made it compatible with python 3.7.1+

@51616
Copy link
Collaborator Author

51616 commented Dec 20, 2022

Thank you for a quick response. I got the report but I'm not sure where to put it. Which specific doc are you referring to?

@vwxyzjn
Copy link
Owner

vwxyzjn commented Dec 20, 2022

consider adding a section in https://github.com/vwxyzjn/cleanrl/blob/master/docs/rl-algorithms/ppo.md like other ppo variants

@51616
Copy link
Collaborator Author

51616 commented Dec 20, 2022

I added the documentation. Not sure if I did it right. Please take a look 🙏

Copy link
Owner

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the docs. I added some docs and feedback.

tests/test_jax_compute_gae.py Show resolved Hide resolved
docs/rl-algorithms/ppo.md Show resolved Hide resolved
Comment on lines 757 to 763
???+ info

The speed of this variant and [ppo_atari_envpool_xla_jax.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_envpool_xla_jax.py) are very similar but the compilation time is reduced significantly (see [vwxyzjn/cleanrl#328](https://github.com/vwxyzjn/cleanrl/pull/328#issuecomment-1340474894)). In the following learning curve, the speed increase comes from the fact that better hardware were used.

Learning curves:

![](../ppo/ppo_atari_envpool_xla_jax_scan/compare.png)
Copy link
Owner

@vwxyzjn vwxyzjn Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the hardware is a bit of an issue. Would you mind doing two additional things?

1 ) run the ppo_atari_envpool_xla_jax.py at your machine to align the hardware settings:

export WANDB_ENTITY=openrlbenchmark
python -m cleanrl_utils.benchmark \
    --env-ids Pong-v5 BeamRider-v5 Breakout-v5 \
    --command "poetry run python cleanrl/ppo_atari_envpool_xla_jax.py --track --capture-video" \
    --num-seeds 3 \
    --workers 1
  1. regenerate the learning curves using openrlbenchmark==0.1.1a3, which also generates the curves with time as the x-axis. Then you can compare the experiments via
pip install --upgrade openrlbenchmark
python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_atari_envpool_xla_jax_scan?tag=pr-328&user=51616' 'ppo_atari_envpool_xla_jax?tag=pr-328&user=51616' \
    --filters '?we=openrlbenchmark&wpn=baselines&ceik=env&cen=exp_name&metric=charts/episodic_return' 'baselines-ppo2-cnn' \
    --filters '?we=openrlbenchmark&wpn=envpool-atari&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_atari_envpool_xla_jax_truncation?user=costa-huang' \
    --env-ids BeamRider-v5 Breakout-v5  Pong-v5 \
    --check-empty-runs False \
    --ncols 3 \
    --ncols-legend 2 \
    --output-filename compare \
    --scan-history \
    --report

and maybe in the docs we can specify the hardware differences used by costa-huang and 51616 :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, the runs with the python loops didn't get the pr-328 tag. So this filter doesn't work
--filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_atari_envpool_xla_jax_scan?tag=pr-328&user=51616' 'ppo_atari_envpool_xla_jax?tag=pr-328&user=51616'

It gives the following error:

{
│   'wandb_project_name': 'cleanrl',
│   'wandb_entity': 'openrlbenchmark',
│   'custom_env_id_key': 'env_id',
│   'custom_exp_name': 'exp_name',
│   'metric': 'charts/avg_episodic_return'
}
========= ppo_atari_envpool_xla_jax_scan?tag=pr-328&user=51616
========= ppo_atari_envpool_xla_jax?tag=pr-328&user=51616
{
│   'wandb_project_name': 'baselines',
│   'wandb_entity': 'openrlbenchmark',
│   'custom_env_id_key': 'env',
│   'custom_exp_name': 'exp_name',
│   'metric': 'charts/episodic_return'
}
========= baselines-ppo2-cnn
{
│   'wandb_project_name': 'envpool-atari',
│   'wandb_entity': 'openrlbenchmark',
│   'custom_env_id_key': 'env_id',
│   'custom_exp_name': 'exp_name',
│   'metric': 'charts/avg_episodic_return'
}
========= ppo_atari_envpool_xla_jax_truncation?user=costa-huang
Traceback (most recent call last):
  File "/home/tan/miniconda3/envs/cleanrl/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/tan/miniconda3/envs/cleanrl/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/tan/miniconda3/envs/cleanrl/lib/python3.7/site-packages/openrlbenchmark/rlops.py", line 356, in <module>
    time_unit=args.time_unit,
  File "/home/tan/miniconda3/envs/cleanrl/lib/python3.7/site-packages/openrlbenchmark/rlops.py", line 161, in compare
    {(runsets[idx].report_runset.name, runsets[idx].runs[0].config[runsets[idx].exp_name]): runsets[idx].color}
  File "/home/tan/miniconda3/envs/cleanrl/lib/python3.7/site-packages/wandb/apis/public.py", line 1053, in __getitem__
    return self.objects[index]
IndexError: list index out of range

Not sure if this is an intended behavior. Anyway, I manually added the tag and the command ran just fine.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm that was weird... The cleanrl_utils.benchmark utility should have autotagged.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docs updated

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

I actually don't quite get it... How come your experiments are like 30% faster than mine? Is it because 2080 TI is faster than 3060 TI? What is your CPU? That could also be a factor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My cpu is Ryzen 5950x. Maybe that’s factor too.

docs/rl-algorithms/ppo.md Outdated Show resolved Hide resolved
Copy link
Owner

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! LGTM. Feel free to merge when you are ready. You should have the contributor's access (merge access) by now. Thanks again.

@51616 51616 merged commit 2dd73af into vwxyzjn:master Dec 21, 2022
@vwxyzjn vwxyzjn mentioned this pull request Jan 5, 2023
20 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants