Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Device transform #1472

Merged
merged 10 commits into from
Aug 30, 2023
Merged

[Feature] Device transform #1472

merged 10 commits into from
Aug 30, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Aug 26, 2023

Description

Adds a DeviceCastTransform transform to move environment data from one device to another.

As part of this PR, transforms now can transform the device of the parent env through transform.transform_env_device.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 26, 2023
@github-actions
Copy link

github-actions bot commented Aug 26, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 89. Improved: $\large\color{#35bf28}14$. Worsened: $\large\color{#d91a1a}7$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 0.1706s 0.1669s 5.9899 Ops/s 5.8858 Ops/s $\color{#35bf28}+1.77\%$
test_sync 0.1746s 96.2884ms 10.3855 Ops/s 10.3557 Ops/s $\color{#35bf28}+0.29\%$
test_async 0.1965s 91.2416ms 10.9599 Ops/s 11.2341 Ops/s $\color{#d91a1a}-2.44\%$
test_simple 0.7990s 0.7267s 1.3761 Ops/s 1.3444 Ops/s $\color{#35bf28}+2.35\%$
test_transformed 2.0001s 1.9342s 0.5170 Ops/s 0.5138 Ops/s $\color{#35bf28}+0.62\%$
test_serial 2.2181s 2.2111s 0.4523 Ops/s 0.4400 Ops/s $\color{#35bf28}+2.78\%$
test_parallel 2.0208s 1.9084s 0.5240 Ops/s 0.5193 Ops/s $\color{#35bf28}+0.91\%$
test_step_mdp_speed[True-True-True-True-True] 0.2062ms 51.3868μs 19.4602 KOps/s 18.6945 KOps/s $\color{#35bf28}+4.10\%$
test_step_mdp_speed[True-True-True-True-False] 58.4000μs 29.0685μs 34.4016 KOps/s 33.9803 KOps/s $\color{#35bf28}+1.24\%$
test_step_mdp_speed[True-True-True-False-True] 80.0010μs 35.8117μs 27.9238 KOps/s 27.1530 KOps/s $\color{#35bf28}+2.84\%$
test_step_mdp_speed[True-True-True-False-False] 52.2010μs 19.4511μs 51.4109 KOps/s 49.8145 KOps/s $\color{#35bf28}+3.20\%$
test_step_mdp_speed[True-True-False-True-True] 0.1924ms 53.2941μs 18.7638 KOps/s 18.3483 KOps/s $\color{#35bf28}+2.26\%$
test_step_mdp_speed[True-True-False-True-False] 59.6010μs 31.0975μs 32.1569 KOps/s 31.5577 KOps/s $\color{#35bf28}+1.90\%$
test_step_mdp_speed[True-True-False-False-True] 77.7010μs 38.0246μs 26.2988 KOps/s 25.5693 KOps/s $\color{#35bf28}+2.85\%$
test_step_mdp_speed[True-True-False-False-False] 49.4010μs 21.7991μs 45.8734 KOps/s 45.0826 KOps/s $\color{#35bf28}+1.75\%$
test_step_mdp_speed[True-False-True-True-True] 0.1116ms 54.3931μs 18.3847 KOps/s 17.5988 KOps/s $\color{#35bf28}+4.47\%$
test_step_mdp_speed[True-False-True-True-False] 0.1011ms 32.6938μs 30.5869 KOps/s 29.4644 KOps/s $\color{#35bf28}+3.81\%$
test_step_mdp_speed[True-False-True-False-True] 85.3000μs 37.7121μs 26.5167 KOps/s 25.2090 KOps/s $\textbf{\color{#35bf28}+5.19\%}$
test_step_mdp_speed[True-False-True-False-False] 0.3260ms 21.7206μs 46.0392 KOps/s 44.4008 KOps/s $\color{#35bf28}+3.69\%$
test_step_mdp_speed[True-False-False-True-True] 0.1138ms 57.7712μs 17.3097 KOps/s 16.6760 KOps/s $\color{#35bf28}+3.80\%$
test_step_mdp_speed[True-False-False-True-False] 91.5010μs 34.5483μs 28.9450 KOps/s 27.4407 KOps/s $\textbf{\color{#35bf28}+5.48\%}$
test_step_mdp_speed[True-False-False-False-True] 98.4010μs 40.4414μs 24.7271 KOps/s 23.6901 KOps/s $\color{#35bf28}+4.38\%$
test_step_mdp_speed[True-False-False-False-False] 61.4010μs 24.0787μs 41.5304 KOps/s 40.6046 KOps/s $\color{#35bf28}+2.28\%$
test_step_mdp_speed[False-True-True-True-True] 0.1383ms 54.5561μs 18.3298 KOps/s 17.8572 KOps/s $\color{#35bf28}+2.65\%$
test_step_mdp_speed[False-True-True-True-False] 76.7010μs 32.3930μs 30.8709 KOps/s 29.4453 KOps/s $\color{#35bf28}+4.84\%$
test_step_mdp_speed[False-True-True-False-True] 80.0010μs 43.6621μs 22.9032 KOps/s 21.3792 KOps/s $\textbf{\color{#35bf28}+7.13\%}$
test_step_mdp_speed[False-True-True-False-False] 57.6010μs 24.4134μs 40.9611 KOps/s 39.4962 KOps/s $\color{#35bf28}+3.71\%$
test_step_mdp_speed[False-True-False-True-True] 0.1360ms 56.6163μs 17.6628 KOps/s 17.1870 KOps/s $\color{#35bf28}+2.77\%$
test_step_mdp_speed[False-True-False-True-False] 1.9990ms 35.0609μs 28.5218 KOps/s 28.2080 KOps/s $\color{#35bf28}+1.11\%$
test_step_mdp_speed[False-True-False-False-True] 79.6010μs 46.6426μs 21.4396 KOps/s 20.9367 KOps/s $\color{#35bf28}+2.40\%$
test_step_mdp_speed[False-True-False-False-False] 0.6048ms 27.1760μs 36.7972 KOps/s 36.3694 KOps/s $\color{#35bf28}+1.18\%$
test_step_mdp_speed[False-False-True-True-True] 93.0010μs 57.5965μs 17.3622 KOps/s 16.3340 KOps/s $\textbf{\color{#35bf28}+6.29\%}$
test_step_mdp_speed[False-False-True-True-False] 99.4010μs 36.6149μs 27.3113 KOps/s 25.9249 KOps/s $\textbf{\color{#35bf28}+5.35\%}$
test_step_mdp_speed[False-False-True-False-True] 0.1027ms 46.2558μs 21.6189 KOps/s 20.7751 KOps/s $\color{#35bf28}+4.06\%$
test_step_mdp_speed[False-False-True-False-False] 0.1035ms 26.5307μs 37.6922 KOps/s 36.5885 KOps/s $\color{#35bf28}+3.02\%$
test_step_mdp_speed[False-False-False-True-True] 0.1085ms 60.7695μs 16.4556 KOps/s 15.7390 KOps/s $\color{#35bf28}+4.55\%$
test_step_mdp_speed[False-False-False-True-False] 73.1000μs 39.1460μs 25.5454 KOps/s 24.7423 KOps/s $\color{#35bf28}+3.25\%$
test_step_mdp_speed[False-False-False-False-True] 0.1069ms 48.0814μs 20.7981 KOps/s 19.9347 KOps/s $\color{#35bf28}+4.33\%$
test_step_mdp_speed[False-False-False-False-False] 90.0010μs 28.3642μs 35.2556 KOps/s 33.6793 KOps/s $\color{#35bf28}+4.68\%$
test_values[generalized_advantage_estimate-True-True] 20.4198ms 15.3463ms 65.1621 Ops/s 64.4635 Ops/s $\color{#35bf28}+1.08\%$
test_values[vec_generalized_advantage_estimate-True-True] 59.0384ms 48.9144ms 20.4439 Ops/s 20.5920 Ops/s $\color{#d91a1a}-0.72\%$
test_values[td0_return_estimate-False-False] 0.3849ms 0.2340ms 4.2739 KOps/s 4.2871 KOps/s $\color{#d91a1a}-0.31\%$
test_values[td1_return_estimate-False-False] 15.3761ms 14.9407ms 66.9311 Ops/s 66.5584 Ops/s $\color{#35bf28}+0.56\%$
test_values[vec_td1_return_estimate-False-False] 60.2913ms 50.9271ms 19.6359 Ops/s 20.6701 Ops/s $\textbf{\color{#d91a1a}-5.00\%}$
test_values[td_lambda_return_estimate-True-False] 38.2264ms 35.8200ms 27.9174 Ops/s 27.5790 Ops/s $\color{#35bf28}+1.23\%$
test_values[vec_td_lambda_return_estimate-True-False] 57.4065ms 48.5279ms 20.6067 Ops/s 20.8179 Ops/s $\color{#d91a1a}-1.01\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 13.8016ms 13.3419ms 74.9516 Ops/s 73.1015 Ops/s $\color{#35bf28}+2.53\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 6.1508ms 3.9497ms 253.1861 Ops/s 257.6675 Ops/s $\color{#d91a1a}-1.74\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 7.1712ms 0.5684ms 1.7594 KOps/s 1.7519 KOps/s $\color{#35bf28}+0.43\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 70.2361ms 63.8463ms 15.6626 Ops/s 15.7069 Ops/s $\color{#d91a1a}-0.28\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 4.0284ms 3.2687ms 305.9304 Ops/s 293.2661 Ops/s $\color{#35bf28}+4.32\%$
test_dqn_speed 5.7149ms 2.1868ms 457.2837 Ops/s 391.0891 Ops/s $\textbf{\color{#35bf28}+16.93\%}$
test_ddpg_speed 9.6905ms 3.2701ms 305.8012 Ops/s 292.2160 Ops/s $\color{#35bf28}+4.65\%$
test_sac_speed 15.3859ms 9.6827ms 103.2775 Ops/s 103.0539 Ops/s $\color{#35bf28}+0.22\%$
test_redq_speed 26.1642ms 18.6491ms 53.6220 Ops/s 52.2202 Ops/s $\color{#35bf28}+2.68\%$
test_redq_deprec_speed 24.8650ms 16.1853ms 61.7844 Ops/s 65.2020 Ops/s $\textbf{\color{#d91a1a}-5.24\%}$
test_td3_speed 21.8242ms 12.0606ms 82.9144 Ops/s 82.4932 Ops/s $\color{#35bf28}+0.51\%$
test_cql_speed 47.9024ms 44.0649ms 22.6938 Ops/s 26.4305 Ops/s $\textbf{\color{#d91a1a}-14.14\%}$
test_a2c_speed 13.1931ms 6.3226ms 158.1620 Ops/s 156.6433 Ops/s $\color{#35bf28}+0.97\%$
test_ppo_speed 16.5614ms 6.9434ms 144.0223 Ops/s 136.5848 Ops/s $\textbf{\color{#35bf28}+5.45\%}$
test_reinforce_speed 13.6830ms 5.0145ms 199.4201 Ops/s 198.7375 Ops/s $\color{#35bf28}+0.34\%$
test_iql_speed 36.1572ms 26.7018ms 37.4506 Ops/s 37.5658 Ops/s $\color{#d91a1a}-0.31\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.8738ms 3.1655ms 315.9098 Ops/s 313.1811 Ops/s $\color{#35bf28}+0.87\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 6.5373ms 3.4157ms 292.7661 Ops/s 298.2416 Ops/s $\color{#d91a1a}-1.84\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 5.5589ms 3.3167ms 301.5002 Ops/s 301.1843 Ops/s $\color{#35bf28}+0.10\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.5864ms 3.1307ms 319.4217 Ops/s 243.4505 Ops/s $\textbf{\color{#35bf28}+31.21\%}$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 5.5815ms 3.3614ms 297.4982 Ops/s 300.1275 Ops/s $\color{#d91a1a}-0.88\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 5.6793ms 3.3209ms 301.1205 Ops/s 293.2555 Ops/s $\color{#35bf28}+2.68\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 4.4857ms 3.1962ms 312.8671 Ops/s 316.5940 Ops/s $\color{#d91a1a}-1.18\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 6.2774ms 3.3359ms 299.7680 Ops/s 300.1600 Ops/s $\color{#d91a1a}-0.13\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.1632s 3.8027ms 262.9731 Ops/s 298.4259 Ops/s $\textbf{\color{#d91a1a}-11.88\%}$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 4.1782ms 3.0862ms 324.0265 Ops/s 305.2262 Ops/s $\textbf{\color{#35bf28}+6.16\%}$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 5.7881ms 3.3166ms 301.5157 Ops/s 293.5156 Ops/s $\color{#35bf28}+2.73\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 5.7985ms 3.3784ms 295.9952 Ops/s 293.9190 Ops/s $\color{#35bf28}+0.71\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 4.3294ms 3.0976ms 322.8313 Ops/s 305.3118 Ops/s $\textbf{\color{#35bf28}+5.74\%}$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 4.9160ms 3.3900ms 294.9852 Ops/s 296.8634 Ops/s $\color{#d91a1a}-0.63\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 7.0524ms 3.3567ms 297.9141 Ops/s 293.9346 Ops/s $\color{#35bf28}+1.35\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 4.1296ms 3.0831ms 324.3459 Ops/s 316.4863 Ops/s $\color{#35bf28}+2.48\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 6.6937ms 3.3646ms 297.2080 Ops/s 296.2106 Ops/s $\color{#35bf28}+0.34\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 6.7701ms 3.3587ms 297.7371 Ops/s 295.3183 Ops/s $\color{#35bf28}+0.82\%$
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.2695s 35.4243ms 28.2292 Ops/s 28.8679 Ops/s $\color{#d91a1a}-2.21\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 0.1582s 30.0272ms 33.3031 Ops/s 29.3111 Ops/s $\textbf{\color{#35bf28}+13.62\%}$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 0.1658s 33.1355ms 30.1792 Ops/s 32.4944 Ops/s $\textbf{\color{#d91a1a}-7.12\%}$
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1565s 29.5430ms 33.8489 Ops/s 29.2707 Ops/s $\textbf{\color{#35bf28}+15.64\%}$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 0.1580s 32.7299ms 30.5531 Ops/s 29.8215 Ops/s $\color{#35bf28}+2.45\%$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 0.1698s 32.9321ms 30.3655 Ops/s 32.4205 Ops/s $\textbf{\color{#d91a1a}-6.34\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1633s 30.4084ms 32.8856 Ops/s 29.4354 Ops/s $\textbf{\color{#35bf28}+11.72\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 0.1609s 32.9677ms 30.3327 Ops/s 32.2089 Ops/s $\textbf{\color{#d91a1a}-5.83\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 0.1639s 30.6880ms 32.5860 Ops/s 29.5246 Ops/s $\textbf{\color{#35bf28}+10.37\%}$

@vmoens vmoens added the enhancement New feature or request label Aug 30, 2023
@vmoens vmoens marked this pull request as ready for review August 30, 2023 13:45
@vmoens vmoens changed the title [WIP] Device transform [Feature] Device transform Aug 30, 2023
@@ -2708,6 +2738,87 @@ def __init__(
super().__init__(torch.double, torch.float, in_keys, in_keys_inv)


class DeviceCastTransform(Transform):
"""Casts the env device.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really casting the env device?

Isn't it transforming the device of the data?

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question:

In environments like VMAS where there is an internal state that sits on a device, the way I thought that the environment is moved to another device is env.to(device)

If i understand correctly here we are casting the env data and its specs to another device, but not its internal state (we do not call env.to()).

does this make sense to use with a parent environment? when is this transform preferred to env.to(device)?

the only case i can imagine is when we need to keep an env of a specific device, apply some transforms there and then move the data to another device

@vmoens
Copy link
Contributor Author

vmoens commented Aug 30, 2023

does this make sense to use with a parent environment? when is this transform preferred to env.to(device)?

This is to address #1198 where the issue is that if the env naturally sits on MPS we can't use float64. So first you must transform the data into float32 and then cast to device. Doing env.to(device) will not work but this transform will.

@vmoens vmoens merged commit 7bc9955 into main Aug 30, 2023
50 of 55 checks passed
osalpekar pushed a commit to osalpekar/rl that referenced this pull request Aug 30, 2023
@vmoens vmoens deleted the device_transform branch August 31, 2023 06:52
@EkaterinaAbramova
Copy link

EkaterinaAbramova commented Sep 29, 2023

does this make sense to use with a parent environment? when is this transform preferred to env.to(device)?

This is to address #1198 where the issue is that if the env naturally sits on MPS we can't use float64. So first you must transform the data into float32 and then cast to device. Doing env.to(device) will not work but this transform will.

Could you please explain exactly how to do this? I am confused. I am following this tutorial https://pytorch.org/tutorials/intermediate/reinforcement_ppo.html and get the MPS flaot64 error when running line: base_env = GymEnv("InvertedDoublePendulum-v4", device=device, frame_skip=frame_skip) What code exactly shall I write to correct this error please?

@vmoens
Copy link
Contributor Author

vmoens commented Sep 29, 2023

Can I ask what the value of device is in your case?

@EkaterinaAbramova
Copy link

device="mps".

I believe I solved it with this (after quite a few hours of trying different things!!!):

base_env = GymEnv("InvertedDoublePendulum-v4", device="cpu", frame_skip=frame_skip)
env = TransformedEnv(
    base_env,
    Compose(
        ObservationNorm(in_keys=["observation"]), # normalise observations (make it about Standard Normal)
        DoubleToFloat(),   
        StepCounter(),                            # count the number of steps before the environment is terminated
        DeviceCastTransform(device=device, orig_device="cpu"),
    ),
)
print(env.device) # gives mps now

Could you please kindly confirm if what Ive done is correct? I am on Apple M2 max trying to use MPS.

@EkaterinaAbramova
Copy link

Everything was progressing smoothly through the tutorial: https://pytorch.org/tutorials/intermediate/reinforcement_ppo.html however at this code I get an error again about MPS. Please could you kindly advise syntax to solve this issue? I double checked everything seems to be on mps, so I don't understand where the error is coming from.

collector = SyncDataCollector(
    env,
    policy_module,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    split_trajs=False,
    device=device,
)


  File ~/anaconda3/envs/gpu-torch-rl/lib/python3.10/site-packages/torch/nn/modules/module.py:1143 in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.


env.device
Out[23]: device(type='mps')
Screenshot 2023-09-30 at 01 36 12 Screenshot 2023-09-30 at 01 38 39

@vmoens
Copy link
Contributor Author

vmoens commented Sep 30, 2023

@EkaterinaAbramova sorry you had this terrible experience, we should document things better for MPS.
It's something we're actively looking at and any feedback is very welcome.
The way you implemented your env looks great to me!

Regarding the error I will have a look into it, shouldn't be too difficult to solve. Have you tried moving the ObservationNorm and StepCounter after the device casting transform?
like:

env = TransformedEnv(
    base_env,
    Compose(
        DoubleToFloat(),   
        DeviceCastTransform(device=device, orig_device="cpu"),
        ObservationNorm(in_keys=["observation"]), # normalise observations (make it about Standard Normal)
        StepCounter(),                            # count the number of steps before the environment is terminated
    ),
)

Like this the buffers in the ObservationNorm transform will sit on mps but with float32 and not float64.

@EkaterinaAbramova
Copy link

@vmoens thank you for swiftly helping me, this issue is quite urgent, so very glad that you had the suggestion. It makes sense and I tried it, however this way around I get this error AttributeError: 'DoubleToFloat' object has no attribute 'init_stats' (it seems I need to pass some arguments, what shall I pass to be able to follow the tutorial please https://pytorch.org/tutorials/intermediate/reinforcement_ppo.html):

Screenshot 2023-09-30 at 23 26 42

@vmoens
Copy link
Contributor Author

vmoens commented Oct 1, 2023

You need to call init_stats on the obs norm transform

env.transform[-2].init_stats(...)

Because the transform has changed place

@EkaterinaAbramova
Copy link

EkaterinaAbramova commented Oct 1, 2023

OK I get it. So now that Ive indexed the correct location, the MPS issue is back.
The code you suggested doesn't work for me:

env = TransformedEnv(
    base_env,
    Compose(
        DoubleToFloat(),   
        DeviceCastTransform(device=device, orig_device="cpu"),
        ObservationNorm(in_keys=["observation"]), # normalise observations (make it about Standard Normal)
        StepCounter(),                            # count the number of steps before the environment is terminated
    ),
)
print(env.device)
env.transform[-2].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0) 
I get the error: Cannot convert a MPS Tensor
Screenshot 2023-10-01 at 10 58 25

TO RECAP: The original code was:

env = TransformedEnv(
    base_env,
    Compose(
        ObservationNorm(in_keys=["observation"]), # normalise observations (make it about Standard Normal)
        DoubleToFloat(),   
        StepCounter(),                            # count the number of steps before the environment is terminated
        DeviceCastTransform(device=device, orig_device="cpu"),
    ),
)
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0) 

This code runs, however I have issues further below in the tutorial where I can't run:

collector = SyncDataCollector(
    env,
    policy_module,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    split_trajs=False,
    device=device,
)
Cannot convert a MPS Tensor

Could you please provide a solution? I am really in a rush now, Ive been trying to solve this issue for days. The tutorial I am working on is online, so if it helps maybe you could try the suggestions to make sure they resolve the issue? Thank you very much for your help, I really need to move past this ASAP please.

@vmoens
Copy link
Contributor Author

vmoens commented Oct 1, 2023

Ok so that's an interesting bug, which basically boils down to some internal machinery within rollout, resets and transforms.
To quickly unblock you: can you compute the stats manually with your env?

simple_env = TransformedEnv(
    base_env,
    Compose(
        DoubleToFloat(),   
        DeviceCastTransform(device=device, orig_device="cpu"),
   )
)
td0 = simple_env.rollout(100)
loc = td0["observation"].mean(dim=0)
scale = td0["observation"].std(dim=0)
env = TransformedEnv(
    base_env,
    Compose(
        DoubleToFloat(),   
        DeviceCastTransform(device=device, orig_device="cpu"),
        ObservationNorm(in_keys=["observation"], loc=loc, scale=scale),
        StepCounter(),
    ),
)

Hopefully that should help!
I should be able to put my hands on an apple silicon computer tomorrow morning if you're still stuck!

@EkaterinaAbramova
Copy link

The suggestion didn't work Im afraid. Any other thing you could propose at this stage or only tomorrow?

Screenshot 2023-10-01 at 21 27 05

So why in the version I provided above, everything is fine until I get down to the SyncDataCollector? Why is it failing there? What has not been yet converted to float32?

@vmoens
Copy link
Contributor Author

vmoens commented Oct 2, 2023

#1589 will solve your problem!

@EkaterinaAbramova
Copy link

I have read through that thread, but I don't understand what I am supposed to do. Sorry!! Should I download the packages again from fresh? Or download something particular from that page?

@vmoens
Copy link
Contributor Author

vmoens commented Oct 2, 2023

Just wait until we merge it and then you can reinstall from got and things should work ok

@vmoens
Copy link
Contributor Author

vmoens commented Oct 2, 2023

@EkaterinaAbramova you should be good to go now!

vmoens added a commit to hyerra/rl that referenced this pull request Oct 10, 2023
@EkaterinaAbramova
Copy link

@vmoens hey I started from scratch but things are not working. I made a post here but nobody replied yet please: #1198

@vmoens
Copy link
Contributor Author

vmoens commented Jan 22, 2024

I can reproduce your issue, let me push a fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
4 participants