Skip to content

Conversation

@masahi
Copy link
Member

@masahi masahi commented Sep 27, 2023

MQA, as used by llama 2 70B and codellama 34B, is a common way to reduce runtime memory bandwidth for big LLMs. So far we haven't been taking advantage of this optimization: We do explicit repeat of KV tensors (https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/relax_model/llama.py#L384-L385) and use the regular attention, which defeats the purpose of MQA.

While CUTLASS fMHA doesn't support MQA (where num_q_head != num_kv_head), flash attention does support it. So I added a new option to partition_for_cutlass to enable pattern matching the repeat op during attention rewriting: When we detect an attention pattern where KV tensors are first expanded by repeat, we recognize it as MQA and dispatch it to flash attention.

For now this feature is opt-in, since it would force using flash attention for causal inference, but I haven't thoroughly validated its performance against such workloads. For example, even though flash attention v2 supports causal decoding inference where seq_q_len = 1 as of Dao-AILab/flash-attention@e07aa03, cutlass fMHA can still be faster for such workloads. But based on feedback I can enable MQA offloading by default to avoid introducing another param.

That said, flash is definitely advantageous for MQA. For example, the following nvprof output for codellama 34B with 16k context length, using repeat and cutlass fMHA, shows large overhead from repeat:

 Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                                                  Name            
                                     
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ----------------------------------------------------------------
------------------------------------ 
     30.5      21627490775      49152   440012.4   440224.0    420416    454336       3642.8  ampere_h16816gemm_64x64_sliced1x2_ldg8_stages_64x6_tn           
                                     
     22.0      15608458542      98304   158777.5   155791.0     85152    245310      69134.0  ampere_h16816gemm_128x64_sliced1x2_ldg8_relu_stages_64x6_tn     
                                                                               
     12.3       8753808009      99840    87678.4    87809.0     82080    103328       3242.2  repeat_kernel                                                   
                                                                                                                                                              
     11.1       7900613613       2304  3429085.8  1829756.5   1206877   7376321    2590665.8  ampere_h16816gemm_256x128_ldg8_stages_64x3_tn                   
                                                                                                                                                              
      8.1       5772286492      50192   115004.1   110656.0    102752    331968      30345.8  ampere_h16816gemm_64x64_sliced1x2_ldg8_stages_64x5_tn           
                                                                                                                                                              
      7.8       5541100467      49920   110999.6    93888.0     90080   1252571     135536.9  void attention_kernel_batched_impl<AttentionKer

Using flash attn MQA, the repeat overhead is completely gone and I get a few token / sec improvement due to this optimization. Also note that the cutlass fMHA perf above and the flash attn MQA perf below are roughly the same, indicating the relative superiority of the former kernel for this workload.

 Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                                                  Name

 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ----------------------------------------------------------------
------------------------------------
     44.7       7950055127       2304  3450544.8  1851210.0   1204583   7478071    2605192.6  ampere_h16816gemm_256x128_ldg8_stages_64x3_tn

     30.8       5464607061      49920   109467.3   100545.0     96961    684644      66473.6  void flash_fwd_kernel<Flash_fwd_kernel_traits<(int)128, (int)128
, (int)64, (int)4, (bool)0, (bool)0…
     15.9       2817276679        768  3668329.0  3675380.5   3196017   3876989      91606.0  ampere_h16816gemm_128x256_ldg8_stages_64x3_tn

      1.9        340279055       1088   312756.5   321794.0    107648    332097      42981.1  ampere_h16816gemm_64x64_sliced1x2_ldg8_stages_64x5_tn

@vinx13 @cyx-6 @yzh119 @sunggg

masahi and others added 5 commits September 27, 2023 09:58
commit 99c2a59
Author: Masahiro Masuda <[email protected]>
Date:   Wed Sep 27 09:57:21 2023 +0900

    Revert "Revert "[Unity] Avoid trivial `var2 = var1` bindings in pattern matcher (apache#15578)""

    This reverts commit 0a6a617.

commit 9a3ca64
Author: Masahiro Masuda <[email protected]>
Date:   Wed Sep 27 09:55:02 2023 +0900

    wip

commit be01900
Author: Masahiro Masuda <[email protected]>
Date:   Tue Sep 26 19:55:29 2023 +0900

    fix test

commit a026b65
Author: Masahiro Masuda <[email protected]>
Date:   Thu Aug 31 22:24:38 2023 +0000

    wip

commit 233d2d0
Author: Masahiro Masuda <[email protected]>
Date:   Tue Aug 29 17:42:11 2023 +0000

    wip

commit 0a6a617
Author: Masahiro Masuda <[email protected]>
Date:   Tue Aug 29 17:28:25 2023 +0000

    Revert "[Unity] Avoid trivial `var2 = var1` bindings in pattern matcher (apache#15578)"

    This reverts commit 567848e.

commit 6c5a435
Author: Masahiro Masuda <[email protected]>
Date:   Tue Aug 29 17:28:16 2023 +0000

    wip

commit 7926cbc
Author: Masahiro Masuda <[email protected]>
Date:   Mon Aug 28 06:17:01 2023 +0000

    wip

commit 9828698
Author: Masahiro Masuda <[email protected]>
Date:   Mon Aug 28 15:11:47 2023 +0900

    wip

commit 5d01fd1
Author: Masahiro Masuda <[email protected]>
Date:   Mon Aug 28 06:05:56 2023 +0000

    wip

commit ae657b7
Author: Masahiro Masuda <[email protected]>
Date:   Mon Aug 28 14:49:21 2023 +0900

    wip

commit ddcab38
Author: Masahiro Masuda <[email protected]>
Date:   Mon Aug 28 05:42:41 2023 +0000

    wip

commit ab3572d
Author: Masahiro Masuda <[email protected]>
Date:   Mon Aug 28 10:40:34 2023 +0900

    wip

commit 690b88e
Author: Masahiro Masuda <[email protected]>
Date:   Mon Aug 28 10:25:33 2023 +0900

    update rev
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants