You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I think the reason is that this 1D function is hard to fit with a Relu kernel, but sampling only 15 points makes it a simpler training objective, so it fits it with a lower diagonal regularizer. You can avoid NaNs by increasing diag_reg which I did below, but as you can see it's a poor fit in any case. (NTK prediction is orange with 1000 test points sampled).
1000 training points, diag_reg=1e-2:
100 training points, diag_reg=1e-3:
15 training points, diag_reg=1e-4:
I guess for this particular example, knowing your training targets, a periodic nonlinearity would fit better (stax.Sin(), diag_reg=1e-4):
Otherwise trying different architectures and plotting predictions or draws from the prior would be good to gain intuition for what works best. Note that for time series data of shape [batch_size, time_duration, n_features], I imagine you may want to use 1D-convolution stax.Conv/stax.ConvLocal over the time_duration axis, to incorporate time locality into your model.
hi, i meet a confuse problem
if i increate the number of training samples (N_tr), i will get a all NaN
nkt_mean
The text was updated successfully, but these errors were encountered: