Native Dot product attention #22760
Unanswered
AakashKumarNain
asked this question in
General
Replies: 2 comments 7 replies
-
Looking into this and will make the API accept no batch input so it works better with vmap. |
Beta Was this translation helpful? Give feedback.
6 replies
-
Can you give a shot of this change: #22830? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
With the latest release, we now have a native implementation of dot product attention provided using
jax.nn.dot_product_attention(...)
. This is good except for one thing. When building nn in jax using libraries like Equinox, we are used to implement functionalities that work for a single example (not a batch_size of 1), and then wevmap
over the model to make it work a batch. With the batch axis included in the attention implementation, I have two options:qkv
, and thensqueeze
the output of attentionI can do either but it seems to have broken my mental model for using
vmap
because in second case what does evenbatch
mean anymore?Beta Was this translation helpful? Give feedback.
All reactions