-
Notifications
You must be signed in to change notification settings - Fork 226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NTK/NNGP behavior in the infinite regime when weights are drawn from Gaussians with high standard deviation #197
Comments
hi, how can i enabling float64 precision ? |
Sorry for the late reply! @zhangbububu see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision @tengandreaxu could you try using |
Thank you so much, Roman. It's no problem at all! import numpy as np
from neural_tangents import stax
from jax import jit
W_stds = list(range(1, 17))
# W_stds.reverse()
layer_fn = []
for i in range(len(W_stds) - 1):
layer_fn.append(stax.Dense(1, W_std=W_stds[i]))
layer_fn.append(stax.Relu(do_stabilize=True))
layer_fn.append(stax.Dense(1, 1.0, 0.0))
_, _, kernel_fn = stax.serial(*layer_fn)
kernel_fn = jit(kernel_fn, static_argnames="get")
x = np.random.rand(100, 100)
print(kernel_fn(x, x, "ntk")) results in [[2.61008562e+20 1.12163820e+20 1.23732785e+20 ... 1.08229372e+20
1.05533967e+20 1.10687273e+20]
[1.12163820e+20 2.92078984e+20 1.31143308e+20 ... 1.16449180e+20
1.15616286e+20 1.19062657e+20]
[1.23732785e+20 1.31143308e+20 3.36093753e+20 ... 1.28641726e+20
1.19473708e+20 1.28997387e+20]
...
[1.08229363e+20 1.16449180e+20 1.28641726e+20 ... 2.74442324e+20
1.07858132e+20 1.20695995e+20]
[1.05533967e+20 1.15616286e+20 1.19473708e+20 ... 1.07858132e+20
2.69344883e+20 1.11830439e+20]
[1.10687273e+20 1.19062657e+20 1.28997387e+20 ... 1.20695995e+20
1.11830439e+20 2.83645061e+20]] Do you think that there is no sense in having weights drawn from a higher standard deviation as we go deeper into the neural net in the infinite width? |
hi, i meet a confuse problem
if i increate the number of training samples (N_tr), i will get a all NaN |
I think so, ideally you would want the mean and variance of your network outputs to match the mean and variance of your training labels, as a sensible prior. But even if your training labels have a large variance, it's common practice to just standardize them (together with test labels) to have mean 0 and variance 1 for best numerical stability. Then in a Relu network, to have mean zero / variance one outputs (given mean zero, variance one inputs), you would want to set @zhangbububu replied in your separate thread, let's continue there. |
Thank you for your prompt help Roman! |
Hi everyone, thank you so much for your exceptional work!
I'm encountering some numerical issues when weights are drawn from Gaussians with a high standard deviation. Please see the snippet below:
The result achieves:
By enabling float64 precision, the results indicate numerical values blowing up:
What's interesting is that the behavior appears to be more dependent on the depth than the high values in the weights' standard deviation. If the standard deviation of the weights were reversed (by uncommenting the code), so that in layer 1 we would have$w_{ij} \sim \mathcal{N}(0,17)$ , and so on so forth. The results would remain unchanged.
Thank you in advance, and happy new year!
The text was updated successfully, but these errors were encountered: