We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hello,
I implemented a brutally simple infinite-width model, calling the kernel_fn with a batch of a single vector.
When I run this on CPU, I don't run into any exorbitant memory issues.
However, when I run this on an A100 GPU, it allocates just under 60GB after calling this tiny calculation!
This also happens on the GPU only when infinite-width CNNs are used on image datasets (CIFAR, MNIST, etc.)
Does anyone know what could be causing this to happen?
import numpy as np import neural_tangents as nt print(jax.devices()) def linear_model(): return nt.stax.serial( nt.stax.Dense(512), nt.stax.Relu(), nt.stax.Dense(512), nt.stax.Relu(), nt.stax.Dense(1) ) init_fn, apply_fn, kernel_fn = linear_model() total = 1 X = np.ones((total, 200), dtype=np.float32) !nvidia-smi ntk = kernel_fn(X, None, 'ntk') !nvidia-smi print(ntk)
Usage from first SMI call: 426MiB / 81920MiB Usage from second SMI call: 61352MiB / 81920MiB
The text was updated successfully, but these errors were encountered:
Sorry don't have access to a GPU now, but this seems in line with https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
JAX will preallocate 75% of the total GPU memory when the first JAX operation is run
So it could be just JAX allocating memory (but not using it all for this computation)
Sorry, something went wrong.
No branches or pull requests
Hello,
I implemented a brutally simple infinite-width model, calling the kernel_fn with a batch of a single vector.
When I run this on CPU, I don't run into any exorbitant memory issues.
However, when I run this on an A100 GPU, it allocates just under 60GB after calling this tiny calculation!
This also happens on the GPU only when infinite-width CNNs are used on image datasets (CIFAR, MNIST, etc.)
Does anyone know what could be causing this to happen?
Usage from first SMI call: 426MiB / 81920MiB
Usage from second SMI call: 61352MiB / 81920MiB
The text was updated successfully, but these errors were encountered: