You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have to solve a linear system and, due to memory constraints, I'm currently using jax.scipy.sparse.linalg.cg. To speed things up and save memory, I'm using multiple GPUs. The code of a pmapped function (with pmapped axis 'p') looks something like this:
Is having a psum inside cg a sensible thing to do? I'm asking this because it seems like the more devices I use, the longer it actually takes to solve the system (with more than 2 GPUs). Could this be due to the communications between the GPUs begin relatively slow?
Also, cg is inside the pmapped function because the full version of pmap_var1 (not distributed across the GPUs) is large, so having it divided between the GPUs is more convenient, but at the same time each GPU calls cg; is there a smart way to make it such that only one GPU calls it (without using lax.cond inside the pmapped function for example)?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I have to solve a linear system and, due to memory constraints, I'm currently using
jax.scipy.sparse.linalg.cg
. To speed things up and save memory, I'm using multiple GPUs. The code of a pmapped function (with pmapped axis'p'
) looks something like this:Is having a
psum
insidecg
a sensible thing to do? I'm asking this because it seems like the more devices I use, the longer it actually takes to solve the system (with more than 2 GPUs). Could this be due to the communications between the GPUs begin relatively slow?Also,
cg
is inside the pmapped function because the full version ofpmap_var1
(not distributed across the GPUs) is large, so having it divided between the GPUs is more convenient, but at the same time each GPU callscg
; is there a smart way to make it such that only one GPU calls it (without usinglax.cond
inside the pmapped function for example)?Beta Was this translation helpful? Give feedback.
All reactions