From f58a5e0cadac3590af244c5c432e717e57337d72 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 31 Aug 2023 17:18:34 +0000 Subject: [PATCH 1/6] wip --- CMakeLists.txt | 4 +++- cpp/cli_main.cc | 4 ++++ mlc_llm/transform/rewrite_attention.py | 4 ++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 31943c36c5..fddae3a0ad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,8 @@ cmake_minimum_required(VERSION 3.18) project(mlc_llm C CXX) +link_directories(/home/masahi/projects/dev/tvm/build/3rdparty/libflash_attn/src) + include(CheckCXXCompilerFlag) if(NOT MSVC) check_cxx_compiler_flag("-std=c++17" SUPPORT_CXX17) @@ -94,7 +96,7 @@ add_library(mlc_llm_static STATIC $) add_dependencies(mlc_llm_static tokenizers_cpp sentencepiece-static tokenizers_c tvm_runtime) set_target_properties(mlc_llm_static PROPERTIES OUTPUT_NAME mlc_llm) -target_link_libraries(mlc_llm PUBLIC tvm_runtime) +target_link_libraries(mlc_llm PUBLIC tvm_runtime -Wl,--no-as-needed flash_attn) target_link_libraries(mlc_llm PRIVATE tokenizers_cpp) if (BUILD_CPP_TEST) diff --git a/cpp/cli_main.cc b/cpp/cli_main.cc index 93063aba68..08fd78f04a 100644 --- a/cpp/cli_main.cc +++ b/cpp/cli_main.cc @@ -426,6 +426,10 @@ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id PrintSpecialCommands(); chat->Reload(model); chat->ProcessSystemPrompts(); + auto input = "Who is Shohei Ohtani?"; + Converse(chat, input, stream_interval, std::cout); + std::cout << chat->RuntimeStatsText() << std::endl << std::flush; + return; while (true) { std::string input; std::cout << chat->GetRole0() << ": " << std::flush; diff --git a/mlc_llm/transform/rewrite_attention.py b/mlc_llm/transform/rewrite_attention.py index 05ee7a8f29..1ad2874b5b 100644 --- a/mlc_llm/transform/rewrite_attention.py +++ b/mlc_llm/transform/rewrite_attention.py @@ -8,8 +8,8 @@ def rewrite_attention(f): V = wildcard() Q_BNSH = is_op("relax.permute_dims")(Q) - K_BNSH = is_op("relax.permute_dims")(K) - V_BNSH = is_op("relax.permute_dims")(V) + K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K)) + V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V)) K_BNSH_T = is_op("relax.permute_dims")(K_BNSH) From 3190210f4de8f66a7463240f077ae57ac4ecc2e7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Sep 2023 22:40:06 +0000 Subject: [PATCH 2/6] update --- mlc_llm/core.py | 26 ++++++++++++++++++++------ mlc_llm/transform/rewrite_attention.py | 10 +++++++--- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 190b7b833f..12d2d727b1 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -182,8 +182,7 @@ class BuildArgs: default=False, metadata={ "help": ( - "Offload attention operations to CUTLASS when the target is CUDA" - "and TVM has been built with CUTLASS enabled." + "Disable offloading attention operations to CUTLASS." ), "action": "store_true", }, @@ -192,8 +191,7 @@ class BuildArgs: default=False, metadata={ "help": ( - "Offload layer and RMS norm operations to CUTLASS when the target is CUDA" - "and TVM has been built with CUTLASS enabled." + "Disable offloading layer and RMS norm operations to CUTLASS." ), "action": "store_true", }, @@ -228,6 +226,15 @@ class BuildArgs: ), }, ) + no_flash_attn_mqa: bool = field( + default=False, + metadata={ + "help": ( + "Disable offloading multi-query attention workload to Flash Attention." + ), + "action": "store_true", + }, + ) def convert_build_args_to_argparser() -> argparse.ArgumentParser: @@ -399,8 +406,15 @@ def mod_transform_before_build( has_cutlass = tvm.get_global_func("relax.ext.cutlass", True) if has_cutlass and not args.no_cutlass_attn: - mod["prefill"] = rewrite_attention(mod["prefill"]) - mod["decode"] = rewrite_attention(mod["decode"]) + use_flash_mqa = not args.no_flash_attn_mqa + + if use_flash_mqa: + mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=True) + mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=True) + + mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=False) + mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=False) + patterns += get_patterns_with_prefix("cutlass.attention") if has_cutlass and not args.no_cutlass_norm: diff --git a/mlc_llm/transform/rewrite_attention.py b/mlc_llm/transform/rewrite_attention.py index 1ad2874b5b..bc90a58718 100644 --- a/mlc_llm/transform/rewrite_attention.py +++ b/mlc_llm/transform/rewrite_attention.py @@ -2,14 +2,18 @@ from tvm.script import relax as R -def rewrite_attention(f): +def rewrite_attention(f, use_flash_mqa=False): Q = wildcard() K = wildcard() V = wildcard() + if use_flash_mqa: + K = is_op("relax.repeat")(K) + V = is_op("relax.repeat")(V) + Q_BNSH = is_op("relax.permute_dims")(Q) - K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K)) - V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V)) + K_BNSH = is_op("relax.permute_dims")(K) + V_BNSH = is_op("relax.permute_dims")(V) K_BNSH_T = is_op("relax.permute_dims")(K_BNSH) From 03f9bf34609c07c7942864b860545938ee0eb493 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Sep 2023 23:24:06 +0000 Subject: [PATCH 3/6] fix --- mlc_llm/transform/rewrite_attention.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mlc_llm/transform/rewrite_attention.py b/mlc_llm/transform/rewrite_attention.py index bc90a58718..b6d2a493ab 100644 --- a/mlc_llm/transform/rewrite_attention.py +++ b/mlc_llm/transform/rewrite_attention.py @@ -7,13 +7,14 @@ def rewrite_attention(f, use_flash_mqa=False): K = wildcard() V = wildcard() - if use_flash_mqa: - K = is_op("relax.repeat")(K) - V = is_op("relax.repeat")(V) - Q_BNSH = is_op("relax.permute_dims")(Q) - K_BNSH = is_op("relax.permute_dims")(K) - V_BNSH = is_op("relax.permute_dims")(V) + + if use_flash_mqa: + K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K)) + V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V)) + else: + K_BNSH = is_op("relax.permute_dims")(K) + V_BNSH = is_op("relax.permute_dims")(V) K_BNSH_T = is_op("relax.permute_dims")(K_BNSH) From 40a582f8ffd44d884e915ae6cf6966e5c66815b5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Sep 2023 23:57:03 +0000 Subject: [PATCH 4/6] fix cmake --- CMakeLists.txt | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index fddae3a0ad..84df97ebfd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,8 +1,6 @@ cmake_minimum_required(VERSION 3.18) project(mlc_llm C CXX) -link_directories(/home/masahi/projects/dev/tvm/build/3rdparty/libflash_attn/src) - include(CheckCXXCompilerFlag) if(NOT MSVC) check_cxx_compiler_flag("-std=c++17" SUPPORT_CXX17) @@ -96,9 +94,17 @@ add_library(mlc_llm_static STATIC $) add_dependencies(mlc_llm_static tokenizers_cpp sentencepiece-static tokenizers_c tvm_runtime) set_target_properties(mlc_llm_static PROPERTIES OUTPUT_NAME mlc_llm) -target_link_libraries(mlc_llm PUBLIC tvm_runtime -Wl,--no-as-needed flash_attn) +target_link_libraries(mlc_llm PUBLIC tvm_runtime) target_link_libraries(mlc_llm PRIVATE tokenizers_cpp) +find_library(FLASH_ATTN_LIBRARY flash_attn) + +if (FLASH_ATTN_LIBRARY STREQUAL "FLASH_ATTN_LIBRARY-NOTFOUND") + message(WARNING "Cannot find libflash_attn.") +else () + target_link_libraries(mlc_llm PUBLIC -Wl,--no-as-needed ${FLASH_ATTN_LIBRARY}) +endif() + if (BUILD_CPP_TEST) message(STATUS "Building cpp unittests") add_subdirectory(3rdparty/googletest) From 93ac3375957b36b3f3790b627db412d948195954 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 28 Sep 2023 00:13:02 +0000 Subject: [PATCH 5/6] disable by default --- mlc_llm/core.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 12d2d727b1..43aa83d0da 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -226,11 +226,11 @@ class BuildArgs: ), }, ) - no_flash_attn_mqa: bool = field( + use_flash_attn_mqa: bool = field( default=False, metadata={ "help": ( - "Disable offloading multi-query attention workload to Flash Attention." + "Offload multi-query attention workload to Flash Attention." ), "action": "store_true", }, @@ -406,9 +406,7 @@ def mod_transform_before_build( has_cutlass = tvm.get_global_func("relax.ext.cutlass", True) if has_cutlass and not args.no_cutlass_attn: - use_flash_mqa = not args.no_flash_attn_mqa - - if use_flash_mqa: + if args.use_flash_attn_mqa: mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=True) mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=True) From c0856f09b5b1f555ce06a2ec1a59f068b1e75fbe Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 28 Sep 2023 00:15:50 +0000 Subject: [PATCH 6/6] fix --- CMakeLists.txt | 2 +- cpp/cli_main.cc | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 84df97ebfd..eb09469b48 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -100,7 +100,7 @@ target_link_libraries(mlc_llm PRIVATE tokenizers_cpp) find_library(FLASH_ATTN_LIBRARY flash_attn) if (FLASH_ATTN_LIBRARY STREQUAL "FLASH_ATTN_LIBRARY-NOTFOUND") - message(WARNING "Cannot find libflash_attn.") + message(WARNING "Cannot find libflash_attn. The model must not have been built with --use-flash-attn-mqa option.") else () target_link_libraries(mlc_llm PUBLIC -Wl,--no-as-needed ${FLASH_ATTN_LIBRARY}) endif() diff --git a/cpp/cli_main.cc b/cpp/cli_main.cc index 08fd78f04a..93063aba68 100644 --- a/cpp/cli_main.cc +++ b/cpp/cli_main.cc @@ -426,10 +426,6 @@ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id PrintSpecialCommands(); chat->Reload(model); chat->ProcessSystemPrompts(); - auto input = "Who is Shohei Ohtani?"; - Converse(chat, input, stream_interval, std::cout); - std::cout << chat->RuntimeStatsText() << std::endl << std::flush; - return; while (true) { std::string input; std::cout << chat->GetRole0() << ": " << std::flush;