Skip to content

Commit 1c82625

Browse files
authored
Enable offloading multi-query attention by Flash Attention (#990)
* wip * update * fix * fix cmake * disable by default * fix
1 parent 1384f23 commit 1c82625

File tree

3 files changed

+34
-9
lines changed

3 files changed

+34
-9
lines changed

CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ set_target_properties(mlc_llm_static PROPERTIES OUTPUT_NAME mlc_llm)
9797
target_link_libraries(mlc_llm PUBLIC tvm_runtime)
9898
target_link_libraries(mlc_llm PRIVATE tokenizers_cpp)
9999

100+
find_library(FLASH_ATTN_LIBRARY flash_attn)
101+
102+
if (FLASH_ATTN_LIBRARY STREQUAL "FLASH_ATTN_LIBRARY-NOTFOUND")
103+
message(WARNING "Cannot find libflash_attn. The model must not have been built with --use-flash-attn-mqa option.")
104+
else ()
105+
target_link_libraries(mlc_llm PUBLIC -Wl,--no-as-needed ${FLASH_ATTN_LIBRARY})
106+
endif()
107+
100108
if (BUILD_CPP_TEST)
101109
message(STATUS "Building cpp unittests")
102110
add_subdirectory(3rdparty/googletest)

mlc_llm/core.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,7 @@ class BuildArgs:
183183
default=False,
184184
metadata={
185185
"help": (
186-
"Offload attention operations to CUTLASS when the target is CUDA"
187-
"and TVM has been built with CUTLASS enabled."
186+
"Disable offloading attention operations to CUTLASS."
188187
),
189188
"action": "store_true",
190189
},
@@ -193,8 +192,7 @@ class BuildArgs:
193192
default=False,
194193
metadata={
195194
"help": (
196-
"Offload layer and RMS norm operations to CUTLASS when the target is CUDA"
197-
"and TVM has been built with CUTLASS enabled."
195+
"Disable offloading layer and RMS norm operations to CUTLASS."
198196
),
199197
"action": "store_true",
200198
},
@@ -229,6 +227,15 @@ class BuildArgs:
229227
),
230228
},
231229
)
230+
use_flash_attn_mqa: bool = field(
231+
default=False,
232+
metadata={
233+
"help": (
234+
"Offload multi-query attention workload to Flash Attention."
235+
),
236+
"action": "store_true",
237+
},
238+
)
232239

233240

234241
def convert_build_args_to_argparser() -> argparse.ArgumentParser:
@@ -404,8 +411,13 @@ def mod_transform_before_build(
404411
has_cutlass = tvm.get_global_func("relax.ext.cutlass", True)
405412

406413
if has_cutlass and not args.no_cutlass_attn:
407-
mod["prefill"] = rewrite_attention(mod["prefill"])
408-
mod["decode"] = rewrite_attention(mod["decode"])
414+
if args.use_flash_attn_mqa:
415+
mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=True)
416+
mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=True)
417+
418+
mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=False)
419+
mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=False)
420+
409421
patterns += get_patterns_with_prefix("cutlass.attention")
410422

411423
if has_cutlass and not args.no_cutlass_norm:

mlc_llm/transform/rewrite_attention.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,19 @@
22
from tvm.script import relax as R
33

44

5-
def rewrite_attention(f):
5+
def rewrite_attention(f, use_flash_mqa=False):
66
Q = wildcard()
77
K = wildcard()
88
V = wildcard()
99

1010
Q_BNSH = is_op("relax.permute_dims")(Q)
11-
K_BNSH = is_op("relax.permute_dims")(K)
12-
V_BNSH = is_op("relax.permute_dims")(V)
11+
12+
if use_flash_mqa:
13+
K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K))
14+
V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V))
15+
else:
16+
K_BNSH = is_op("relax.permute_dims")(K)
17+
V_BNSH = is_op("relax.permute_dims")(V)
1318

1419
K_BNSH_T = is_op("relax.permute_dims")(K_BNSH)
1520

0 commit comments

Comments
 (0)