[Unity][BYOC] Support offloading multi-query attention by Flash Attention #15831
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
MQA, as used by llama 2 70B and codellama 34B, is a common way to reduce runtime memory bandwidth for big LLMs. So far we haven't been taking advantage of this optimization: We do explicit
repeatof KV tensors (https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/relax_model/llama.py#L384-L385) and use the regular attention, which defeats the purpose of MQA.While CUTLASS fMHA doesn't support MQA (where
num_q_head != num_kv_head), flash attention does support it. So I added a new option topartition_for_cutlassto enable pattern matching therepeatop during attention rewriting: When we detect an attention pattern where KV tensors are first expanded byrepeat, we recognize it as MQA and dispatch it to flash attention.For now this feature is opt-in, since it would force using flash attention for causal inference, but I haven't thoroughly validated its performance against such workloads. For example, even though flash attention v2 supports causal decoding inference where
seq_q_len = 1as of Dao-AILab/flash-attention@e07aa03, cutlass fMHA can still be faster for such workloads. But based on feedback I can enable MQA offloading by default to avoid introducing another param.That said, flash is definitely advantageous for MQA. For example, the following nvprof output for codellama 34B with 16k context length, using
repeatand cutlass fMHA, shows large overhead from repeat:Using flash attn MQA, the repeat overhead is completely gone and I get a few token / sec improvement due to this optimization. Also note that the cutlass fMHA perf above and the flash attn MQA perf below are roughly the same, indicating the relative superiority of the former kernel for this workload.
@vinx13 @cyx-6 @yzh119 @sunggg