-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Compatibility of tensordict primers with batched envs (specifically for LSTM and GRU) #2668
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2668
Note: Links to docs will display an error until the docs builds have been completed. ❌ 8 New Failures, 7 Unrelated FailuresAs of commit fafa7bd with merge base 133d709 (): NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…fically for LSTM and GRU) ghstack-source-id: ff709b9f51e2e8dbcb50aed56fb4727902cd168e Pull Request resolved: #2668
This PR works with pytorch/tensordict#1150 |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.4282s | 0.4245s | 2.3559 Ops/s | 2.2083 Ops/s | |
test_transformed | 0.6087s | 0.6035s | 1.6571 Ops/s | 1.6354 Ops/s | |
test_serial | 1.3604s | 1.3543s | 0.7384 Ops/s | 0.7240 Ops/s | |
test_parallel | 1.3098s | 1.2165s | 0.8221 Ops/s | 0.8149 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1942ms | 31.4436μs | 31.8029 KOps/s | 31.9152 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 79.4680μs | 18.3381μs | 54.5314 KOps/s | 54.4899 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 52.9080μs | 17.4976μs | 57.1506 KOps/s | 56.7665 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 70.1200μs | 10.2871μs | 97.2088 KOps/s | 96.6369 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 77.2640μs | 33.3220μs | 30.0102 KOps/s | 30.2840 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 85.4390μs | 20.1280μs | 49.6821 KOps/s | 48.9341 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 51.0750μs | 19.6332μs | 50.9343 KOps/s | 50.9360 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 65.0410μs | 12.3854μs | 80.7401 KOps/s | 81.2666 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 99.5160μs | 35.0569μs | 28.5251 KOps/s | 28.3703 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 56.6950μs | 22.4229μs | 44.5973 KOps/s | 44.8394 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 78.6670μs | 19.5143μs | 51.2444 KOps/s | 50.9645 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 40.2550μs | 12.4267μs | 80.4719 KOps/s | 81.7768 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 0.1143ms | 37.0164μs | 27.0151 KOps/s | 27.2529 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 70.5910μs | 24.4821μs | 40.8462 KOps/s | 41.2242 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 71.0420μs | 21.4049μs | 46.7183 KOps/s | 46.9464 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 61.3440μs | 14.0128μs | 71.3634 KOps/s | 71.0454 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 83.9560μs | 35.1420μs | 28.4560 KOps/s | 28.5091 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 70.6610μs | 22.4792μs | 44.4855 KOps/s | 44.5967 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 78.0160μs | 22.0964μs | 45.2562 KOps/s | 45.1898 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 47.5090μs | 13.6324μs | 73.3545 KOps/s | 71.3772 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 84.6770μs | 36.8197μs | 27.1594 KOps/s | 27.0118 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 74.3890μs | 24.2180μs | 41.2917 KOps/s | 41.1295 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 2.7226ms | 24.2753μs | 41.1942 KOps/s | 41.2464 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 53.4500μs | 15.6132μs | 64.0485 KOps/s | 64.4902 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 82.8940μs | 39.0889μs | 25.5827 KOps/s | 25.8268 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 65.7020μs | 25.9223μs | 38.5768 KOps/s | 38.4338 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 54.6810μs | 23.7977μs | 42.0209 KOps/s | 41.2455 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 69.9400μs | 15.4327μs | 64.7976 KOps/s | 64.6726 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 95.7480μs | 40.8700μs | 24.4678 KOps/s | 24.9933 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 0.6584ms | 27.9149μs | 35.8232 KOps/s | 36.4451 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 77.8850μs | 25.5721μs | 39.1051 KOps/s | 39.2201 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 68.9890μs | 17.2975μs | 57.8119 KOps/s | 58.2832 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 11.8596ms | 9.6345ms | 103.7932 Ops/s | 101.3407 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 37.8942ms | 35.4309ms | 28.2240 Ops/s | 29.8851 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2278ms | 0.1780ms | 5.6166 KOps/s | 5.6644 KOps/s | |
test_values[td1_return_estimate-False-False] | 28.2595ms | 23.5459ms | 42.4702 Ops/s | 41.0406 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 39.3904ms | 35.6254ms | 28.0699 Ops/s | 29.7937 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 34.8404ms | 33.9562ms | 29.4497 Ops/s | 27.7944 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 38.0906ms | 35.5781ms | 28.1072 Ops/s | 29.8085 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 8.8362ms | 8.3253ms | 120.1161 Ops/s | 119.0441 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.3225ms | 1.8504ms | 540.4184 Ops/s | 539.0778 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.6440ms | 0.3520ms | 2.8410 KOps/s | 2.8018 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 55.3488ms | 43.7606ms | 22.8516 Ops/s | 26.5580 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 4.0201ms | 3.0167ms | 331.4881 Ops/s | 324.5610 Ops/s | |
test_dqn_speed[False-None] | 5.7826ms | 1.3873ms | 720.8506 Ops/s | 716.6196 Ops/s | |
test_dqn_speed[False-backward] | 2.1185ms | 1.8751ms | 533.3078 Ops/s | 524.9859 Ops/s | |
test_dqn_speed[True-None] | 0.6166ms | 0.4690ms | 2.1323 KOps/s | 2.0660 KOps/s | |
test_dqn_speed[True-backward] | 0.9156ms | 0.8838ms | 1.1314 KOps/s | 850.5411 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.7222ms | 0.4773ms | 2.0952 KOps/s | 2.0755 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 0.9112ms | 0.8827ms | 1.1328 KOps/s | 1.0982 KOps/s | |
test_ddpg_speed[False-None] | 3.9713ms | 2.8924ms | 345.7379 Ops/s | 346.1660 Ops/s | |
test_ddpg_speed[False-backward] | 4.8436ms | 4.0412ms | 247.4515 Ops/s | 249.9097 Ops/s | |
test_ddpg_speed[True-None] | 1.5409ms | 1.0089ms | 991.1706 Ops/s | 987.5670 Ops/s | |
test_ddpg_speed[True-backward] | 2.5029ms | 1.9516ms | 512.4082 Ops/s | 522.2633 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.4171ms | 0.9990ms | 1.0010 KOps/s | 988.5683 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 1.9582ms | 1.8843ms | 530.6906 Ops/s | 522.7217 Ops/s | |
test_sac_speed[False-None] | 10.2898ms | 7.9631ms | 125.5791 Ops/s | 123.1895 Ops/s | |
test_sac_speed[False-backward] | 12.3358ms | 10.7576ms | 92.9578 Ops/s | 92.0884 Ops/s | |
test_sac_speed[True-None] | 2.1986ms | 1.8240ms | 548.2413 Ops/s | 545.6319 Ops/s | |
test_sac_speed[True-backward] | 3.6078ms | 3.5303ms | 283.2591 Ops/s | 280.6457 Ops/s | |
test_sac_speed[reduce-overhead-None] | 2.3034ms | 1.8396ms | 543.5951 Ops/s | 540.5226 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 3.5804ms | 3.4963ms | 286.0132 Ops/s | 282.5114 Ops/s | |
test_redq_speed[False-None] | 14.7502ms | 12.8208ms | 77.9985 Ops/s | 76.6866 Ops/s | |
test_redq_speed[False-backward] | 24.3281ms | 22.1470ms | 45.1528 Ops/s | 44.5987 Ops/s | |
test_redq_speed[True-None] | 5.6607ms | 4.5551ms | 219.5354 Ops/s | 218.6585 Ops/s | |
test_redq_speed[True-backward] | 13.9107ms | 12.2453ms | 81.6640 Ops/s | 82.3080 Ops/s | |
test_redq_speed[reduce-overhead-None] | 5.4868ms | 4.6403ms | 215.5032 Ops/s | 217.3352 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 13.4155ms | 12.3069ms | 81.2553 Ops/s | 83.1601 Ops/s | |
test_redq_deprec_speed[False-None] | 15.0897ms | 12.7794ms | 78.2510 Ops/s | 75.7270 Ops/s | |
test_redq_deprec_speed[False-backward] | 19.8365ms | 18.5936ms | 53.7819 Ops/s | 52.5150 Ops/s | |
test_redq_deprec_speed[True-None] | 4.3888ms | 3.5908ms | 278.4862 Ops/s | 278.6985 Ops/s | |
test_redq_deprec_speed[True-backward] | 9.7837ms | 8.0503ms | 124.2187 Ops/s | 124.7312 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 4.3706ms | 3.5754ms | 279.6927 Ops/s | 278.0793 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 8.2748ms | 7.9967ms | 125.0509 Ops/s | 123.5445 Ops/s | |
test_td3_speed[False-None] | 8.6657ms | 7.9841ms | 125.2490 Ops/s | 122.6170 Ops/s | |
test_td3_speed[False-backward] | 10.8336ms | 10.3531ms | 96.5894 Ops/s | 94.4764 Ops/s | |
test_td3_speed[True-None] | 1.9549ms | 1.7372ms | 575.6504 Ops/s | 578.1645 Ops/s | |
test_td3_speed[True-backward] | 3.3837ms | 3.3237ms | 300.8719 Ops/s | 294.4490 Ops/s | |
test_td3_speed[reduce-overhead-None] | 2.0636ms | 1.7318ms | 577.4222 Ops/s | 570.4972 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 4.2731ms | 3.3900ms | 294.9812 Ops/s | 297.7999 Ops/s | |
test_cql_speed[False-None] | 38.7580ms | 36.4675ms | 27.4217 Ops/s | 26.9357 Ops/s | |
test_cql_speed[False-backward] | 49.4664ms | 46.8109ms | 21.3626 Ops/s | 21.2941 Ops/s | |
test_cql_speed[True-None] | 18.1309ms | 15.7874ms | 63.3415 Ops/s | 64.0168 Ops/s | |
test_cql_speed[True-backward] | 25.0191ms | 23.1190ms | 43.2545 Ops/s | 44.5537 Ops/s | |
test_cql_speed[reduce-overhead-None] | 17.4685ms | 15.7543ms | 63.4746 Ops/s | 63.9456 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 24.1276ms | 22.5831ms | 44.2810 Ops/s | 44.2541 Ops/s | |
test_a2c_speed[False-None] | 8.8877ms | 7.1961ms | 138.9634 Ops/s | 138.4517 Ops/s | |
test_a2c_speed[False-backward] | 17.8818ms | 14.4378ms | 69.2625 Ops/s | 70.6553 Ops/s | |
test_a2c_speed[True-None] | 4.7186ms | 4.2156ms | 237.2114 Ops/s | 237.2288 Ops/s | |
test_a2c_speed[True-backward] | 12.0371ms | 10.7433ms | 93.0809 Ops/s | 93.4538 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 5.0299ms | 4.2200ms | 236.9667 Ops/s | 235.9938 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 11.4558ms | 10.6329ms | 94.0480 Ops/s | 93.2941 Ops/s | |
test_ppo_speed[False-None] | 8.7562ms | 7.3840ms | 135.4284 Ops/s | 132.6841 Ops/s | |
test_ppo_speed[False-backward] | 16.5043ms | 14.6375ms | 68.3177 Ops/s | 66.4393 Ops/s | |
test_ppo_speed[True-None] | 4.4340ms | 3.6804ms | 271.7130 Ops/s | 267.0986 Ops/s | |
test_ppo_speed[True-backward] | 10.2817ms | 9.7645ms | 102.4120 Ops/s | 104.5430 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 4.1056ms | 3.6800ms | 271.7375 Ops/s | 270.0292 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 10.3061ms | 9.5662ms | 104.5352 Ops/s | 103.3259 Ops/s | |
test_reinforce_speed[False-None] | 8.0459ms | 6.5299ms | 153.1418 Ops/s | 151.3005 Ops/s | |
test_reinforce_speed[False-backward] | 10.0567ms | 9.7896ms | 102.1493 Ops/s | 101.2330 Ops/s | |
test_reinforce_speed[True-None] | 3.3705ms | 2.6519ms | 377.0833 Ops/s | 375.1671 Ops/s | |
test_reinforce_speed[True-backward] | 9.9073ms | 8.5584ms | 116.8447 Ops/s | 115.1117 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 3.2287ms | 2.6378ms | 379.1034 Ops/s | 370.3569 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 8.9612ms | 8.5733ms | 116.6417 Ops/s | 113.5054 Ops/s | |
test_iql_speed[False-None] | 35.2690ms | 32.3157ms | 30.9447 Ops/s | 30.3851 Ops/s | |
test_iql_speed[False-backward] | 58.9302ms | 46.0836ms | 21.6997 Ops/s | 21.6870 Ops/s | |
test_iql_speed[True-None] | 12.2878ms | 10.6308ms | 94.0661 Ops/s | 92.6133 Ops/s | |
test_iql_speed[True-backward] | 23.2291ms | 21.6090ms | 46.2770 Ops/s | 45.2087 Ops/s | |
test_iql_speed[reduce-overhead-None] | 12.1672ms | 10.6567ms | 93.8373 Ops/s | 93.2673 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 24.2886ms | 22.1892ms | 45.0670 Ops/s | 45.7203 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 8.0066ms | 5.1431ms | 194.4370 Ops/s | 192.5790 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.7886ms | 0.5174ms | 1.9329 KOps/s | 1.8891 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.9165ms | 0.4985ms | 2.0061 KOps/s | 1.9909 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 7.5170ms | 4.7527ms | 210.4084 Ops/s | 211.1230 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.3593s | 0.7922ms | 1.2623 KOps/s | 1.9929 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.8638ms | 0.4903ms | 2.0394 KOps/s | 2.0542 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.4450ms | 1.6343ms | 611.8741 Ops/s | 604.7245 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.8776ms | 1.5782ms | 633.6239 Ops/s | 629.7349 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 7.6096ms | 4.9034ms | 203.9407 Ops/s | 205.7573 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.1300ms | 0.6466ms | 1.5465 KOps/s | 1.5225 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.9930ms | 0.6228ms | 1.6056 KOps/s | 1.5941 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.1656ms | 4.7423ms | 210.8693 Ops/s | 210.2729 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.0878ms | 0.5163ms | 1.9370 KOps/s | 1.9423 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 1.0383ms | 0.5029ms | 1.9886 KOps/s | 1.9835 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.7013ms | 4.7571ms | 210.2121 Ops/s | 211.2897 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.0460ms | 0.5042ms | 1.9833 KOps/s | 1.9465 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7100ms | 0.4821ms | 2.0743 KOps/s | 2.0863 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 5.1665ms | 4.8271ms | 207.1655 Ops/s | 207.7481 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 3.2124ms | 0.6552ms | 1.5263 KOps/s | 1.5238 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8337ms | 0.6220ms | 1.6077 KOps/s | 1.5851 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.4381s | 13.0605ms | 76.5665 Ops/s | 38.0294 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 3.4058ms | 2.1447ms | 466.2626 Ops/s | 454.8518 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 7.8657ms | 1.4902ms | 671.0554 Ops/s | 608.5312 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 6.0000ms | 4.3836ms | 228.1215 Ops/s | 228.6593 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 9.9551ms | 2.4455ms | 408.9154 Ops/s | 434.9172 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 5.4661ms | 1.4231ms | 702.7039 Ops/s | 705.6534 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.4125s | 12.7392ms | 78.4976 Ops/s | 231.4041 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 8.1290ms | 2.4454ms | 408.9357 Ops/s | 399.8724 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 5.0425ms | 1.5851ms | 630.8894 Ops/s | 571.2720 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 16.3644ms | 13.2208ms | 75.6387 Ops/s | 70.9111 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 16.2205ms | 14.8853ms | 67.1805 Ops/s | 64.0777 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 24.7037ms | 21.7689ms | 45.9372 Ops/s | 43.6514 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 16.6711ms | 15.0820ms | 66.3044 Ops/s | 65.1309 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 22.8006ms | 21.7852ms | 45.9027 Ops/s | 44.8891 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 18.3877ms | 16.5206ms | 60.5305 Ops/s | 59.7301 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.7228s | 0.7221s | 1.3849 Ops/s | 1.3414 Ops/s | |
test_transformed | 0.9842s | 0.9754s | 1.0252 Ops/s | 1.0237 Ops/s | |
test_serial | 2.2511s | 2.1682s | 0.4612 Ops/s | 0.4611 Ops/s | |
test_parallel | 1.9450s | 1.8507s | 0.5403 Ops/s | 0.5217 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1788ms | 40.0378μs | 24.9764 KOps/s | 24.5499 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 53.0310μs | 23.3761μs | 42.7787 KOps/s | 42.2935 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 53.9410μs | 22.3860μs | 44.6707 KOps/s | 44.2122 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 47.1210μs | 13.0642μs | 76.5449 KOps/s | 75.9581 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 89.2820μs | 43.0290μs | 23.2401 KOps/s | 23.1811 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 63.1310μs | 25.4631μs | 39.2725 KOps/s | 38.6998 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 50.8900μs | 24.8686μs | 40.2113 KOps/s | 39.7351 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 39.7810μs | 15.2726μs | 65.4768 KOps/s | 63.8102 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 90.0710μs | 45.5759μs | 21.9414 KOps/s | 22.1109 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 60.2810μs | 27.8840μs | 35.8629 KOps/s | 35.3456 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 63.3420μs | 24.5904μs | 40.6663 KOps/s | 39.8307 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 40.2700μs | 15.4500μs | 64.7248 KOps/s | 64.3641 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 80.3910μs | 47.6809μs | 20.9727 KOps/s | 20.6805 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 66.3110μs | 30.6367μs | 32.6406 KOps/s | 32.5825 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 75.0620μs | 27.4768μs | 36.3943 KOps/s | 36.3839 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 46.3810μs | 17.6878μs | 56.5361 KOps/s | 56.7104 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 91.6820μs | 46.1812μs | 21.6538 KOps/s | 21.6813 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 66.1320μs | 28.5831μs | 34.9857 KOps/s | 34.7440 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 60.5710μs | 29.2562μs | 34.1808 KOps/s | 34.4828 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 43.2710μs | 17.3344μs | 57.6886 KOps/s | 57.8843 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 84.3010μs | 48.0892μs | 20.7947 KOps/s | 20.5513 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 82.3120μs | 30.2490μs | 33.0590 KOps/s | 32.3956 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 3.2170ms | 31.5588μs | 31.6869 KOps/s | 32.0919 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 45.4410μs | 19.5670μs | 51.1064 KOps/s | 50.1661 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 80.9310μs | 50.7938μs | 19.6874 KOps/s | 19.8560 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 66.9110μs | 33.2093μs | 30.1120 KOps/s | 29.9813 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 67.7510μs | 31.3631μs | 31.8846 KOps/s | 32.6464 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 46.4010μs | 19.6849μs | 50.8003 KOps/s | 50.9359 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 82.3410μs | 52.1634μs | 19.1705 KOps/s | 19.0559 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 67.8010μs | 35.4924μs | 28.1751 KOps/s | 28.2873 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 64.3610μs | 32.8098μs | 30.4787 KOps/s | 30.8855 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 54.4610μs | 21.8080μs | 45.8548 KOps/s | 45.8203 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 26.0171ms | 25.5078ms | 39.2036 Ops/s | 38.5537 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 95.5710ms | 2.8212ms | 354.4596 Ops/s | 312.1272 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.1129ms | 83.6459μs | 11.9552 KOps/s | 12.2418 KOps/s | |
test_values[td1_return_estimate-False-False] | 58.0369ms | 57.4348ms | 17.4110 Ops/s | 16.7228 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 1.3056ms | 1.1004ms | 908.7426 Ops/s | 910.9002 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 95.4247ms | 91.7772ms | 10.8960 Ops/s | 10.3433 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 1.3016ms | 1.1004ms | 908.7272 Ops/s | 914.6391 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 27.3810ms | 26.8788ms | 37.2041 Ops/s | 39.0383 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.0615ms | 0.7732ms | 1.2934 KOps/s | 1.3062 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.7837ms | 0.6853ms | 1.4593 KOps/s | 1.4641 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.5496ms | 1.4962ms | 668.3807 Ops/s | 667.9190 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.7404ms | 0.7003ms | 1.4280 KOps/s | 1.4348 KOps/s | |
test_dqn_speed[False-None] | 7.0562ms | 1.5530ms | 643.9109 Ops/s | 648.5017 Ops/s | |
test_dqn_speed[False-backward] | 2.3403ms | 2.1700ms | 460.8280 Ops/s | 464.7840 Ops/s | |
test_dqn_speed[True-None] | 0.6408ms | 0.5663ms | 1.7659 KOps/s | 1.7583 KOps/s | |
test_dqn_speed[True-backward] | 1.2007ms | 1.1368ms | 879.6989 Ops/s | 876.2883 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.6658ms | 0.5791ms | 1.7267 KOps/s | 1.7330 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.0407ms | 0.9992ms | 1.0008 KOps/s | 997.7704 Ops/s | |
test_ddpg_speed[False-None] | 3.2211ms | 2.9042ms | 344.3310 Ops/s | 340.9356 Ops/s | |
test_ddpg_speed[False-backward] | 4.6423ms | 4.2019ms | 237.9891 Ops/s | 238.6794 Ops/s | |
test_ddpg_speed[True-None] | 1.2169ms | 1.1175ms | 894.8266 Ops/s | 889.5538 Ops/s | |
test_ddpg_speed[True-backward] | 2.3385ms | 2.2394ms | 446.5459 Ops/s | 452.5091 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.2136ms | 1.1335ms | 882.2143 Ops/s | 885.4251 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 1.7659ms | 1.6989ms | 588.6288 Ops/s | 594.6596 Ops/s | |
test_sac_speed[False-None] | 8.5889ms | 8.1722ms | 122.3656 Ops/s | 118.7787 Ops/s | |
test_sac_speed[False-backward] | 11.7041ms | 11.2286ms | 89.0585 Ops/s | 89.0507 Ops/s | |
test_sac_speed[True-None] | 1.6252ms | 1.5744ms | 635.1814 Ops/s | 632.1043 Ops/s | |
test_sac_speed[True-backward] | 3.4559ms | 3.3441ms | 299.0300 Ops/s | 300.5562 Ops/s | |
test_sac_speed[reduce-overhead-None] | 22.7950ms | 12.7612ms | 78.3625 Ops/s | 79.2699 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 1.6585ms | 1.5449ms | 647.3000 Ops/s | 725.4728 Ops/s | |
test_redq_speed[False-None] | 8.3589ms | 7.6037ms | 131.5151 Ops/s | 130.4663 Ops/s | |
test_redq_speed[False-backward] | 12.5583ms | 11.8068ms | 84.6966 Ops/s | 87.2569 Ops/s | |
test_redq_speed[True-None] | 2.2185ms | 2.0578ms | 485.9536 Ops/s | 491.0789 Ops/s | |
test_redq_speed[True-backward] | 4.2522ms | 3.7909ms | 263.7904 Ops/s | 267.6513 Ops/s | |
test_redq_speed[reduce-overhead-None] | 2.2050ms | 2.0604ms | 485.3400 Ops/s | 489.1820 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 3.9005ms | 3.8000ms | 263.1579 Ops/s | 256.7631 Ops/s | |
test_redq_deprec_speed[False-None] | 9.8496ms | 9.2254ms | 108.3961 Ops/s | 107.0096 Ops/s | |
test_redq_deprec_speed[False-backward] | 13.1035ms | 12.2774ms | 81.4504 Ops/s | 78.9563 Ops/s | |
test_redq_deprec_speed[True-None] | 2.6423ms | 2.4447ms | 409.0403 Ops/s | 392.2691 Ops/s | |
test_redq_deprec_speed[True-backward] | 4.2456ms | 4.1579ms | 240.5060 Ops/s | 243.8328 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 2.6415ms | 2.4325ms | 411.0994 Ops/s | 420.6284 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 4.2569ms | 4.1605ms | 240.3534 Ops/s | 244.4610 Ops/s | |
test_td3_speed[False-None] | 8.2624ms | 8.0666ms | 123.9686 Ops/s | 121.1492 Ops/s | |
test_td3_speed[False-backward] | 11.0031ms | 10.5241ms | 95.0204 Ops/s | 96.5791 Ops/s | |
test_td3_speed[True-None] | 1.7027ms | 1.6423ms | 608.8878 Ops/s | 617.8048 Ops/s | |
test_td3_speed[True-backward] | 3.3818ms | 3.2374ms | 308.8857 Ops/s | 314.7963 Ops/s | |
test_td3_speed[reduce-overhead-None] | 82.6124ms | 26.9755ms | 37.0706 Ops/s | 35.7318 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 1.3925ms | 1.3534ms | 738.8904 Ops/s | 664.3730 Ops/s | |
test_cql_speed[False-None] | 17.6320ms | 17.0185ms | 58.7595 Ops/s | 58.0813 Ops/s | |
test_cql_speed[False-backward] | 22.8120ms | 22.3467ms | 44.7494 Ops/s | 43.9278 Ops/s | |
test_cql_speed[True-None] | 3.1761ms | 3.0637ms | 326.4001 Ops/s | 309.7993 Ops/s | |
test_cql_speed[True-backward] | 5.7531ms | 5.2804ms | 189.3804 Ops/s | 183.8864 Ops/s | |
test_cql_speed[reduce-overhead-None] | 21.9101ms | 13.4045ms | 74.6018 Ops/s | 74.8989 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 1.6421ms | 1.5606ms | 640.7830 Ops/s | 574.6007 Ops/s | |
test_a2c_speed[False-None] | 3.4735ms | 3.3161ms | 301.5590 Ops/s | 302.8478 Ops/s | |
test_a2c_speed[False-backward] | 6.6651ms | 6.2199ms | 160.7732 Ops/s | 152.6398 Ops/s | |
test_a2c_speed[True-None] | 1.1729ms | 1.0835ms | 922.9158 Ops/s | 962.9405 Ops/s | |
test_a2c_speed[True-backward] | 2.9371ms | 2.8570ms | 350.0174 Ops/s | 350.9573 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 21.6886ms | 11.5966ms | 86.2319 Ops/s | 85.7153 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 1.2108ms | 1.1498ms | 869.7181 Ops/s | 987.7807 Ops/s | |
test_ppo_speed[False-None] | 3.8694ms | 3.7656ms | 265.5623 Ops/s | 265.8940 Ops/s | |
test_ppo_speed[False-backward] | 7.6569ms | 7.1813ms | 139.2500 Ops/s | 144.9375 Ops/s | |
test_ppo_speed[True-None] | 1.0587ms | 0.9892ms | 1.0109 KOps/s | 1.0280 KOps/s | |
test_ppo_speed[True-backward] | 2.8472ms | 2.7984ms | 357.3424 Ops/s | 362.7879 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 0.5953ms | 0.5331ms | 1.8759 KOps/s | 1.8502 KOps/s | |
test_ppo_speed[reduce-overhead-backward] | 1.2055ms | 1.1483ms | 870.8450 Ops/s | 981.4842 Ops/s | |
test_reinforce_speed[False-None] | 2.4273ms | 2.3127ms | 432.3907 Ops/s | 430.3892 Ops/s | |
test_reinforce_speed[False-backward] | 3.9167ms | 3.5113ms | 284.7931 Ops/s | 296.8455 Ops/s | |
test_reinforce_speed[True-None] | 0.9481ms | 0.8641ms | 1.1573 KOps/s | 1.1581 KOps/s | |
test_reinforce_speed[True-backward] | 2.7028ms | 2.6333ms | 379.7456 Ops/s | 403.1075 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 22.2326ms | 11.7725ms | 84.9439 Ops/s | 86.6727 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 1.2440ms | 1.2078ms | 827.9235 Ops/s | 846.8949 Ops/s | |
test_iql_speed[False-None] | 9.9518ms | 9.4668ms | 105.6321 Ops/s | 105.5289 Ops/s | |
test_iql_speed[False-backward] | 14.0432ms | 13.4960ms | 74.0960 Ops/s | 74.4150 Ops/s | |
test_iql_speed[True-None] | 1.9371ms | 1.8316ms | 545.9754 Ops/s | 542.4814 Ops/s | |
test_iql_speed[True-backward] | 4.7065ms | 4.6082ms | 217.0024 Ops/s | 220.6218 Ops/s | |
test_iql_speed[reduce-overhead-None] | 20.3111ms | 11.6889ms | 85.5514 Ops/s | 87.1614 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 1.6247ms | 1.5805ms | 632.7079 Ops/s | 691.8500 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 8.0829ms | 6.4672ms | 154.6263 Ops/s | 151.7603 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.4813ms | 0.2778ms | 3.6003 KOps/s | 2.9855 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.4307ms | 0.2573ms | 3.8868 KOps/s | 3.0627 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.7425ms | 6.1720ms | 162.0219 Ops/s | 159.4338 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.9493ms | 0.2695ms | 3.7102 KOps/s | 3.0971 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.4423ms | 0.2489ms | 4.0180 KOps/s | 3.1687 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.5589ms | 1.3347ms | 749.2566 Ops/s | 773.8013 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.4441ms | 1.2153ms | 822.8247 Ops/s | 805.1580 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.6004ms | 6.3991ms | 156.2717 Ops/s | 156.3077 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.0819ms | 0.4932ms | 2.0274 KOps/s | 2.3382 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8086ms | 0.4249ms | 2.3537 KOps/s | 2.1338 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.4498ms | 6.2590ms | 159.7703 Ops/s | 160.0515 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.8507ms | 0.3309ms | 3.0221 KOps/s | 3.5819 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.4518ms | 0.2559ms | 3.9082 KOps/s | 3.8818 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.5275ms | 6.1563ms | 162.4345 Ops/s | 159.2547 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.9007ms | 0.2716ms | 3.6815 KOps/s | 3.1400 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.4405ms | 0.2461ms | 4.0629 KOps/s | 3.9135 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.7340ms | 6.4374ms | 155.3411 Ops/s | 155.3001 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.1295ms | 0.4405ms | 2.2703 KOps/s | 2.1623 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.5785ms | 0.3866ms | 2.5866 KOps/s | 2.1841 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 7.2128ms | 5.5397ms | 180.5165 Ops/s | 182.5044 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 10.5620ms | 2.1715ms | 460.5188 Ops/s | 431.2418 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 3.5355ms | 1.1751ms | 851.0160 Ops/s | 852.0223 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 10.2119ms | 5.6306ms | 177.6002 Ops/s | 183.9994 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 6.2484ms | 2.0498ms | 487.8615 Ops/s | 424.0813 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 0.5033s | 11.2730ms | 88.7078 Ops/s | 865.2095 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 8.3600ms | 5.7505ms | 173.8991 Ops/s | 32.6966 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 9.2678ms | 2.1486ms | 465.4243 Ops/s | 459.5543 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 8.2366ms | 1.3968ms | 715.9364 Ops/s | 745.1151 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 15.8455ms | 15.5830ms | 64.1725 Ops/s | 64.5685 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 19.6189ms | 17.6666ms | 56.6039 Ops/s | 57.4613 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 20.2932ms | 19.8029ms | 50.4977 Ops/s | 48.9985 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 19.8625ms | 17.8468ms | 56.0325 Ops/s | 56.7893 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 20.2542ms | 19.7965ms | 50.5139 Ops/s | 48.8564 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 20.7044ms | 19.2203ms | 52.0282 Ops/s | 51.7750 Ops/s |
@albertbou92 THere is an annoying behaviour that if the composite spec in the primer can be set to the parent shape it is, oterwise it is expanded. I didn't realize how unstable this was: this basically means that if your spec has shape [2] and by chance you're using a parallel env with 2 workers, you don't expand, but if you change the number of workers you do... # 1. No indication of shape
primer=TensorDictPrimer(stuff=spec) # no batch_size
parallel_env.append_transform(primer) # expands the specs
# 2. Using a composite spec to indicate the shape
primer=TensorDictPrimer(composite_spec)
parallel_env.append_transform(primer) # expands the spec shape only if the composite shape differs from the env shape
# 3. No indication of shape
primer=TensorDictPrimer(stuff=spec, batch_size=[n_envs]) # batch_size explicitly set
parallel_env.append_transform(primer) # expands the specs It's going to be hard to transition to that behaviour but I think it's worth it. So what we need is that for the case (1), if the shape can be reset, we warn users that we are changing the spec shape but tht will not be the case anymore in v0.8 (I'm doing just one release cycle here because the behaviour is pretty bad). |
…fically for LSTM and GRU) ghstack-source-id: d5981b4dbee8305250faa776c46424c7cf959578 Pull Request resolved: #2668
Yes I remember seeing this behavior when doing the PR for the What about having a boolean paramater in |
that's exactly what I implemented here! glad you agree it's a good idea |
…fically for LSTM and GRU) ghstack-source-id: e1da58ecfd36ca01b8a11fe90e5f3c5fe77f064c Pull Request resolved: #2668
Stack from ghstack (oldest at bottom):