Replies: 1 comment
-
This doesn't answer your question, but have you considered using pyKeops to handle such large matrices in a smart way? https://www.kernel-operations.io/keops/index.html Although it states that bindings are only for Numpy and PyTorch, you can easily use https://github.com/rdyro/torch2jax as the go-between. This all supports auto-diff, vmap, etc etc etc. And the extra overhead is very minimal, in fact, due to the hyper-efficient way Keops works it can be faster, even with going through an intermediate step. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I'm encountering the situation of continuously increasing memory use with jax which doesn't quite make sense to me (see code below).
I have a very large array 500,000 x 4,000 that essentially need to vmap'ed to get 500,000 numbers (using func0 given below).
When I tried directly doing that I ran out of memory, so I thought, I'll split the vmap into chunks. I.e I'll apply vmap
to array [0:5000,:] then [5000:10000,:] and so forth to concatenate the results.
To my surprise when I did that when I look at the memory use I can still see it increasing by 3Gb per iteration (and in the end I run out of memory). I understand that each vmap will need to have information about the gradients, but that should be nowhere near 3gb.
So I'm wondering if this is a bug or am I am missing something here.
Thanks in advance !
The test code to illustrate the issue given below (requires 70GB ram to run with the current values of nspec,npix parameters)
Also I use CPU and jax 0.4.35
Beta Was this translation helpful? Give feedback.
All reactions