-
Notifications
You must be signed in to change notification settings - Fork 27
Support cross attention kv cache #187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| self.cross_attention_cache = StaticCache( | ||
| config=self.config, | ||
| max_batch_size=batch_size, | ||
| max_cache_len=getattr(self.config, "max_source_positions", max_static_cache_length), # This is fixed in whisper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull this outside into a var like the other arguments
| self.cross_attention_cache = StaticCache( | ||
| config=self.config, | ||
| max_batch_size=batch_size, | ||
| max_cache_len=getattr(self.config, "max_source_positions", max_static_cache_length), # This is fixed in whisper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also what do you mean this is fixed in whisper? Will this work for t5?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically they always have 1500 for max_source_positions and that translates to 30 seconds of audio. So we should use that for cache len. For T5 I don't know and that's why I name this class WhisperCrossAttention.
| f"cross_attention_value_cache_{i}", self.cross_attention_cache.layers[i].values, persistent=False | ||
| ) | ||
|
|
||
| # Massage decoder to use cross attention. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # Massage decoder to use cross attention. | |
| # Use custom cross attention for Whisper. |
| # limitations under the License. | ||
|
|
||
| # Export friendly cross attention implementation for Whisper. Adopted | ||
| # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L241 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L241 | |
| # from https://github.com/huggingface/transformers/blob/454c0a7ccf33f7fc13e3e2eb9b188a5c09ab708b/src/transformers/models/whisper/modeling_whisper.py#L241 |
Permalink is better in case code changes
| {"cache_position": None}, | ||
| ) | ||
|
|
||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove this else branch if we aren't expecting to use it?
| ) | ||
|
|
||
| # Update the KV cache outside of torch.cond. | ||
| past_key_values.update( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not put this inside the recompute_kv branch?
jackzhxng
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh also run make style for formatting
To avoid excessive computation we want to support kv cache for cross attention in Whisper.
Fundamentally we only run
k_projandv_projonce on the encoder output hidden state, at the first token generation, then we should keep thekey_statesandvalue_statesand reuse them in all the subsequent token generation.For whisper-large-v3-turbo, where we have 4 layers of decoder:
Without KV cache in
encoder_attn, we are doing 2 1280x1280 MM for each layer, so in total 8 1280x1280 MM for each token generated. This largely impacts token/sec perf number.This PR replaces
encoder_attnwith aWhisperCrossAttentionclass, where we replacesifcondition withtorch.cond. The logic becomes:Notice that we still have 1 extra read and 1 extra write, but it should be much faster than MM.