Skip to content

Conversation

@larryliu0820
Copy link
Collaborator

To avoid excessive computation we want to support kv cache for cross attention in Whisper.

Fundamentally we only run k_proj and v_proj once on the encoder output hidden state, at the first token generation, then we should keep the key_states and value_states and reuse them in all the subsequent token generation.

For whisper-large-v3-turbo, where we have 4 layers of decoder:

WhisperDecoder(
  (embed_tokens): Embedding(51866, 1280, padding_idx=50257)
  (embed_positions): WhisperPositionalEmbedding(448, 1280)
  (layers): ModuleList(
    (0-3): 4 x WhisperDecoderLayer(
      (self_attn): WhisperAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (activation_fn): GELUActivation()
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): WhisperAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (encoder_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
)

Without KV cache in encoder_attn, we are doing 2 1280x1280 MM for each layer, so in total 8 1280x1280 MM for each token generated. This largely impacts token/sec perf number.

This PR replaces encoder_attn with a WhisperCrossAttention class, where we replaces if condition with torch.cond. The logic becomes:

  • If KV cache values are all zero:
    • Compute KV projections
  • Otherwise:
    • Clone from KV cache. Note here we can't directly return KV cache, due to the non-aliasing requirement.
  • After torch.cond:
    • Write back the values from either branch back to KV cache

Notice that we still have 1 extra read and 1 extra write, but it should be much faster than MM.

self.cross_attention_cache = StaticCache(
config=self.config,
max_batch_size=batch_size,
max_cache_len=getattr(self.config, "max_source_positions", max_static_cache_length), # This is fixed in whisper
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull this outside into a var like the other arguments

self.cross_attention_cache = StaticCache(
config=self.config,
max_batch_size=batch_size,
max_cache_len=getattr(self.config, "max_source_positions", max_static_cache_length), # This is fixed in whisper
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what do you mean this is fixed in whisper? Will this work for t5?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically they always have 1500 for max_source_positions and that translates to 30 seconds of audio. So we should use that for cache len. For T5 I don't know and that's why I name this class WhisperCrossAttention.

f"cross_attention_value_cache_{i}", self.cross_attention_cache.layers[i].values, persistent=False
)

# Massage decoder to use cross attention.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Massage decoder to use cross attention.
# Use custom cross attention for Whisper.

# limitations under the License.

# Export friendly cross attention implementation for Whisper. Adopted
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L241
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L241
# from https://github.com/huggingface/transformers/blob/454c0a7ccf33f7fc13e3e2eb9b188a5c09ab708b/src/transformers/models/whisper/modeling_whisper.py#L241

Permalink is better in case code changes

{"cache_position": None},
)

else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove this else branch if we aren't expecting to use it?

)

# Update the KV cache outside of torch.cond.
past_key_values.update(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not put this inside the recompute_kv branch?

Copy link
Collaborator

@jackzhxng jackzhxng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh also run make style for formatting

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants