diff --git a/CMakeLists.txt b/CMakeLists.txt index 31943c36c5..eb09469b48 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -97,6 +97,14 @@ set_target_properties(mlc_llm_static PROPERTIES OUTPUT_NAME mlc_llm) target_link_libraries(mlc_llm PUBLIC tvm_runtime) target_link_libraries(mlc_llm PRIVATE tokenizers_cpp) +find_library(FLASH_ATTN_LIBRARY flash_attn) + +if (FLASH_ATTN_LIBRARY STREQUAL "FLASH_ATTN_LIBRARY-NOTFOUND") + message(WARNING "Cannot find libflash_attn. The model must not have been built with --use-flash-attn-mqa option.") +else () + target_link_libraries(mlc_llm PUBLIC -Wl,--no-as-needed ${FLASH_ATTN_LIBRARY}) +endif() + if (BUILD_CPP_TEST) message(STATUS "Building cpp unittests") add_subdirectory(3rdparty/googletest) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 190b7b833f..43aa83d0da 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -182,8 +182,7 @@ class BuildArgs: default=False, metadata={ "help": ( - "Offload attention operations to CUTLASS when the target is CUDA" - "and TVM has been built with CUTLASS enabled." + "Disable offloading attention operations to CUTLASS." ), "action": "store_true", }, @@ -192,8 +191,7 @@ class BuildArgs: default=False, metadata={ "help": ( - "Offload layer and RMS norm operations to CUTLASS when the target is CUDA" - "and TVM has been built with CUTLASS enabled." + "Disable offloading layer and RMS norm operations to CUTLASS." ), "action": "store_true", }, @@ -228,6 +226,15 @@ class BuildArgs: ), }, ) + use_flash_attn_mqa: bool = field( + default=False, + metadata={ + "help": ( + "Offload multi-query attention workload to Flash Attention." + ), + "action": "store_true", + }, + ) def convert_build_args_to_argparser() -> argparse.ArgumentParser: @@ -399,8 +406,13 @@ def mod_transform_before_build( has_cutlass = tvm.get_global_func("relax.ext.cutlass", True) if has_cutlass and not args.no_cutlass_attn: - mod["prefill"] = rewrite_attention(mod["prefill"]) - mod["decode"] = rewrite_attention(mod["decode"]) + if args.use_flash_attn_mqa: + mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=True) + mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=True) + + mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=False) + mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=False) + patterns += get_patterns_with_prefix("cutlass.attention") if has_cutlass and not args.no_cutlass_norm: diff --git a/mlc_llm/transform/rewrite_attention.py b/mlc_llm/transform/rewrite_attention.py index 05ee7a8f29..b6d2a493ab 100644 --- a/mlc_llm/transform/rewrite_attention.py +++ b/mlc_llm/transform/rewrite_attention.py @@ -2,14 +2,19 @@ from tvm.script import relax as R -def rewrite_attention(f): +def rewrite_attention(f, use_flash_mqa=False): Q = wildcard() K = wildcard() V = wildcard() Q_BNSH = is_op("relax.permute_dims")(Q) - K_BNSH = is_op("relax.permute_dims")(K) - V_BNSH = is_op("relax.permute_dims")(V) + + if use_flash_mqa: + K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K)) + V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V)) + else: + K_BNSH = is_op("relax.permute_dims")(K) + V_BNSH = is_op("relax.permute_dims")(V) K_BNSH_T = is_op("relax.permute_dims")(K_BNSH)