You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have been trying to vectorize functions with jax.vmap but I encounter a problem where it takes too much memory on a higher number of sampes making the system crash.
I try to compute the GAE function for an actor critic algorithm in RL, and I compared a manuel vectorization vs vmap.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hello,
I have been trying to vectorize functions with
jax.vmap
but I encounter a problem where it takes too much memory on a higher number of sampes making the system crash.I try to compute the GAE function for an actor critic algorithm in RL, and I compared a manuel vectorization vs vmap.
Using this data:
Manual vectorization:
Result:
Vmap vectorization
(same implementatiom as here https://github.com/google/flax/tree/main/examples/ppo)
Result:
Question
Is there anyway to optimize the vmap vectorization memory wise or is it just a trade-off ?
Because it is impossible to use for steps => 1024
Beta Was this translation helpful? Give feedback.
All reactions