You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am observing an error message when providing masked inputs with more than one feature dimensions to a kernel that involves stax.GlobalAvgPool()
Reproducer:
import jax
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax
if __name__ == '__main__':
# input tokens
X = 3*np.ones((10,512))
mask_constant = 10
pad_token = 0
# pad some elements
X = X.at[0,4].set(pad_token)
X = X.at[7,422].set(pad_token)
print('before encode ',X.shape)
# vocabulary size
n_vocab = 5
def encode(x, mask_constant):
# zero mean embeddings
res = jax.nn.one_hot(x, n_vocab)
res -= np.mean(res, axis=-1, keepdims=True)
return np.where(x[..., None] == pad_token, mask_constant, res)
X = encode(X, mask_constant=mask_constant)
print('after encode ', X.shape)
# trace over output correlations
_, _, kernel_fn_avg = stax.GlobalAvgPool()
input_fn = nt.batch(kernel_fn_avg, batch_size=2)
cov = input_fn(X, X, 'nngp', mask_constant=mask_constant, diagonal_spatial=True)
print('output ', cov.shape)
Output
$ python mask_reproducer.py
Attempting to register factory for plugin cuBLAS when one has already been registered
before encode (10, 512)
after encode (10, 512, 5)
Traceback (most recent call last):
File "/gpfs/alpine/bif136/world-shared/gpbind/mask_reproducer.py", line 32, in <module>
cov = input_fn(X, X, 'nngp', mask_constant=mask_constant, diagonal_spatial=True)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 471, in serial_fn
return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 398, in serial_fn_x1
_, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
carry, y = f(carry, x)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 387, in row_fn
return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
carry, y = f(carry, x)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 396, in col_fn
return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 758, in f_pmapped
return _f(x_or_kernel, *args_np, **kwargs_np)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/api.py", line 525, in cache_miss
out_flat = xla.xla_call(
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/core.py", line 1919, in bind
return call_bind(self, fun, *args, **params)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/core.py", line 1935, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/core.py", line 687, in process_call
return primitive.impl(f, *tracers, **params)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/dispatch.py", line 199, in _xla_call_impl
compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/linear_util.py", line 295, in memoized_fun
ans = call(fun, *args)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/dispatch.py", line 248, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/dispatch.py", line 293, in lower_xla_callable
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2167, in trace_to_jaxpr_final2
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2117, in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers_)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 751, in _f
return f(_x_or_kernel, *_args, **_kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 222, in getter_fn
fn_out = fn(*canonicalized_args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 1008, in kernel_fn_any
return kernel_fn_x1(x1_or_kernel, x2, get,
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 921, in kernel_fn_x1
out_kernel = kernel_fn(kernel, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 222, in kernel_fn_with_masking
mask1, mask2 = mask_fn(mask1, shape1), mask_fn(mask2, shape2)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 188, in mask_fn
return _mask_fn(mask, input_shape)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 1756, in mask_fn
_check_is_implemented(mask, channel_axis)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 3621, in _check_is_implemented
raise NotImplementedError(
jax._src.traceback_util.UnfilteredStackTrace: NotImplementedError: Different channel-wise masks as inputs to pooling layers are not yet supported. Please let us know about your use case at https://github.com/google/neural-tangents/issues/new
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/gpfs/alpine/bif136/world-shared/gpbind/mask_reproducer.py", line 32, in <module>
cov = input_fn(X, X, 'nngp', mask_constant=mask_constant, diagonal_spatial=True)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 471, in serial_fn
return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 398, in serial_fn_x1
_, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
carry, y = f(carry, x)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 387, in row_fn
return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
carry, y = f(carry, x)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 396, in col_fn
return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 758, in f_pmapped
return _f(x_or_kernel, *args_np, **kwargs_np)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 751, in _f
return f(_x_or_kernel, *_args, **_kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 222, in getter_fn
fn_out = fn(*canonicalized_args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 1008, in kernel_fn_any
return kernel_fn_x1(x1_or_kernel, x2, get,
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 921, in kernel_fn_x1
out_kernel = kernel_fn(kernel, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 222, in kernel_fn_with_masking
mask1, mask2 = mask_fn(mask1, shape1), mask_fn(mask2, shape2)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 188, in mask_fn
return _mask_fn(mask, input_shape)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 1756, in mask_fn
_check_is_implemented(mask, channel_axis)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 3621, in _check_is_implemented
raise NotImplementedError(
NotImplementedError: Different channel-wise masks as inputs to pooling layers are not yet supported. Please let us know about your use case at https://github.com/google/neural-tangents/issues/new
before encode (10, 512)
after encode (10, 512, 5)
/gpfs/alpine/world-shared/bif136/jax_env_summit/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py:769: UserWarning: Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.
warnings.warn("Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.")
/gpfs/alpine/world-shared/bif136/jax_env_summit/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py:769: UserWarning: Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.
warnings.warn("Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.")
output (10, 10)
I admit that the warning is a little noisy, perhaps it could be omitted and the reduction mentioned in the documentation.
The text was updated successfully, but these errors were encountered:
This touches on a bit of a fragile part of the library, which may take a while to figure out how fix properly (namely, we don't support different channelwise masks inside the network; and pooling is a layer that can produce different channelwise masks, given different input channelwise masks). I'm hesitant to employ the proposed solution though, because it might lead to silent errors if the user does use different channel masks and doesn't pay attention to the warning.
Two user-side short-term solutions:
Have a network start at the bottom with a parametric layer (e.g. Dense, Conv, etc), since these layers do produce outputs that have identical masks for all channels (regardless of different masks for different channels in the inputs).
Apply any non-parametric layers (e.g. pooling, relu, etc) as part of dataset preprocessing, and not part of the network.
I am observing an error message when providing masked inputs with more than one feature dimensions to a kernel that involves
stax.GlobalAvgPool()
Reproducer:
Output
Expected output (with the fix from #158):
I admit that the warning is a little noisy, perhaps it could be omitted and the reduction mentioned in the documentation.
The text was updated successfully, but these errors were encountered: