Skip to content

GQA models have not supported prefix caching#2873

Closed
toslunar wants to merge 1 commit intovllm-project:mainfrom
toslunar:prefix-gqa-not-yet
Closed

GQA models have not supported prefix caching#2873
toslunar wants to merge 1 commit intovllm-project:mainfrom
toslunar:prefix-gqa-not-yet

Conversation

@toslunar
Copy link
Copy Markdown
Contributor

I found a model that uses GQA returns wrong result with prefix_pos. After some investigation, the code to support MQA/GQA

if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :, None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])

, which repeats the inputs, is not compatible with the current implementation of prefix caching (context_attention_fwd).

To support MQA/GQA,

                if self.num_kv_heads != self.num_heads:
                    query = query.view(batch_size * seq_len, self.num_heads, self.head_size)
                    key = key.reshape(batch_size * seq_len, self.num_heads, self.head_size)
                    value = value.reshape(batch_size * seq_len, self.num_heads, self.head_size)

is closer, but KV of prefix should also be expanded (after they are read from key_cache and value_cache).

@sighingnow
Copy link
Copy Markdown
Collaborator

The issue was addressed by #3007

@WoosukKwon WoosukKwon closed this Mar 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants