Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Refactor categorical dists: Masked one-hot and pass-through gradients #1488

Merged
merged 3 commits into from
Sep 5, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Sep 5, 2023

Description

Implements a masked one hot distribution.
Enables 2 reparam strategies for one-hot samples: RelaxedOneHot or Pass-through

Also renamed "mask" in "action_mask" in the MaskedAction transform.

@matteobettini @MateuszGuzek

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 5, 2023
@github-actions
Copy link

github-actions bot commented Sep 5, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 89. Improved: $\large\color{#35bf28}5$. Worsened: $\large\color{#d91a1a}4$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 0.1418s 0.1396s 7.1611 Ops/s 7.1508 Ops/s $\color{#35bf28}+0.15\%$
test_sync 0.1479s 78.5237ms 12.7350 Ops/s 12.5868 Ops/s $\color{#35bf28}+1.18\%$
test_async 0.1954s 72.7233ms 13.7507 Ops/s 13.8291 Ops/s $\color{#d91a1a}-0.57\%$
test_simple 0.6887s 0.6220s 1.6076 Ops/s 1.6173 Ops/s $\color{#d91a1a}-0.60\%$
test_transformed 1.6858s 1.6309s 0.6132 Ops/s 0.6142 Ops/s $\color{#d91a1a}-0.18\%$
test_serial 1.7625s 1.7123s 0.5840 Ops/s 0.5791 Ops/s $\color{#35bf28}+0.85\%$
test_parallel 1.4972s 1.4378s 0.6955 Ops/s 0.6795 Ops/s $\color{#35bf28}+2.35\%$
test_step_mdp_speed[True-True-True-True-True] 0.1794ms 44.9887μs 22.2278 KOps/s 22.1773 KOps/s $\color{#35bf28}+0.23\%$
test_step_mdp_speed[True-True-True-True-False] 0.2239ms 25.5373μs 39.1585 KOps/s 38.7855 KOps/s $\color{#35bf28}+0.96\%$
test_step_mdp_speed[True-True-True-False-True] 96.6010μs 31.5141μs 31.7318 KOps/s 31.6983 KOps/s $\color{#35bf28}+0.11\%$
test_step_mdp_speed[True-True-True-False-False] 38.6000μs 17.5417μs 57.0070 KOps/s 56.1848 KOps/s $\color{#35bf28}+1.46\%$
test_step_mdp_speed[True-True-False-True-True] 94.9010μs 46.7277μs 21.4006 KOps/s 21.4075 KOps/s $\color{#d91a1a}-0.03\%$
test_step_mdp_speed[True-True-False-True-False] 71.1010μs 27.5730μs 36.2674 KOps/s 36.5384 KOps/s $\color{#d91a1a}-0.74\%$
test_step_mdp_speed[True-True-False-False-True] 0.3685ms 33.7708μs 29.6114 KOps/s 29.3186 KOps/s $\color{#35bf28}+1.00\%$
test_step_mdp_speed[True-True-False-False-False] 68.8010μs 19.8516μs 50.3738 KOps/s 51.3053 KOps/s $\color{#d91a1a}-1.82\%$
test_step_mdp_speed[True-False-True-True-True] 79.6010μs 48.5993μs 20.5764 KOps/s 20.5366 KOps/s $\color{#35bf28}+0.19\%$
test_step_mdp_speed[True-False-True-True-False] 60.2000μs 29.5101μs 33.8867 KOps/s 34.1279 KOps/s $\color{#d91a1a}-0.71\%$
test_step_mdp_speed[True-False-True-False-True] 0.1010ms 33.3844μs 29.9541 KOps/s 29.4999 KOps/s $\color{#35bf28}+1.54\%$
test_step_mdp_speed[True-False-True-False-False] 84.5010μs 19.8849μs 50.2894 KOps/s 51.0246 KOps/s $\color{#d91a1a}-1.44\%$
test_step_mdp_speed[True-False-False-True-True] 0.1720ms 49.9006μs 20.0398 KOps/s 19.8256 KOps/s $\color{#35bf28}+1.08\%$
test_step_mdp_speed[True-False-False-True-False] 54.4000μs 31.0639μs 32.1917 KOps/s 32.6176 KOps/s $\color{#d91a1a}-1.31\%$
test_step_mdp_speed[True-False-False-False-True] 58.5010μs 35.7756μs 27.9520 KOps/s 28.4905 KOps/s $\color{#d91a1a}-1.89\%$
test_step_mdp_speed[True-False-False-False-False] 69.7000μs 21.3991μs 46.7310 KOps/s 47.3521 KOps/s $\color{#d91a1a}-1.31\%$
test_step_mdp_speed[False-True-True-True-True] 0.1369ms 48.5873μs 20.5815 KOps/s 20.7081 KOps/s $\color{#d91a1a}-0.61\%$
test_step_mdp_speed[False-True-True-True-False] 54.4000μs 29.4367μs 33.9712 KOps/s 34.2717 KOps/s $\color{#d91a1a}-0.88\%$
test_step_mdp_speed[False-True-True-False-True] 82.9010μs 38.1238μs 26.2304 KOps/s 26.7714 KOps/s $\color{#d91a1a}-2.02\%$
test_step_mdp_speed[False-True-True-False-False] 3.3822ms 21.8752μs 45.7139 KOps/s 45.2709 KOps/s $\color{#35bf28}+0.98\%$
test_step_mdp_speed[False-True-False-True-True] 94.8010μs 50.5544μs 19.7807 KOps/s 19.9191 KOps/s $\color{#d91a1a}-0.69\%$
test_step_mdp_speed[False-True-False-True-False] 62.3010μs 31.0424μs 32.2140 KOps/s 32.0972 KOps/s $\color{#35bf28}+0.36\%$
test_step_mdp_speed[False-True-False-False-True] 84.2010μs 39.6293μs 25.2339 KOps/s 25.8375 KOps/s $\color{#d91a1a}-2.34\%$
test_step_mdp_speed[False-True-False-False-False] 0.9742ms 23.6283μs 42.3221 KOps/s 42.5106 KOps/s $\color{#d91a1a}-0.44\%$
test_step_mdp_speed[False-False-True-True-True] 0.4577ms 52.4640μs 19.0607 KOps/s 19.2920 KOps/s $\color{#d91a1a}-1.20\%$
test_step_mdp_speed[False-False-True-True-False] 76.9000μs 32.9504μs 30.3487 KOps/s 30.1101 KOps/s $\color{#35bf28}+0.79\%$
test_step_mdp_speed[False-False-True-False-True] 70.5000μs 39.7010μs 25.1883 KOps/s 25.5554 KOps/s $\color{#d91a1a}-1.44\%$
test_step_mdp_speed[False-False-True-False-False] 43.0000μs 23.3009μs 42.9168 KOps/s 42.8993 KOps/s $\color{#35bf28}+0.04\%$
test_step_mdp_speed[False-False-False-True-True] 0.1755ms 53.2497μs 18.7794 KOps/s 18.6833 KOps/s $\color{#35bf28}+0.51\%$
test_step_mdp_speed[False-False-False-True-False] 92.8010μs 34.8639μs 28.6830 KOps/s 29.0744 KOps/s $\color{#d91a1a}-1.35\%$
test_step_mdp_speed[False-False-False-False-True] 59.9000μs 40.2576μs 24.8400 KOps/s 24.9074 KOps/s $\color{#d91a1a}-0.27\%$
test_step_mdp_speed[False-False-False-False-False] 57.1000μs 25.1747μs 39.7225 KOps/s 40.3614 KOps/s $\color{#d91a1a}-1.58\%$
test_values[generalized_advantage_estimate-True-True] 15.0419ms 13.2785ms 75.3098 Ops/s 72.1436 Ops/s $\color{#35bf28}+4.39\%$
test_values[vec_generalized_advantage_estimate-True-True] 51.3798ms 41.7116ms 23.9741 Ops/s 24.0014 Ops/s $\color{#d91a1a}-0.11\%$
test_values[td0_return_estimate-False-False] 0.3269ms 0.1936ms 5.1666 KOps/s 4.9529 KOps/s $\color{#35bf28}+4.31\%$
test_values[td1_return_estimate-False-False] 13.2744ms 12.7625ms 78.3549 Ops/s 74.6181 Ops/s $\textbf{\color{#35bf28}+5.01\%}$
test_values[vec_td1_return_estimate-False-False] 49.8292ms 41.7919ms 23.9281 Ops/s 24.1514 Ops/s $\color{#d91a1a}-0.92\%$
test_values[td_lambda_return_estimate-True-False] 32.4096ms 31.7870ms 31.4594 Ops/s 30.3299 Ops/s $\color{#35bf28}+3.72\%$
test_values[vec_td_lambda_return_estimate-True-False] 47.1532ms 41.4111ms 24.1481 Ops/s 23.8551 Ops/s $\color{#35bf28}+1.23\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 11.8588ms 11.7686ms 84.9718 Ops/s 83.8347 Ops/s $\color{#35bf28}+1.36\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 8.8428ms 3.3946ms 294.5863 Ops/s 297.8490 Ops/s $\color{#d91a1a}-1.10\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 1.3540ms 0.4608ms 2.1702 KOps/s 2.1213 KOps/s $\color{#35bf28}+2.30\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 58.4290ms 54.7609ms 18.2612 Ops/s 18.0248 Ops/s $\color{#35bf28}+1.31\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 8.9592ms 2.8540ms 350.3910 Ops/s 353.5988 Ops/s $\color{#d91a1a}-0.91\%$
test_dqn_speed 6.9471ms 1.8056ms 553.8220 Ops/s 540.9023 Ops/s $\color{#35bf28}+2.39\%$
test_ddpg_speed 18.7691ms 2.7145ms 368.3977 Ops/s 361.7681 Ops/s $\color{#35bf28}+1.83\%$
test_sac_speed 14.7899ms 7.7538ms 128.9692 Ops/s 125.1345 Ops/s $\color{#35bf28}+3.06\%$
test_redq_speed 21.9111ms 15.1499ms 66.0069 Ops/s 64.1186 Ops/s $\color{#35bf28}+2.95\%$
test_redq_deprec_speed 19.6961ms 12.2972ms 81.3194 Ops/s 78.9255 Ops/s $\color{#35bf28}+3.03\%$
test_td3_speed 10.7938ms 9.7377ms 102.6936 Ops/s 100.6557 Ops/s $\color{#35bf28}+2.02\%$
test_cql_speed 33.8275ms 27.6488ms 36.1679 Ops/s 38.7518 Ops/s $\textbf{\color{#d91a1a}-6.67\%}$
test_a2c_speed 15.6888ms 5.0621ms 197.5446 Ops/s 195.0564 Ops/s $\color{#35bf28}+1.28\%$
test_ppo_speed 11.0612ms 5.3966ms 185.3023 Ops/s 176.5066 Ops/s $\color{#35bf28}+4.98\%$
test_reinforce_speed 10.0408ms 3.9907ms 250.5799 Ops/s 252.1186 Ops/s $\color{#d91a1a}-0.61\%$
test_iql_speed 26.2949ms 20.3941ms 49.0338 Ops/s 47.3095 Ops/s $\color{#35bf28}+3.64\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.9786ms 2.5590ms 390.7844 Ops/s 383.6355 Ops/s $\color{#35bf28}+1.86\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 4.8742ms 2.7265ms 366.7760 Ops/s 356.8004 Ops/s $\color{#35bf28}+2.80\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 3.9742ms 2.7131ms 368.5878 Ops/s 362.6206 Ops/s $\color{#35bf28}+1.65\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.3823ms 2.5622ms 390.2892 Ops/s 387.0785 Ops/s $\color{#35bf28}+0.83\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 4.9763ms 2.7758ms 360.2533 Ops/s 359.9019 Ops/s $\color{#35bf28}+0.10\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 3.9235ms 2.7114ms 368.8103 Ops/s 358.6469 Ops/s $\color{#35bf28}+2.83\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.4729ms 2.5530ms 391.7033 Ops/s 387.4497 Ops/s $\color{#35bf28}+1.10\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 0.1291s 3.1232ms 320.1878 Ops/s 362.0930 Ops/s $\textbf{\color{#d91a1a}-11.57\%}$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 4.6801ms 2.7012ms 370.2011 Ops/s 361.5817 Ops/s $\color{#35bf28}+2.38\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.8214ms 2.5622ms 390.2960 Ops/s 388.4698 Ops/s $\color{#35bf28}+0.47\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 4.5355ms 2.7459ms 364.1742 Ops/s 358.7391 Ops/s $\color{#35bf28}+1.52\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 3.9272ms 2.7127ms 368.6347 Ops/s 361.7262 Ops/s $\color{#35bf28}+1.91\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.3370ms 2.5364ms 394.2558 Ops/s 382.2341 Ops/s $\color{#35bf28}+3.15\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 4.4094ms 2.7724ms 360.6920 Ops/s 360.8414 Ops/s $\color{#d91a1a}-0.04\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 4.6807ms 2.7112ms 368.8405 Ops/s 356.5159 Ops/s $\color{#35bf28}+3.46\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.3811ms 2.5437ms 393.1285 Ops/s 386.6728 Ops/s $\color{#35bf28}+1.67\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 4.4248ms 2.7807ms 359.6161 Ops/s 358.7148 Ops/s $\color{#35bf28}+0.25\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 4.3221ms 2.7051ms 369.6661 Ops/s 357.8994 Ops/s $\color{#35bf28}+3.29\%$
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.2597s 28.9812ms 34.5051 Ops/s 34.4302 Ops/s $\color{#35bf28}+0.22\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 0.1407s 28.9996ms 34.4833 Ops/s 34.5354 Ops/s $\color{#d91a1a}-0.15\%$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 0.1381s 26.2357ms 38.1160 Ops/s 34.7735 Ops/s $\textbf{\color{#35bf28}+9.61\%}$
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1348s 28.5255ms 35.0564 Ops/s 37.4434 Ops/s $\textbf{\color{#d91a1a}-6.37\%}$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 0.1342s 26.0943ms 38.3226 Ops/s 34.5697 Ops/s $\textbf{\color{#35bf28}+10.86\%}$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 0.1356s 28.5864ms 34.9817 Ops/s 37.6791 Ops/s $\textbf{\color{#d91a1a}-7.16\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1391s 23.9238ms 41.7994 Ops/s 34.1709 Ops/s $\textbf{\color{#35bf28}+22.32\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 0.1322s 28.1218ms 35.5596 Ops/s 37.1617 Ops/s $\color{#d91a1a}-4.31\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 0.1295s 25.8172ms 38.7339 Ops/s 34.3701 Ops/s $\textbf{\color{#35bf28}+12.70\%}$

Copy link
Contributor

@MateuszGuzek MateuszGuzek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, single comment suggesting improvement in the docstring

torchrl/modules/distributions/discrete.py Outdated Show resolved Hide resolved
@vmoens vmoens added the enhancement New feature or request label Sep 5, 2023
@vmoens vmoens merged commit e133749 into main Sep 5, 2023
49 of 56 checks passed
@vmoens vmoens deleted the masked_onehot branch September 5, 2023 14:14
vmoens added a commit that referenced this pull request Sep 5, 2023
vmoens added a commit to hyerra/rl that referenced this pull request Oct 10, 2023
vmoens added a commit to hyerra/rl that referenced this pull request Oct 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants