Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature masks do not get reduced in the kernel #159

Open
jglaser opened this issue Aug 15, 2022 · 2 comments
Open

Feature masks do not get reduced in the kernel #159

jglaser opened this issue Aug 15, 2022 · 2 comments
Labels
bug Something isn't working

Comments

@jglaser
Copy link
Contributor

jglaser commented Aug 15, 2022

I am observing an error message when providing masked inputs with more than one feature dimensions to a kernel that involves stax.GlobalAvgPool()

Reproducer:

import jax
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax

if __name__ == '__main__':
    # input tokens
    X = 3*np.ones((10,512))

    mask_constant = 10
    pad_token = 0

    # pad some elements
    X = X.at[0,4].set(pad_token)
    X = X.at[7,422].set(pad_token)
    print('before encode ',X.shape)

    # vocabulary size
    n_vocab = 5
    def encode(x, mask_constant):
        # zero mean embeddings
        res = jax.nn.one_hot(x, n_vocab)
        res -= np.mean(res, axis=-1, keepdims=True)
        return np.where(x[..., None] == pad_token, mask_constant, res)

    X = encode(X, mask_constant=mask_constant)
    print('after encode ', X.shape)

    # trace over output correlations
    _, _, kernel_fn_avg = stax.GlobalAvgPool()
    input_fn = nt.batch(kernel_fn_avg, batch_size=2)
    cov = input_fn(X, X, 'nngp', mask_constant=mask_constant, diagonal_spatial=True)
    print('output ', cov.shape)

Output

$ python mask_reproducer.py 
Attempting to register factory for plugin cuBLAS when one has already been registered
before encode  (10, 512)
after encode  (10, 512, 5)
Traceback (most recent call last):
  File "/gpfs/alpine/bif136/world-shared/gpbind/mask_reproducer.py", line 32, in <module>
    cov = input_fn(X, X, 'nngp', mask_constant=mask_constant, diagonal_spatial=True)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 471, in serial_fn
    return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 398, in serial_fn_x1
    _, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
    carry, y = f(carry, x)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 387, in row_fn
    return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
    carry, y = f(carry, x)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 396, in col_fn
    return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 758, in f_pmapped
    return _f(x_or_kernel, *args_np, **kwargs_np)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/api.py", line 525, in cache_miss
    out_flat = xla.xla_call(
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/core.py", line 1919, in bind
    return call_bind(self, fun, *args, **params)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/core.py", line 1935, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/core.py", line 687, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/dispatch.py", line 199, in _xla_call_impl
    compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/linear_util.py", line 295, in memoized_fun
    ans = call(fun, *args)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/dispatch.py", line 248, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/dispatch.py", line 293, in lower_xla_callable
    jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/profiler.py", line 294, in wrapper
    return func(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2167, in trace_to_jaxpr_final2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2117, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 751, in _f
    return f(_x_or_kernel, *_args, **_kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 222, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 1008, in kernel_fn_any
    return kernel_fn_x1(x1_or_kernel, x2, get,
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 921, in kernel_fn_x1
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 222, in kernel_fn_with_masking
    mask1, mask2 = mask_fn(mask1, shape1), mask_fn(mask2, shape2)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 188, in mask_fn
    return _mask_fn(mask, input_shape)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 1756, in mask_fn
    _check_is_implemented(mask, channel_axis)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 3621, in _check_is_implemented
    raise NotImplementedError(
jax._src.traceback_util.UnfilteredStackTrace: NotImplementedError: Different channel-wise masks as inputs to pooling layers are not yet supported. Please let us know about your use case at https://github.com/google/neural-tangents/issues/new

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/gpfs/alpine/bif136/world-shared/gpbind/mask_reproducer.py", line 32, in <module>
    cov = input_fn(X, X, 'nngp', mask_constant=mask_constant, diagonal_spatial=True)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 471, in serial_fn
    return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 398, in serial_fn_x1
    _, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
    carry, y = f(carry, x)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 387, in row_fn
    return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
    carry, y = f(carry, x)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 396, in col_fn
    return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 758, in f_pmapped
    return _f(x_or_kernel, *args_np, **kwargs_np)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 751, in _f
    return f(_x_or_kernel, *_args, **_kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 222, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 1008, in kernel_fn_any
    return kernel_fn_x1(x1_or_kernel, x2, get,
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 921, in kernel_fn_x1
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 222, in kernel_fn_with_masking
    mask1, mask2 = mask_fn(mask1, shape1), mask_fn(mask2, shape2)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
    return g(*args, **kwargs)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 188, in mask_fn
    return _mask_fn(mask, input_shape)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 1756, in mask_fn
    _check_is_implemented(mask, channel_axis)
  File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 3621, in _check_is_implemented
    raise NotImplementedError(
NotImplementedError: Different channel-wise masks as inputs to pooling layers are not yet supported. Please let us know about your use case at https://github.com/google/neural-tangents/issues/new

Expected output (with the fix from #158):

before encode  (10, 512)
after encode  (10, 512, 5)
/gpfs/alpine/world-shared/bif136/jax_env_summit/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py:769: UserWarning: Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.
  warnings.warn("Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.")
/gpfs/alpine/world-shared/bif136/jax_env_summit/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py:769: UserWarning: Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.
  warnings.warn("Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.")
output  (10, 10)

I admit that the warning is a little noisy, perhaps it could be omitted and the reduction mentioned in the documentation.

@romanngg
Copy link
Contributor

Thanks for the detailed repro!

This touches on a bit of a fragile part of the library, which may take a while to figure out how fix properly (namely, we don't support different channelwise masks inside the network; and pooling is a layer that can produce different channelwise masks, given different input channelwise masks). I'm hesitant to employ the proposed solution though, because it might lead to silent errors if the user does use different channel masks and doesn't pay attention to the warning.

Two user-side short-term solutions:

  • Have a network start at the bottom with a parametric layer (e.g. Dense, Conv, etc), since these layers do produce outputs that have identical masks for all channels (regardless of different masks for different channels in the inputs).
  • Apply any non-parametric layers (e.g. pooling, relu, etc) as part of dataset preprocessing, and not part of the network.

Lmk if this would work for the short-term!

@romanngg romanngg added the bug Something isn't working label Aug 19, 2022
@jglaser
Copy link
Contributor Author

jglaser commented Sep 11, 2022

Yup, starting the network with a Dense layer works

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants