Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draw Phase Diagram for CNTK #164

Open
hhorace opened this issue Sep 2, 2022 · 1 comment
Open

Draw Phase Diagram for CNTK #164

hhorace opened this issue Sep 2, 2022 · 1 comment
Labels
enhancement New feature or request

Comments

@hhorace
Copy link

hhorace commented Sep 2, 2022

I'm curious about the initialization for CNTK, so I replace the kernel_fn in c_map(W_var, b_var) function in colab with:

# Create a single layer of a network as an affine transformation composed
# with an Erf nonlinearity.
# kernel_fn = stax.serial(stax.Dense(1024, W_std, b_std), stax.Erf())[2]
kernel_fn = stax.serial(
      stax.Conv(out_chan=1024, filter_shape=(3, 3), strides=None, padding='SAME', W_std=W_std, b_std=b_std),
      stax.Relu(),
      stax.Flatten(),
      stax.Dense(10, W_std=W_std, b_std=b_std, parameterization='ntk')
)[2]

However, it seems that there's a bottom layer error when I tried to plot, with the error msg as follow:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_54123/684423119.py in <module>
----> 1 plt.contourf(W_var, b_var, c_star(W_var, b_var))
      2 plt.colorbar()
      3 plt.title('$C^*$ as a function of weight and bias variance', fontsize=14)
      4 
      5 format_plot('$\\sigma_w^2$', '$\\sigma_b^2$')

    [... skipping hidden 18 frame]

/tmp/ipykernel_54123/2333898834.py in <lambda>(W_var, b_var)
     51   return c_map_fn
     52 
---> 53 c_star = lambda W_var, b_var: fixed_point(c_map(W_var, b_var), 0.1, 1e-7)
     54 chi = lambda c, W_var, b_var: grad(c_map(W_var, b_var))(c)
     55 chi_1 = partial(chi, 1.)

/tmp/ipykernel_54123/2333898834.py in c_map(W_var, b_var)
     42     return kernel_fn(Kernel(np.array([[q]]))).nngp[0, 0]
     43 
---> 44   qstar = fixed_point(q_map_fn, 1.0, 1e-7)
     45 
     46   def c_map_fn(c):

/tmp/ipykernel_54123/3420146269.py in fixed_point(f, initial_value, threshold)
     38     return x - g(x) / dg(x), x
     39 
---> 40   return lax.while_loop(cond_fn, body_fn, (initial_value, 0.0))[0]

    [... skipping hidden 12 frame]

/tmp/ipykernel_54123/3420146269.py in body_fn(x)
     36   def body_fn(x):
     37     x, _ = x
---> 38     return x - g(x) / dg(x), x
     39 
     40   return lax.while_loop(cond_fn, body_fn, (initial_value, 0.0))[0]

/tmp/ipykernel_54123/3420146269.py in <lambda>(x)
     27 def fixed_point(f, initial_value, threshold):
     28   """Find fixed-points of a function f:R->R using Newton's method."""
---> 29   g = lambda x: f(x) - x
     30   dg = grad(g)
     31 

/tmp/ipykernel_54123/2333898834.py in q_map_fn(q)
     40   def q_map_fn(q):
     41     print(q)
---> 42     return kernel_fn(Kernel(np.array([[q]]))).nngp[0, 0]
     43 
     44   qstar = fixed_point(q_map_fn, 1.0, 1e-7)

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
    174     @functools.wraps(f)
    175     def h(*args, **kwargs):
--> 176       return g(*args, **kwargs)
    177 
    178     h.__signature__ = inspect.signature(f)

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in getter_fn(*args, **kwargs)
    208                                                           len(args)])
    209 
--> 210       fn_out = fn(*canonicalized_args, **kwargs)
    211 
    212       @nt_tree_fn()

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_any(x1_or_kernel, x2, get, pattern, mask_constant, diagonal_batch, diagonal_spatial, **kwargs)
   4293     """
   4294     if utils.is_nt_tree_of(x1_or_kernel, Kernel):
-> 4295       return kernel_fn_kernel(x1_or_kernel,
   4296                               pattern=pattern,
   4297                               diagonal_batch=diagonal_batch,

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_kernel(kernel, **kwargs)
   4212 
   4213   def kernel_fn_kernel(kernel, **kwargs):
-> 4214     out_kernel = kernel_fn(kernel, **kwargs)
   4215     return _set_shapes(init_fn, apply_fn, kernel, out_kernel, **kwargs)
   4216 

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
    174     @functools.wraps(f)
    175     def h(*args, **kwargs):
--> 176       return g(*args, **kwargs)
    177 
    178     h.__signature__ = inspect.signature(f)

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in new_kernel_fn(k, **kwargs)
    191               pass
    192 
--> 193       return kernel_fn(k, **kwargs)
    194 
    195     setattr(new_kernel_fn, _INPUT_REQ, frozendict.frozendict(static_reqs))

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn(k, **kwargs)
    325     # inside kernel functions here and parallel below.
    326     for f in kernel_fns:
--> 327       k = f(k, **kwargs)
    328     return k
    329 

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
    174     @functools.wraps(f)
    175     def h(*args, **kwargs):
--> 176       return g(*args, **kwargs)
    177 
    178     h.__signature__ = inspect.signature(f)

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in getter_fn(*args, **kwargs)
    208                                                           len(args)])
    209 
--> 210       fn_out = fn(*canonicalized_args, **kwargs)
    211 
    212       @nt_tree_fn()

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_any(x1_or_kernel, x2, get, pattern, mask_constant, diagonal_batch, diagonal_spatial, **kwargs)
   4293     """
   4294     if utils.is_nt_tree_of(x1_or_kernel, Kernel):
-> 4295       return kernel_fn_kernel(x1_or_kernel,
   4296                               pattern=pattern,
   4297                               diagonal_batch=diagonal_batch,

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_kernel(kernel, **kwargs)
   4212 
   4213   def kernel_fn_kernel(kernel, **kwargs):
-> 4214     out_kernel = kernel_fn(kernel, **kwargs)
   4215     return _set_shapes(init_fn, apply_fn, kernel, out_kernel, **kwargs)
   4216 

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_with_masking(k, **user_reqs)
    277         mask1, mask2 = mask_fn(mask1, shape1), mask_fn(mask2, shape2)
    278 
--> 279         k = kernel_fn(k, **user_reqs)  # type: Kernel
    280 
    281         if remask_kernel:

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
    174     @functools.wraps(f)
    175     def h(*args, **kwargs):
--> 176       return g(*args, **kwargs)
    177 
    178     h.__signature__ = inspect.signature(f)

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in new_kernel_fn(k, **kwargs)
    191               pass
    192 
--> 193       return kernel_fn(k, **kwargs)
    194 
    195     setattr(new_kernel_fn, _INPUT_REQ, frozendict.frozendict(static_reqs))

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn(k, **kwargs)
   1506       return out
   1507 
-> 1508     cov1 = conv(cov1, 1 if k.diagonal_batch else 2)
   1509     cov2 = conv(cov2, 1 if k.diagonal_batch else 2)
   1510 

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in conv(lhs, batch_ndim)
   1502 
   1503     def conv(lhs, batch_ndim):
-> 1504       out = conv_unscaled(lhs, batch_ndim)
   1505       out = affine(out, W_std**2, b_std**2, batch_ndim)
   1506       return out

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in conv_unscaled(lhs, batch_ndim)
   1477 
   1478     def conv_unscaled(lhs, batch_ndim):
-> 1479       lhs = conv_kernel(lhs,
   1480                         filter_shape_kernel,
   1481                         strides_kernel,

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in _conv_kernel_full_spatial_shared(lhs, filter_shape, strides, padding, batch_ndim)
   4759     return n_channels
   4760 
-> 4761   out = _conv_kernel_full_spatial_loop(lhs, filter_shape, strides, padding,
   4762                                        lax_conv, get_n_channels)
   4763   return out

~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in _conv_kernel_full_spatial_loop(lhs, filter_shape, strides, padding, lax_conv, get_n_channels)
   4912     spatial_i = (i - batch_ndim) // 2
   4913 
-> 4914     lhs = np.moveaxis(lhs, (i - 1, i), (-2, -1))
   4915     preshape = lhs.shape[:-2]
   4916     n_channels = get_n_channels(utils.size_at(preshape))

~/jax/jax/_src/numpy/lax_numpy.py in moveaxis(a, source, destination)
   1535     destination_axes = tuple(cast(Sequence[int], destination))
   1536   source_axes = tuple(_canonicalize_axis(i, ndim(a)) for i in source_axes)
-> 1537   destination_axes = tuple(_canonicalize_axis(i, ndim(a))
   1538                            for i in destination_axes)
   1539   if len(source_axes) != len(destination_axes):

~/jax/jax/_src/numpy/lax_numpy.py in <genexpr>(.0)
   1535     destination_axes = tuple(cast(Sequence[int], destination))
   1536   source_axes = tuple(_canonicalize_axis(i, ndim(a)) for i in source_axes)
-> 1537   destination_axes = tuple(_canonicalize_axis(i, ndim(a))
   1538                            for i in destination_axes)
   1539   if len(source_axes) != len(destination_axes):

~/jax/jax/_src/util.py in canonicalize_axis(axis, num_dims)
    275   axis = operator.index(axis)
    276   if not -num_dims <= axis < num_dims:
--> 277     raise ValueError(
    278         "axis {} is out of bounds for array of dimension {}".format(
    279             axis, num_dims))

ValueError: axis -2 is out of bounds for array of dimension 1

Is there any misunderstanding of me to the Phase Diagram? (Is CNTK fundamentally un-drawn-able?
Also, I've also found that there's totally no difference in Phase Diagram when I simply deeper an FC network, e.g.

def DenseGroup(n, neurons, W_std, b_std):
    blocks = []
    for _ in range(n):
        blocks += [stax.Dense(neurons, W_std, b_std), stax.Erf()]
    return stax.serial(*blocks)

for layer in range(1,11):
    def c_map(W_var, b_var):
        ...
        kernel_fn = stax.serial(DenseGroup(layer, 1024, W_std, b_std))[2]
        ...
    c_star = lambda W_var, b_var: fixed_point(c_map(W_var, b_var), 0.1, 1e-7)
    chi = lambda c, W_var, b_var: grad(c_map(W_var, b_var))(c)
    chi_1 = partial(chi, 1.)
    
    c_star = jit(vectorize_over_sw_sb(c_star))
    chi_1 = jit(vectorize_over_sw_sb(chi_1))

    plt.contourf(W_var, b_var, c_star(W_var, b_var))

Does it mean that the depth of NNs won't affect the initialization?

@romanngg
Copy link
Contributor

romanngg commented Sep 4, 2022

@SiuMath and @sschoenholz may answer better, but I can give some brief comments:

  1. Re changing the depth, your observation is correct. $C*$ diagram shows the fixed point correlation, i.e. the limiting correlation value $c^*$ in the infinite-depth limit, so it shouldn't matter if you repeat 1 or 3 identical layers infinitely-many times. The $\chi$ plot will change, but note that by definition and the chain rule, the $\chi$ for $n$ identical layers will be equal to $\chi$ for one layer to the power of $n$, so the phase boundary where $\chi = 1$ will remain the same.

  2. I imagine the code could be generalized to CNNs, but it would need to support vector-valued variances $q$ and covariances $c$ (for spatial locations), so may need some work. Note that per https://arxiv.org/abs/1806.05393 for standard/ntk parameterization and CIRCULAR padding, it should yield the same phase diagram as for the fully-connected network. In Figure 11 of https://arxiv.org/abs/1810.05148 we've run some experiments with SAME padding, and obtained a reasonable agreement too.

One other comment, this notebook relies on being able to determine a fixed-point variance $q*$, which does not always exist for ReLU that you used in your example (for weight variance above 2, no stable non-zero variance exists, and it explodes in the infinite-depth limit), so ReLU nonlinearity won't work in the notebook at the moment, even for FCNs. But you can find the ReLU phase diagram in Figure 4 (b) of https://arxiv.org/abs/1711.00165.

Finally, note that these diagrams study the forward propagation of the signal, so they are only working with the CNN-GP kernel (and not the CNTK), so the parameterization argument should have no impact on them.

Hope this helps!

@romanngg romanngg added the enhancement New feature or request label Sep 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants