You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was working through the Neural Tangents Cookbook and am a bit confused by the loss_fn (reproduced below):
def loss_fn(predict_fn, ys, t, xs=None):
mean, cov = predict_fn(t=t, get='ntk', x_test=xs, compute_cov=True)
mean = jnp.reshape(mean, mean.shape[:1] + (-1,))
var = jnp.diagonal(cov, axis1=1, axis2=2)
ys = jnp.reshape(ys, (1, -1))
mean_predictions = 0.5 * jnp.mean(ys ** 2 - 2 * mean * ys + var + mean ** 2,
axis=1)
return mean_predictions
It looks like this function is later used to calculate the training or test losses for plotting. What I am confused by is, the calculation for (each test point in) the mean_predictions contains var, making it effectively the sum of the squared error (between a prediction and a label) and the variance. While it does make sense to include the variance as part of the performance (or loss), but why this speicfic form (e.g., why $+ 1 \times \text{var}$ instead of $2 \times \text{var}$ or why variance and not standard deviation, and why is there a $0.5$ in front)? Perhaps you could point me to a reference that I probably missed somewhere?
Thanks again!
The text was updated successfully, but these errors were encountered:
Hi there, thanks for the great repo!
I was working through the Neural Tangents Cookbook and am a bit confused by the
loss_fn
(reproduced below):It looks like this function is later used to calculate the training or test losses for plotting. What I am confused by is, the calculation for (each test point in) the$+ 1 \times \text{var}$ instead of $2 \times \text{var}$ or why variance and not standard deviation, and why is there a $0.5$ in front)? Perhaps you could point me to a reference that I probably missed somewhere?
mean_predictions
containsvar
, making it effectively the sum of the squared error (between a prediction and a label) and the variance. While it does make sense to include the variance as part of the performance (or loss), but why this speicfic form (e.g., whyThanks again!
The text was updated successfully, but these errors were encountered: