You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I implemented a basic transformer block with residual connections and am getting the following error:
NotImplementedError: `FanInSum` is only implemented for the case where all input layers guaranteed to be mean-zero Gaussian, i.e. having all `is_gaussian` set to `True`, got [True, False].
Eventually would like to also include causal masking, and if you have pointers there that would also be great as it is also not clear how to do a upper triangular mask in the infinite width seq len case.
The text was updated successfully, but these errors were encountered:
Hi,
I implemented a basic transformer block with residual connections and am getting the following error:
It appears that it's due to
stax.Identity()
Here is the implementation:
And then taking the example data from the cookbook:
where the error occurs in the
kernel_fn
calculation.What is odd is that the
ResBlock
works in the cookbook:And it appears that with
linear_scaling=True
that theis_gaussian=True
from this line:https://github.com/google/neural-tangents/blob/c17e770bb74f1771da7be4a69fabfa68b6078960/neural_tangents/_src/stax/linear.py#L2464C14-L2468C39
Eventually would like to also include causal masking, and if you have pointers there that would also be great as it is also not clear how to do a upper triangular mask in the infinite width seq len case.
The text was updated successfully, but these errors were encountered: