Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High Memory Usage on Infinite-Width NTK for GPU Only #204

Open
deoliveirajoshua opened this issue Jul 29, 2024 · 1 comment
Open

High Memory Usage on Infinite-Width NTK for GPU Only #204

deoliveirajoshua opened this issue Jul 29, 2024 · 1 comment

Comments

@deoliveirajoshua
Copy link

Hello,

I implemented a brutally simple infinite-width model, calling the kernel_fn with a batch of a single vector.

When I run this on CPU, I don't run into any exorbitant memory issues.

However, when I run this on an A100 GPU, it allocates just under 60GB after calling this tiny calculation!

This also happens on the GPU only when infinite-width CNNs are used on image datasets (CIFAR, MNIST, etc.)

Does anyone know what could be causing this to happen?

import numpy as np
import neural_tangents as nt

print(jax.devices())

def linear_model():
    return nt.stax.serial(
        nt.stax.Dense(512), nt.stax.Relu(),
        nt.stax.Dense(512), nt.stax.Relu(),
        nt.stax.Dense(1)
    )

init_fn, apply_fn, kernel_fn = linear_model()

total = 1
X = np.ones((total, 200), dtype=np.float32)

!nvidia-smi
ntk = kernel_fn(X, None, 'ntk')
!nvidia-smi
print(ntk)

Usage from first SMI call: 426MiB / 81920MiB
Usage from second SMI call: 61352MiB / 81920MiB

@romanngg
Copy link
Contributor

romanngg commented Sep 4, 2024

Sorry don't have access to a GPU now, but this seems in line with
https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

JAX will preallocate 75% of the total GPU memory when the first JAX operation is run

So it could be just JAX allocating memory (but not using it all for this computation)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants