Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer Block Implementation gives NotImplementedError #203

Open
esnvidia opened this issue Jul 19, 2024 · 0 comments
Open

Transformer Block Implementation gives NotImplementedError #203

esnvidia opened this issue Jul 19, 2024 · 0 comments

Comments

@esnvidia
Copy link

Hi,

I implemented a basic transformer block with residual connections and am getting the following error:

NotImplementedError: `FanInSum` is only implemented for the case where all input layers guaranteed to be mean-zero Gaussian, i.e. having all `is_gaussian` set to `True`, got [True, False].

It appears that it's due to stax.Identity()

Here is the implementation:

def FeedForwardNetwork(hidden_dim, output_dim):
    return stax.serial(stax.Dense(hidden_dim), stax.Relu(),
                       stax.Dense(output_dim)
                      )

AttnBlock = stax.serial(stax.FanOut(2),
                        stax.parallel(
                            stax.serial(
                                stax.GlobalSelfAttention(
                                   n_chan_out=1,
                                   n_chan_key=1,
                                   n_chan_val=1,
                                   pos_emb_type='SUM',
                                   W_pos_emb_std=1,
                                   # pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
                                   attention_mechanism='SOFTMAX',
                                    linear_scaling=True,
                                   n_heads=1)
                            ),
                            stax.Identity()
                        ),
                        stax.FanInSum()
                       )

def TransformerBlock(ff_dim, d_model):
    return stax.serial(AttnBlock,
                       stax.LayerNorm(),
                       stax.FanOut(2),
                       stax.parallel(
                           FeedForwardNetwork(ff_dim, d_model),
                          stax.Identity()
                       ),
                       stax.FanInSum(),
                       stax.LayerNorm()
                      )
def Transformer(num_layers,ff_dim, d_model):
    layers = []
    for _ in range(num_layers):
        layers.append(TransformerBlock(ff_dim, d_model))
    layers.append(stax.Dense(out_dim=1))
    return stax.serial(*layers)

num_layers = 1
ff_dim = 128
d_model = 256

init_fn, apply_fn, kernel_fn = Transformer(num_layers, ff_dim, d_model)

And then taking the example data from the cookbook:

key = random.PRNGKey(10)
train_points = 5
test_points = 50
noise_scale = 1e-1

target_fn = lambda x: jnp.sin(x)

key, x_key, y_key = random.split(key, 3)

train_xs = random.uniform(x_key, (train_points, 1), minval=-jnp.pi, maxval=jnp.pi)

train_ys = target_fn(train_xs)
train_ys += noise_scale * random.normal(y_key, (train_points, 1))
train = (train_xs, train_ys)

test_xs = jnp.linspace(-jnp.pi, jnp.pi, test_points)
test_xs = jnp.reshape(test_xs, (test_points, 1))

test_ys = target_fn(test_xs)
test = (test_xs, test_ys)

apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnames='get')

kernel = kernel_fn(test_xs, test_xs, 'nngp')
std_dev = jnp.sqrt(jnp.diag(kernel))

where the error occurs in the kernel_fn calculation.

What is odd is that the ResBlock works in the cookbook:

ResBlock = stax.serial(
    stax.FanOut(2),
    stax.parallel(
        stax.serial(
            stax.Erf(),
            stax.Dense(512, W_std=1.1, b_std=0),
        ),
        stax.Identity()
        ,
    stax.FanInSum()
)

And it appears that with linear_scaling=True that the is_gaussian=True from this line:
https://github.com/google/neural-tangents/blob/c17e770bb74f1771da7be4a69fabfa68b6078960/neural_tangents/_src/stax/linear.py#L2464C14-L2468C39

Eventually would like to also include causal masking, and if you have pointers there that would also be great as it is also not clear how to do a upper triangular mask in the infinite width seq len case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant