-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add Hash transform and UnaryTransform #2648
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2648
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 80f920674e13db2fcbed6e82a990d35cb14c6d11 Pull Request resolved: #2648
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.5290s | 0.4472s | 2.2361 Ops/s | 2.2616 Ops/s | |
test_transformed | 0.6161s | 0.6119s | 1.6343 Ops/s | 1.5894 Ops/s | |
test_serial | 1.4600s | 1.3796s | 0.7249 Ops/s | 0.7273 Ops/s | |
test_parallel | 1.2783s | 1.1934s | 0.8380 Ops/s | 0.8202 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1267ms | 30.0075μs | 33.3250 KOps/s | 33.5134 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 46.5670μs | 17.8031μs | 56.1700 KOps/s | 56.5197 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 0.6046ms | 17.0524μs | 58.6426 KOps/s | 58.4483 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 39.7340μs | 10.0789μs | 99.2175 KOps/s | 100.0470 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 73.6870μs | 31.8865μs | 31.3612 KOps/s | 31.2222 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 57.8070μs | 19.6421μs | 50.9110 KOps/s | 50.2463 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 47.0970μs | 19.0108μs | 52.6016 KOps/s | 52.5110 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 49.5220μs | 11.8912μs | 84.0959 KOps/s | 83.8433 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 67.9670μs | 34.0286μs | 29.3871 KOps/s | 29.3157 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 51.2650μs | 21.5278μs | 46.4516 KOps/s | 46.0956 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 50.9550μs | 18.8029μs | 53.1832 KOps/s | 52.8013 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 38.8220μs | 11.8597μs | 84.3192 KOps/s | 85.0363 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 79.0670μs | 35.7029μs | 28.0090 KOps/s | 28.0791 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 57.8280μs | 23.1225μs | 43.2480 KOps/s | 43.1858 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 65.7120μs | 20.5800μs | 48.5908 KOps/s | 48.5559 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 38.2920μs | 13.6213μs | 73.4142 KOps/s | 73.2405 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 68.9680μs | 33.8079μs | 29.5789 KOps/s | 29.6191 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 62.9270μs | 21.6386μs | 46.2137 KOps/s | 46.2559 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 49.2010μs | 21.4495μs | 46.6211 KOps/s | 46.1631 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 42.2980μs | 13.1571μs | 76.0048 KOps/s | 76.0514 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 75.7320μs | 35.6965μs | 28.0140 KOps/s | 28.1114 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 0.6136ms | 23.2474μs | 43.0156 KOps/s | 43.0357 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 2.4282ms | 23.4999μs | 42.5533 KOps/s | 42.6470 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 42.2890μs | 14.8560μs | 67.3128 KOps/s | 67.0278 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 91.0100μs | 37.3216μs | 26.7942 KOps/s | 26.8146 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 59.1200μs | 25.0803μs | 39.8719 KOps/s | 39.6938 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 56.9260μs | 23.2556μs | 43.0004 KOps/s | 42.6197 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 40.2250μs | 14.8243μs | 67.4567 KOps/s | 66.6987 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 87.2120μs | 38.9866μs | 25.6498 KOps/s | 25.6890 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 74.4480μs | 26.4605μs | 37.7922 KOps/s | 37.2733 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 70.7610μs | 24.6234μs | 40.6118 KOps/s | 39.9908 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 46.2060μs | 16.5346μs | 60.4793 KOps/s | 59.9282 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 12.6528ms | 9.7742ms | 102.3100 Ops/s | 102.8789 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 41.4308ms | 33.4157ms | 29.9260 Ops/s | 27.5509 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2344ms | 0.1744ms | 5.7332 KOps/s | 5.6530 KOps/s | |
test_values[td1_return_estimate-False-False] | 29.3649ms | 23.9156ms | 41.8137 Ops/s | 41.0564 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 44.2580ms | 33.9239ms | 29.4777 Ops/s | 27.6724 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 36.8336ms | 34.4002ms | 29.0696 Ops/s | 29.2700 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 35.0667ms | 33.3177ms | 30.0141 Ops/s | 27.6184 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 8.5767ms | 8.3494ms | 119.7696 Ops/s | 119.8755 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.2279ms | 1.8425ms | 542.7438 Ops/s | 545.0138 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.4744ms | 0.3627ms | 2.7569 KOps/s | 2.8602 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 37.4168ms | 35.7617ms | 27.9629 Ops/s | 23.0392 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 3.7893ms | 3.0438ms | 328.5375 Ops/s | 328.4141 Ops/s | |
test_dqn_speed[False-None] | 5.7222ms | 1.4129ms | 707.7817 Ops/s | 716.8092 Ops/s | |
test_dqn_speed[False-backward] | 3.2125ms | 1.9662ms | 508.6060 Ops/s | 532.9037 Ops/s | |
test_dqn_speed[True-None] | 0.8007ms | 0.4844ms | 2.0646 KOps/s | 2.0553 KOps/s | |
test_dqn_speed[True-backward] | 0.9610ms | 0.9003ms | 1.1108 KOps/s | 845.6088 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.7935ms | 0.4807ms | 2.0802 KOps/s | 2.0371 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 0.9667ms | 0.8952ms | 1.1170 KOps/s | 1.1050 KOps/s | |
test_ddpg_speed[False-None] | 0.1710s | 3.4166ms | 292.6929 Ops/s | 342.3729 Ops/s | |
test_ddpg_speed[False-backward] | 4.3710ms | 4.0339ms | 247.9006 Ops/s | 247.6925 Ops/s | |
test_ddpg_speed[True-None] | 1.7000ms | 1.0195ms | 980.8614 Ops/s | 980.0364 Ops/s | |
test_ddpg_speed[True-backward] | 2.0325ms | 1.8975ms | 527.0156 Ops/s | 444.4961 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.3218ms | 1.0101ms | 990.0011 Ops/s | 979.3507 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 1.9708ms | 1.8851ms | 530.4892 Ops/s | 491.8967 Ops/s | |
test_sac_speed[False-None] | 9.9145ms | 8.0705ms | 123.9086 Ops/s | 120.7672 Ops/s | |
test_sac_speed[False-backward] | 11.2898ms | 10.7922ms | 92.6599 Ops/s | 91.8345 Ops/s | |
test_sac_speed[True-None] | 2.3304ms | 1.8298ms | 546.4951 Ops/s | 531.3199 Ops/s | |
test_sac_speed[True-backward] | 3.5974ms | 3.5007ms | 285.6576 Ops/s | 284.6879 Ops/s | |
test_sac_speed[reduce-overhead-None] | 2.0863ms | 1.8468ms | 541.4900 Ops/s | 530.7679 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 3.8320ms | 3.5686ms | 280.2202 Ops/s | 281.3172 Ops/s | |
test_redq_speed[False-None] | 14.6338ms | 13.0297ms | 76.7475 Ops/s | 66.9598 Ops/s | |
test_redq_speed[False-backward] | 24.3210ms | 22.4163ms | 44.6105 Ops/s | 43.8640 Ops/s | |
test_redq_speed[True-None] | 5.9590ms | 4.8347ms | 206.8373 Ops/s | 198.3693 Ops/s | |
test_redq_speed[True-backward] | 13.5484ms | 12.6626ms | 78.9724 Ops/s | 81.2935 Ops/s | |
test_redq_speed[reduce-overhead-None] | 5.4862ms | 4.9937ms | 200.2506 Ops/s | 205.8034 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 13.4836ms | 12.4971ms | 80.0184 Ops/s | 78.7150 Ops/s | |
test_redq_deprec_speed[False-None] | 15.6847ms | 13.5296ms | 73.9122 Ops/s | 73.1946 Ops/s | |
test_redq_deprec_speed[False-backward] | 20.4480ms | 19.2759ms | 51.8783 Ops/s | 51.4507 Ops/s | |
test_redq_deprec_speed[True-None] | 4.5735ms | 3.7963ms | 263.4138 Ops/s | 268.4555 Ops/s | |
test_redq_deprec_speed[True-backward] | 8.8809ms | 8.3779ms | 119.3617 Ops/s | 118.9851 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 4.0915ms | 3.5667ms | 280.3734 Ops/s | 266.3378 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 8.8301ms | 8.0133ms | 124.7920 Ops/s | 120.0081 Ops/s | |
test_td3_speed[False-None] | 8.4725ms | 8.0666ms | 123.9680 Ops/s | 120.1468 Ops/s | |
test_td3_speed[False-backward] | 11.8839ms | 10.4430ms | 95.7582 Ops/s | 93.2840 Ops/s | |
test_td3_speed[True-None] | 1.8121ms | 1.7195ms | 581.5631 Ops/s | 541.7906 Ops/s | |
test_td3_speed[True-backward] | 3.6002ms | 3.3290ms | 300.3932 Ops/s | 291.4078 Ops/s | |
test_td3_speed[reduce-overhead-None] | 2.1107ms | 1.7412ms | 574.3062 Ops/s | 549.9013 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 3.8302ms | 3.3407ms | 299.3361 Ops/s | 287.3566 Ops/s | |
test_cql_speed[False-None] | 41.1688ms | 36.5928ms | 27.3278 Ops/s | 26.9225 Ops/s | |
test_cql_speed[False-backward] | 49.1439ms | 46.3088ms | 21.5942 Ops/s | 20.2801 Ops/s | |
test_cql_speed[True-None] | 17.3002ms | 15.8481ms | 63.0992 Ops/s | 61.8985 Ops/s | |
test_cql_speed[True-backward] | 22.9486ms | 22.1165ms | 45.2152 Ops/s | 43.6300 Ops/s | |
test_cql_speed[reduce-overhead-None] | 16.6903ms | 15.5519ms | 64.3009 Ops/s | 62.0220 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 24.8505ms | 22.3095ms | 44.8240 Ops/s | 43.4792 Ops/s | |
test_a2c_speed[False-None] | 8.5368ms | 7.2017ms | 138.8556 Ops/s | 136.8660 Ops/s | |
test_a2c_speed[False-backward] | 16.2126ms | 14.3615ms | 69.6306 Ops/s | 63.8182 Ops/s | |
test_a2c_speed[True-None] | 4.9895ms | 4.2539ms | 235.0779 Ops/s | 234.7667 Ops/s | |
test_a2c_speed[True-backward] | 11.4599ms | 10.6771ms | 93.6581 Ops/s | 90.9029 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 4.6352ms | 4.1924ms | 238.5278 Ops/s | 234.1488 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 11.5761ms | 10.8999ms | 91.7443 Ops/s | 89.9011 Ops/s | |
test_ppo_speed[False-None] | 8.8106ms | 7.5604ms | 132.2677 Ops/s | 131.1499 Ops/s | |
test_ppo_speed[False-backward] | 15.4575ms | 14.7885ms | 67.6201 Ops/s | 67.5794 Ops/s | |
test_ppo_speed[True-None] | 4.0291ms | 3.7029ms | 270.0554 Ops/s | 261.3203 Ops/s | |
test_ppo_speed[True-backward] | 10.0475ms | 9.5596ms | 104.6067 Ops/s | 101.6184 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 4.0561ms | 3.6864ms | 271.2707 Ops/s | 265.2088 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 10.1259ms | 9.5208ms | 105.0337 Ops/s | 101.8031 Ops/s | |
test_reinforce_speed[False-None] | 7.4419ms | 6.5215ms | 153.3397 Ops/s | 145.9325 Ops/s | |
test_reinforce_speed[False-backward] | 10.6267ms | 9.7378ms | 102.6926 Ops/s | 95.8061 Ops/s | |
test_reinforce_speed[True-None] | 3.2625ms | 2.6431ms | 378.3477 Ops/s | 359.6615 Ops/s | |
test_reinforce_speed[True-backward] | 9.3378ms | 8.5689ms | 116.7009 Ops/s | 113.7759 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 3.1248ms | 2.6287ms | 380.4167 Ops/s | 367.4783 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 9.1706ms | 8.5532ms | 116.9154 Ops/s | 114.6912 Ops/s | |
test_iql_speed[False-None] | 34.0847ms | 32.2927ms | 30.9667 Ops/s | 15.7194 Ops/s | |
test_iql_speed[False-backward] | 45.9114ms | 45.1407ms | 22.1529 Ops/s | 20.7136 Ops/s | |
test_iql_speed[True-None] | 10.9943ms | 10.5731ms | 94.5799 Ops/s | 90.0445 Ops/s | |
test_iql_speed[True-backward] | 25.8142ms | 21.6890ms | 46.1063 Ops/s | 44.8402 Ops/s | |
test_iql_speed[reduce-overhead-None] | 11.2877ms | 10.5670ms | 94.6340 Ops/s | 92.2287 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 22.8135ms | 21.9855ms | 45.4844 Ops/s | 44.7075 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.4652ms | 4.9464ms | 202.1670 Ops/s | 199.6377 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.9227ms | 0.6489ms | 1.5411 KOps/s | 1.5530 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.8876ms | 0.6312ms | 1.5842 KOps/s | 1.6114 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.3456ms | 4.7591ms | 210.1246 Ops/s | 207.3786 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.2368ms | 0.6330ms | 1.5798 KOps/s | 1.5788 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.8707ms | 0.6111ms | 1.6363 KOps/s | 1.6439 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.1611ms | 1.8841ms | 530.7446 Ops/s | 538.7819 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 2.1178ms | 1.6968ms | 589.3576 Ops/s | 605.4755 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 5.4937ms | 4.8479ms | 206.2756 Ops/s | 205.4071 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.9244ms | 0.8015ms | 1.2477 KOps/s | 1.2765 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.9969ms | 0.7470ms | 1.3387 KOps/s | 1.3317 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.7625ms | 4.8325ms | 206.9332 Ops/s | 209.5905 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.8306ms | 0.6553ms | 1.5260 KOps/s | 467.1524 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.9221ms | 0.6353ms | 1.5741 KOps/s | 1.6651 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 7.1287ms | 4.7040ms | 212.5860 Ops/s | 209.3291 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.6500ms | 0.6444ms | 1.5518 KOps/s | 1.5658 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.8547ms | 0.6151ms | 1.6256 KOps/s | 1.7364 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 7.1431ms | 4.8355ms | 206.8056 Ops/s | 198.6828 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.4284ms | 0.7708ms | 1.2974 KOps/s | 1.2728 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.9935ms | 0.7531ms | 1.3279 KOps/s | 1.3568 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 8.3124ms | 4.9800ms | 200.8032 Ops/s | 226.8413 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 10.0841ms | 2.4398ms | 409.8766 Ops/s | 432.0863 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 2.1026ms | 1.4153ms | 706.5488 Ops/s | 714.0086 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.4115s | 12.5579ms | 79.6309 Ops/s | 239.0829 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 3.4316ms | 2.1753ms | 459.7006 Ops/s | 440.1789 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 7.3291ms | 1.5228ms | 656.6707 Ops/s | 683.7119 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 6.9487ms | 4.5621ms | 219.1953 Ops/s | 239.2955 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 7.3161ms | 2.4362ms | 410.4723 Ops/s | 366.7970 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 5.3928ms | 1.5598ms | 641.0878 Ops/s | 610.4649 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 13.2004ms | 12.9312ms | 77.3321 Ops/s | 74.0582 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 24.4711ms | 15.2878ms | 65.4114 Ops/s | 67.9429 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 23.8141ms | 22.1758ms | 45.0942 Ops/s | 44.1667 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 16.1023ms | 15.1217ms | 66.1300 Ops/s | 66.4519 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 0.3642s | 28.7157ms | 34.8242 Ops/s | 45.7333 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 17.3643ms | 16.5007ms | 60.6036 Ops/s | 60.9485 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.8076s | 0.7228s | 1.3834 Ops/s | 1.3660 Ops/s | |
test_transformed | 0.9477s | 0.9464s | 1.0566 Ops/s | 1.0140 Ops/s | |
test_serial | 2.1603s | 2.1239s | 0.4708 Ops/s | 0.4649 Ops/s | |
test_parallel | 1.8439s | 1.8156s | 0.5508 Ops/s | 0.5532 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.2328ms | 39.4697μs | 25.3359 KOps/s | 25.5178 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 52.1010μs | 22.7354μs | 43.9843 KOps/s | 43.8833 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 54.0200μs | 22.0555μs | 45.3402 KOps/s | 46.4338 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 51.9010μs | 12.6700μs | 78.9266 KOps/s | 78.5212 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 0.1050ms | 41.7003μs | 23.9807 KOps/s | 23.8179 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 0.2082ms | 24.5758μs | 40.6904 KOps/s | 40.0796 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.2059ms | 23.9456μs | 41.7613 KOps/s | 41.9835 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 56.3110μs | 14.4099μs | 69.3966 KOps/s | 67.7190 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 83.3220μs | 42.4752μs | 23.5432 KOps/s | 22.7368 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 71.0410μs | 26.9451μs | 37.1126 KOps/s | 36.5551 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 54.5410μs | 24.0904μs | 41.5102 KOps/s | 41.6255 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 50.9010μs | 14.9422μs | 66.9246 KOps/s | 67.4907 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 74.9620μs | 45.6589μs | 21.9015 KOps/s | 21.5280 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 65.3410μs | 28.9204μs | 34.5777 KOps/s | 33.6863 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 65.5210μs | 26.1485μs | 38.2431 KOps/s | 38.0237 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 0.1074ms | 16.9986μs | 58.8284 KOps/s | 59.0478 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 88.2110μs | 43.8301μs | 22.8154 KOps/s | 22.8703 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 0.1003ms | 27.5419μs | 36.3083 KOps/s | 36.8439 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 59.6610μs | 28.4744μs | 35.1193 KOps/s | 36.4748 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 50.7310μs | 16.8227μs | 59.4436 KOps/s | 60.8517 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 73.3510μs | 46.7684μs | 21.3820 KOps/s | 21.7651 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 61.8810μs | 29.5796μs | 33.8071 KOps/s | 34.2920 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 3.2336ms | 30.4365μs | 32.8553 KOps/s | 34.0343 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 52.4210μs | 18.5469μs | 53.9173 KOps/s | 53.2638 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 93.1410μs | 48.3628μs | 20.6771 KOps/s | 20.8381 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 72.0210μs | 31.7771μs | 31.4692 KOps/s | 32.2751 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 70.0810μs | 30.0062μs | 33.3265 KOps/s | 34.1748 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 56.1810μs | 18.6271μs | 53.6853 KOps/s | 53.9027 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 0.1037ms | 49.5058μs | 20.1997 KOps/s | 20.0388 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 77.0320μs | 32.9100μs | 30.3859 KOps/s | 29.8980 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 81.7710μs | 31.8505μs | 31.3967 KOps/s | 32.8574 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 67.0110μs | 20.5120μs | 48.7521 KOps/s | 48.0223 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 26.1825ms | 25.5076ms | 39.2040 Ops/s | 37.8826 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 98.1732ms | 2.8626ms | 349.3345 Ops/s | 338.9625 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.1053ms | 78.6232μs | 12.7189 KOps/s | 11.6823 KOps/s | |
test_values[td1_return_estimate-False-False] | 58.5348ms | 57.5260ms | 17.3835 Ops/s | 16.9726 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 1.4665ms | 1.1130ms | 898.4375 Ops/s | 917.6793 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 92.0199ms | 86.9700ms | 11.4982 Ops/s | 11.1488 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 1.3643ms | 1.0793ms | 926.5337 Ops/s | 922.5207 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 24.5622ms | 24.4525ms | 40.8957 Ops/s | 39.8866 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.0268ms | 0.7453ms | 1.3417 KOps/s | 1.3016 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.7575ms | 0.6667ms | 1.4999 KOps/s | 1.4869 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.6436ms | 1.4756ms | 677.6862 Ops/s | 673.9795 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.7099ms | 0.6835ms | 1.4631 KOps/s | 1.4498 KOps/s | |
test_dqn_speed[False-None] | 6.9220ms | 1.5120ms | 661.3613 Ops/s | 665.7231 Ops/s | |
test_dqn_speed[False-backward] | 2.1314ms | 2.0874ms | 479.0607 Ops/s | 473.9189 Ops/s | |
test_dqn_speed[True-None] | 0.6209ms | 0.5305ms | 1.8850 KOps/s | 1.8295 KOps/s | |
test_dqn_speed[True-backward] | 1.1058ms | 1.0667ms | 937.4630 Ops/s | 917.7319 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.7620ms | 0.5546ms | 1.8030 KOps/s | 1.7416 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.0984ms | 0.9350ms | 1.0695 KOps/s | 1.0545 KOps/s | |
test_ddpg_speed[False-None] | 3.0876ms | 2.7684ms | 361.2156 Ops/s | 353.8366 Ops/s | |
test_ddpg_speed[False-backward] | 4.4869ms | 4.0268ms | 248.3347 Ops/s | 244.8165 Ops/s | |
test_ddpg_speed[True-None] | 1.1769ms | 1.0357ms | 965.5296 Ops/s | 949.7928 Ops/s | |
test_ddpg_speed[True-backward] | 2.1180ms | 2.0701ms | 483.0576 Ops/s | 441.2378 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.3260ms | 1.0959ms | 912.4818 Ops/s | 885.4017 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 1.7106ms | 1.5869ms | 630.1603 Ops/s | 564.4047 Ops/s | |
test_sac_speed[False-None] | 8.2616ms | 7.8236ms | 127.8177 Ops/s | 124.2087 Ops/s | |
test_sac_speed[False-backward] | 11.1999ms | 10.7334ms | 93.1668 Ops/s | 89.4013 Ops/s | |
test_sac_speed[True-None] | 1.6826ms | 1.5227ms | 656.7085 Ops/s | 651.1188 Ops/s | |
test_sac_speed[True-backward] | 3.2251ms | 3.1207ms | 320.4387 Ops/s | 300.8818 Ops/s | |
test_sac_speed[reduce-overhead-None] | 22.8804ms | 12.5937ms | 79.4050 Ops/s | 79.7545 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 1.4175ms | 1.3319ms | 750.7829 Ops/s | 743.2104 Ops/s | |
test_redq_speed[False-None] | 8.2488ms | 7.3957ms | 135.2145 Ops/s | 133.1476 Ops/s | |
test_redq_speed[False-backward] | 11.9209ms | 11.1280ms | 89.8632 Ops/s | 87.9077 Ops/s | |
test_redq_speed[True-None] | 2.0144ms | 1.9376ms | 516.0927 Ops/s | 489.8320 Ops/s | |
test_redq_speed[True-backward] | 3.6916ms | 3.5529ms | 281.4592 Ops/s | 264.2285 Ops/s | |
test_redq_speed[reduce-overhead-None] | 2.0852ms | 1.9920ms | 502.0147 Ops/s | 503.8005 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 3.9074ms | 3.7563ms | 266.2200 Ops/s | 281.0493 Ops/s | |
test_redq_deprec_speed[False-None] | 9.3489ms | 8.8514ms | 112.9766 Ops/s | 110.1052 Ops/s | |
test_redq_deprec_speed[False-backward] | 12.6542ms | 12.1221ms | 82.4941 Ops/s | 82.5828 Ops/s | |
test_redq_deprec_speed[True-None] | 2.3898ms | 2.2615ms | 442.1804 Ops/s | 435.2707 Ops/s | |
test_redq_deprec_speed[True-backward] | 4.5006ms | 4.0535ms | 246.6975 Ops/s | 255.0296 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 2.3304ms | 2.2584ms | 442.7935 Ops/s | 411.9323 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 4.1039ms | 4.0394ms | 247.5629 Ops/s | 240.1412 Ops/s | |
test_td3_speed[False-None] | 7.9599ms | 7.7300ms | 129.3658 Ops/s | 126.2744 Ops/s | |
test_td3_speed[False-backward] | 10.8963ms | 10.3856ms | 96.2869 Ops/s | 95.7370 Ops/s | |
test_td3_speed[True-None] | 1.6309ms | 1.5497ms | 645.2898 Ops/s | 629.6804 Ops/s | |
test_td3_speed[True-backward] | 3.5605ms | 3.1876ms | 313.7122 Ops/s | 314.0021 Ops/s | |
test_td3_speed[reduce-overhead-None] | 50.6054ms | 25.6889ms | 38.9274 Ops/s | 39.3726 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 1.5208ms | 1.4459ms | 691.6266 Ops/s | 687.4419 Ops/s | |
test_cql_speed[False-None] | 16.9177ms | 16.3577ms | 61.1332 Ops/s | 59.6858 Ops/s | |
test_cql_speed[False-backward] | 24.4541ms | 22.0378ms | 45.3765 Ops/s | 44.4716 Ops/s | |
test_cql_speed[True-None] | 2.9408ms | 2.7958ms | 357.6737 Ops/s | 350.9445 Ops/s | |
test_cql_speed[True-backward] | 5.4648ms | 5.0430ms | 198.2963 Ops/s | 191.2798 Ops/s | |
test_cql_speed[reduce-overhead-None] | 0.3473s | 14.7800ms | 67.6588 Ops/s | 75.8656 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 1.6780ms | 1.6189ms | 617.6880 Ops/s | 611.5129 Ops/s | |
test_a2c_speed[False-None] | 3.3613ms | 3.2022ms | 312.2892 Ops/s | 296.3951 Ops/s | |
test_a2c_speed[False-backward] | 6.4511ms | 6.2986ms | 158.7666 Ops/s | 154.1617 Ops/s | |
test_a2c_speed[True-None] | 1.4188ms | 0.9900ms | 1.0101 KOps/s | 960.0807 Ops/s | |
test_a2c_speed[True-backward] | 2.5872ms | 2.5088ms | 398.5974 Ops/s | 377.8812 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 21.5120ms | 11.4843ms | 87.0752 Ops/s | 87.5093 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 1.0512ms | 0.9819ms | 1.0184 KOps/s | 1.0079 KOps/s | |
test_ppo_speed[False-None] | 4.0311ms | 3.6238ms | 275.9555 Ops/s | 268.7254 Ops/s | |
test_ppo_speed[False-backward] | 7.2378ms | 6.7729ms | 147.6464 Ops/s | 141.2037 Ops/s | |
test_ppo_speed[True-None] | 1.3777ms | 0.9498ms | 1.0528 KOps/s | 993.6404 Ops/s | |
test_ppo_speed[True-backward] | 2.6740ms | 2.6054ms | 383.8175 Ops/s | 376.2651 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 0.6557ms | 0.5194ms | 1.9254 KOps/s | 69.6515 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 1.1805ms | 1.1047ms | 905.2424 Ops/s | 856.3782 Ops/s | |
test_reinforce_speed[False-None] | 2.3632ms | 2.2157ms | 451.3291 Ops/s | 443.8765 Ops/s | |
test_reinforce_speed[False-backward] | 3.5221ms | 3.3111ms | 302.0188 Ops/s | 296.0995 Ops/s | |
test_reinforce_speed[True-None] | 0.8678ms | 0.8034ms | 1.2447 KOps/s | 1.1744 KOps/s | |
test_reinforce_speed[True-backward] | 2.7046ms | 2.5644ms | 389.9547 Ops/s | 393.1938 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 0.2908s | 11.9007ms | 84.0287 Ops/s | 89.0319 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 1.1067ms | 1.0258ms | 974.8572 Ops/s | 853.0907 Ops/s | |
test_iql_speed[False-None] | 9.7628ms | 9.1891ms | 108.8242 Ops/s | 107.7653 Ops/s | |
test_iql_speed[False-backward] | 13.6639ms | 13.0593ms | 76.5736 Ops/s | 75.2112 Ops/s | |
test_iql_speed[True-None] | 2.1027ms | 1.7168ms | 582.4845 Ops/s | 569.7085 Ops/s | |
test_iql_speed[True-backward] | 4.1575ms | 4.0779ms | 245.2234 Ops/s | 242.3399 Ops/s | |
test_iql_speed[reduce-overhead-None] | 20.5050ms | 11.5353ms | 86.6901 Ops/s | 87.7550 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 1.4734ms | 1.4151ms | 706.6566 Ops/s | 609.1690 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.6930ms | 6.1477ms | 162.6617 Ops/s | 158.3958 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.5299ms | 0.2962ms | 3.3759 KOps/s | 3.2922 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.5037ms | 0.3068ms | 3.2594 KOps/s | 3.5978 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.1467ms | 5.9118ms | 169.1536 Ops/s | 166.1354 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.9191ms | 0.3112ms | 3.2138 KOps/s | 3.0230 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5788ms | 0.2901ms | 3.4472 KOps/s | 3.1775 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.4687ms | 1.3051ms | 766.2092 Ops/s | 770.4529 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.4732ms | 1.2390ms | 807.1168 Ops/s | 844.9509 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.1520ms | 6.0737ms | 164.6441 Ops/s | 161.7937 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.6588ms | 0.4770ms | 2.0966 KOps/s | 2.1756 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7870ms | 0.4402ms | 2.2718 KOps/s | 2.5665 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.0632ms | 5.9554ms | 167.9159 Ops/s | 166.1611 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.8374ms | 0.3552ms | 2.8153 KOps/s | 3.0571 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6386ms | 0.3469ms | 2.8827 KOps/s | 3.7831 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 9.0609ms | 5.8707ms | 170.3376 Ops/s | 167.2954 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.7646ms | 0.2683ms | 3.7270 KOps/s | 3.4069 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.4582ms | 0.2465ms | 4.0563 KOps/s | 3.6694 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.3812ms | 6.0369ms | 165.6489 Ops/s | 162.6867 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.0845ms | 0.4369ms | 2.2887 KOps/s | 2.0226 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7126ms | 0.4225ms | 2.3671 KOps/s | 2.2988 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 6.9612ms | 5.3328ms | 187.5197 Ops/s | 184.3116 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 6.1968ms | 1.9980ms | 500.4984 Ops/s | 429.7377 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 8.9725ms | 1.2451ms | 803.1428 Ops/s | 866.1653 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 7.7482ms | 5.3072ms | 188.4246 Ops/s | 186.1914 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 7.6569ms | 1.9967ms | 500.8247 Ops/s | 452.4583 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 7.8490ms | 1.2593ms | 794.0617 Ops/s | 801.5531 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.4887s | 15.2526ms | 65.5627 Ops/s | 33.4260 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 9.5617ms | 2.2172ms | 451.0210 Ops/s | 453.1721 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 7.1747ms | 1.3674ms | 731.3090 Ops/s | 735.8641 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 15.0227ms | 14.8505ms | 67.3379 Ops/s | 64.3550 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 19.6130ms | 17.6047ms | 56.8031 Ops/s | 56.4786 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 19.3297ms | 19.1073ms | 52.3361 Ops/s | 51.0484 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 20.1786ms | 17.8109ms | 56.1455 Ops/s | 56.4756 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 20.2087ms | 19.7857ms | 50.5415 Ops/s | 51.1188 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 20.9951ms | 19.1036ms | 52.3461 Ops/s | 51.9130 Ops/s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to interact with SipHash and other hash modules?
Or perhpas a tokenizer (not strictly a hash but the signature is similar for strings).
Or perhaps a custom hash function?
I'm happy with this being restrictive but if possible I'd prefer to avoid having multple transforms that do the str -> int map
Good questions. I agree that the transform should allow the user to specify any hashing function or tokenizer they want to use, including
The It seems like what we really need is just a general key-wise
In order to use this to implement a Python-hash transform for ChessEnv, for instance, we would just do:
Or to implement a transform that uses just the Python-hash transform, like the one currently in this PR, we would just do this:
Although it would be nice if the user didn't have to even think about specs. I wonder if it would be possible to make a transform automatically guess what the spec updates need to be, based on the output of We could allow the user to optionally specify inverse keys and an inverse function to We could also consider other kinds of generalized transforms. For instance, What do you think? |
@kurtamohler I'm up to it up until |
For the record, here is a script that makes it possible to use sha256 hashes (fewer collisions) in a reproducible manner (ie, sort of seeded) import hashlib
def reproducible_hash_parts(string, seed):
"""
Creates a reproducible 256-bit hash from a string using a seed and splits it into four 64-bit parts.
Args:
string (str): The input string.
seed (str): The seed value.
Returns:
tuple: Four 64-bit integers representing the parts of the 256-bit hash value.
"""
# Prepend the seed to the string
seeded_string = seed + string
# Create a new SHA-256 hash object
hash_object = hashlib.sha256()
# Update the hash object with the seeded string
hash_object.update(seeded_string.encode('utf-8'))
# Get the hash value as bytes
hash_bytes = hash_object.digest()
# Split the hash bytes into four parts
part1 = hash_bytes[:8]
part2 = hash_bytes[8:16]
part3 = hash_bytes[16:24]
part4 = hash_bytes[24:]
# Convert each part to a 64-bit integer
part1_value = int.from_bytes(part1, 'big')
part2_value = int.from_bytes(part2, 'big')
part3_value = int.from_bytes(part3, 'big')
part4_value = int.from_bytes(part4, 'big')
return part1_value, part2_value, part3_value, part4_value
# Example usage:
string = "Hello, World!"
seed = "my_seed"
part1, part2, part3, part4 = reproducible_hash_parts(string, seed)
print(f"Part 1: {part1}")
print(f"Part 2: {part2}")
print(f"Part 3: {part3}")
print(f"Part 4: {part4}") |
Another random thought: we could add the option to store a table of hash-to-value within the transform class HashTransform(...):
_hash_table: Dict[HashType, str]
... and include that in the transform state-dict (or make it easy to save this for future use) |
ghstack-source-id: 80f920674e13db2fcbed6e82a990d35cb14c6d11 Pull Request resolved: pytorch#2648
ghstack-source-id: 80f920674e13db2fcbed6e82a990d35cb14c6d11 Pull Request resolved: pytorch#2648
ghstack-source-id: de4419eed15f97b79869b3701cfc2704da5a5e59 Pull Request resolved: #2648
Maybe I should add
What's the purpose you have in mind for that? Is there a reason someone would want to look up the value that corresponds to a hash if that pairing already exists in the output tensordict? Or would this be used to detect hash collision? I think this map would have to use weakrefs to the values so that we don't prevent them from being deallocated when all the other references are deleted |
The goal I had in mind was to keep track of the existing map for later use, although if Maybe we should default to hashlib for reproducibility |
ghstack-source-id: de4419eed15f97b79869b3701cfc2704da5a5e59 Pull Request resolved: pytorch#2648
ghstack-source-id: cf7d425e1297959b81a0902cb27bbffc5a51568f Pull Request resolved: #2648
ghstack-source-id: b966004ebf23f1d3888da673e383d95577406129 Pull Request resolved: #2648
ghstack-source-id: a55d94b8b782b579bfea63753350fd628365dc28 Pull Request resolved: #2648
ghstack-source-id: dccf63fe4f9d5f76947ddb7d5dedcff87ff8cdc5 Pull Request resolved: #2648
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much!
I added a an example in the docstrings and fixed the auto-spec for non-tensor
obs = torch.stack( | ||
[ | ||
NonTensorData(data="abcde"), | ||
NonTensorData(data="fghij"), | ||
NonTensorData(data="klmno"), | ||
] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the future you can also do NonTensorStack("a", "b", "c")
ghstack-source-id: dccf63fe4f9d5f76947ddb7d5dedcff87ff8cdc5 Pull Request resolved: #2648
Stack from ghstack (oldest at bottom):