We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Thanks for open-sourcing the flash_attention kernel! A feature is badly needed to support prefix caching.
When q_seq_len < kv_seq_len, current implementation left-aligns q to k/v, e.g.:
q_seq_len
kv_seq_len
q
k
v
# Left alignment: [1, 2, 3, 4] # q [1, 2, 3, 4, 5, 6, 7, 8] # k/v
However, in case of prefix-cache-aware prefill, we need q right-aligns to k/v, like:
# Right alignment: [1, 2, 3, 4] # q [1, 2, 3, 4, 5, 6, 7, 8] # k/v
Would be great to add offset parameter in flash_attention defaulting to 0. If offset > 0, right-shift q by offset tokens; otherwise, left-shift it.
offset
flash_attention
offset > 0
The text was updated successfully, but these errors were encountered:
Which attention kernel are you referring to? (We have like 3 floating around now).
Sorry, something went wrong.
this one:
jax/jax/experimental/pallas/ops/tpu/flash_attention.py
Line 140 in 3a5ac48
No branches or pull requests
Thanks for open-sourcing the flash_attention kernel! A feature is badly needed to support prefix caching.
When
q_seq_len
<kv_seq_len
, current implementation left-alignsq
tok
/v
, e.g.:However, in case of prefix-cache-aware prefill, we need q right-aligns to k/v, like:
Would be great to add
offset
parameter inflash_attention
defaulting to 0. Ifoffset > 0
, right-shiftq
byoffset
tokens; otherwise, left-shift it.The text was updated successfully, but these errors were encountered: