Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Average Attention or Attention on last token? #14

Closed
feiyu12138 opened this issue Apr 26, 2024 · 8 comments
Closed

Average Attention or Attention on last token? #14

feiyu12138 opened this issue Apr 26, 2024 · 8 comments

Comments

@feiyu12138
Copy link

Hi,

In the paper and comments, it shows (to rerank the visual tokens) average attention is calculated across all tokens. However, the code shows it's actually calculating the attention on the last token, which is obviously different from the description. Would you mind make it clear which strategy is better? Thank you very much!

@feiyu12138
Copy link
Author

Now I get it. Essentially, you can't do filtering at first step if you wanna reproduce the performance(w/o cache). But at following step, you can only access the last token. This is reason for last-token-guided ranking

@zjysteven
Copy link

Hi @feiyu12138, I'm actually having the same question. Would you mind elaborating more on what do you mean by "you can't do filtering at first step if you wanna reproduce the performance (w/o cache)"?

@feiyu12138
Copy link
Author

feiyu12138 commented Jul 4, 2024

Hi @feiyu12138, I'm actually having the same question. Would you mind elaborating more on what do you mean by "you can't do filtering at first step if you wanna reproduce the performance (w/o cache)"?

Hi,
If you purge low-rank tokens at first step, you actually can't rank the token when you predict following tokens(since 2nd answer tokens). Therefore, you have to store KV states for all tokens at all layers at first step. Does it make sense to you?

@zjysteven
Copy link

zjysteven commented Jul 4, 2024

I'm not yet familiar with the internal inference and KV cache of LLMs (only have high-level ideas), so I haven't really followed you here.

Let's assume KV cache is not used (which should make things simpler?) Let's say we prune the tokens at layer K. Then we will have all image tokens available across layer 1 to K - 1, regardless of which answer token is being generated, right? Then I don't see why pruning at step 1 will affect step 2 and beyond.

I know it's definitely not your job to answer my question, so I do appreciate your time and discussion.

@chenllliang chenllliang reopened this Jul 4, 2024
@feiyu12138
Copy link
Author

feiyu12138 commented Jul 4, 2024

I'm not yet familiar with the internal inference and KV cache of LLMs (only have high-level ideas), so I haven't really followed you here.

Let's assume KV cache is not used (which should make things simpler?) Let's say we prune the tokens at layer K. Then we will have all image tokens available across layer 1 to K - 1, regardless of which answer token is being generated, right? Then I don't see why pruning at step 1 will affect step 2 and beyond.

I know it's definitely not your job to answer my question, so I do appreciate your time and discussion.

You are right about "all image tokens available across layer 1 to K - 1". However, in case that the second token has different attention on visual tokens, then it's possible that for the second token generation, we have to use different tokens at layer K to 32. But if you didn't store all KV states, there is no available KV at layer K to 32 you need.

@zjysteven
Copy link

@chenllliang Is it true that the first generated token's forward pass won't have vision tokens pruned? If so I'm really confused since at least when evaluating aokvqa the max new tokens is set to 1. Would appreciate confirmation and thoughts from the authors too.

@chenllliang
Copy link
Member

@zjysteven , FastV works differently between with and without KV cache, I'll explain seperately.

  • When KV cache is not applied, the decoding of each output token is individual, then every token's forward pass would have different vision tokens pruned. In the okvqa evaluation from the repo, the generated token's forward pass do have vision tokens pruned.
  • When KV cache is applied, there are different ways to integreate FastV:
    • In the first forward pass, do not prune the visual tokens in order to save the full kv-caches of the visual token. In the later output tokens' decoding, prune the kv-caches of the visual tokens. It is expected to have the same results with the non-kv-cache version.
    • In the first forward pass, prune the visual tokens and only save the kv-cache of the unpruned tokens. In the later output tokens' decoding, do not prune again since they can only see the pruned kv-cache instead of the full image's kv-cache. This would lead to slightly different results compared to the non-kv-cache version since all output tokens have the same visual tokens pruned. (We implement this in the lmms-eval part)

@feiyu12138 thanks for your clear explaination as well!

@zjysteven
Copy link

That makes perfect sense and clears up my confusion. Thank you both!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants