You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
My collaborator and myself are working on a package which uses CUDA to perform some computation, then we share the array using dlpack.
The one issue is that we need to force users to use export XLA_PYTHON_CLIENT_PREALLOCATE=false when they use our library, which I see as a wart. Otherwise, it's likely that users will GPU seg because of JAX's pre-allocated budget (and then our computation will try and borrow memory over what's already owned by JAX).
Has there been any thinking about the ergonomics of this option? Or how interop between managed runtimes should work (in an ideal world)? It seems like there's no way around forcing an environment option like the above with the current way that memory allocation works -- but I wondered if there was thinking about this.
There’s a second issue. If I want to pass a JAX array the other way, I need to pull it off device and pipe it across dlpack to my custom C code — what’s the right way to avoid pulling it off device?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi all!
My collaborator and myself are working on a package which uses CUDA to perform some computation, then we share the array using
dlpack
.The one issue is that we need to force users to use
export XLA_PYTHON_CLIENT_PREALLOCATE=false
when they use our library, which I see as a wart. Otherwise, it's likely that users will GPU seg because of JAX's pre-allocated budget (and then our computation will try and borrow memory over what's already owned by JAX).Has there been any thinking about the ergonomics of this option? Or how interop between managed runtimes should work (in an ideal world)? It seems like there's no way around forcing an environment option like the above with the current way that memory allocation works -- but I wondered if there was thinking about this.
There’s a second issue. If I want to pass a JAX array the other way, I need to pull it off device and pipe it across dlpack to my custom C code — what’s the right way to avoid pulling it off device?
Is the correct path the XLA custom C call path?
Beta Was this translation helpful? Give feedback.
All reactions