Skip to content

Why is Flax Linear not identical to matrix multiplication? #4020

Closed Answered by thijs-vanweezel
thijs-vanweezel asked this question in Q&A
Discussion options

You must be logged in to vote

The answer has been provided on Stackoverflow.

The matmul should be transposed;

y = x.squeeze()@layer.kernel + layer.bias

So, to invert a nnx.Linear operation using nnx.tensorsolve:

solve_batched = jax.vmap(jnp.linalg.tensorsolve)
solve_batched(
    a=jnp.broadcast_to(
        layer.kernel.value.T, # Note the transposition
        (y.shape[0],*layer.kernel.value.shape)), 
    b=y - layer.bias)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by thijs-vanweezel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant