Replies: 2 comments 1 reply
-
Hi - it's not clear from your description which arrays have shape Given the memory blow-up, I suspect you may be running into unexepcted rank promotion due to the alignment of various array axes in your code. For example, if It would be helpful if you could put together a complete minimal example of the behavior you're seeing, so that it's clearer to the reader what array sizes yoy're working with. |
Beta Was this translation helpful? Give feedback.
-
Hey @jakevdp thank you for your reply. I am trying to put together a minimal example, but of course the minimal example is working so I am not sure exactly how to reproduce this unless I just sent you my notebook. If you are willing to take a look at it, I would be happy to send it. Below is the minimum example of what I am trying to do (and what I think I am doing). `import os import jax.numpy as jnp def getGradients(sks): def getCost(grad_1, grad_2, sigmas, sks):
jit_getCost = jax.jit(getCost, device=gpu_device) def run(sks): rng_key = random.key(123) num_t_points = 61 new_key, *sks = random.split(rng_key,6) cost = run(sks)` In the actual code, the gradients and sigmas are calculated in another function but they all share the same dimensions (61,11,2000). I am then trying to draw samples of size (61,11,2000,100). It seems this operation is working fine and is doing so also in the minimum example. I can even do operations (like taking the mean) on the samples. However, if I try to use the results of the operations that is when I get the error. Now it is saying:
It still seems to be memory related, but if I print the size of the array it shows the correct dimension (61,11,2000) after taking the mean along axis(-1). Benjamin |
Beta Was this translation helpful? Give feedback.
-
I am working on porting a project from Numpy to JAX to take make it hopefully real time fast. My current algorithm is running at 100 Hz which is plenty fast for what I need. However, I am trying to add 1 more piece and it is leading to OOM errors no matter what size I change the inputs to. I am at a loss of where to go from here. I have tried setting:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
The algorithm is generating a bunch of samples and then performing operations on them before returning my final result.
I have this line:
grad_cost = jnp.sum(jnp.sum(jnp.square(grad_1),axis=1) + jnp.sum(jnp.square(grad_2),axis=1),axis=0)
which is operating on arrays of size (61, 11, 2000) with no problem.
I then added these lines:
where these samples are now size (61, 11, 2000, 100). This too runs fine and quite quick. However, if I try pass these mean values into the original function like:
grad_cost = jnp.sum(jnp.sum(jnp.square(grad_1_m),axis=1) + jnp.sum(jnp.square(grad_2_m),axis=1),axis=0)
Then I always will get the resource exhausted error:
RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2899692453488 bytes.
I don't understand why it thinks it needs 2.9 Tbytes. These arrays should not be that big in memory.
I have also tried to implement a version using lax.fori_loop thinking maybe it was really exhausting the memory, but that too failed immediately with the same error. Any suggestions on where I am going wrong? How I can get this run with the additional samples? I can handling the speed coming down a litttle, but I really don't understand why it won't run at all at this point. Any input is very much appreciated.
Thanks!
Benjamin
Beta Was this translation helpful? Give feedback.
All reactions