@@ -183,8 +183,7 @@ class BuildArgs:
183183 default = False ,
184184 metadata = {
185185 "help" : (
186- "Offload attention operations to CUTLASS when the target is CUDA"
187- "and TVM has been built with CUTLASS enabled."
186+ "Disable offloading attention operations to CUTLASS."
188187 ),
189188 "action" : "store_true" ,
190189 },
@@ -193,8 +192,7 @@ class BuildArgs:
193192 default = False ,
194193 metadata = {
195194 "help" : (
196- "Offload layer and RMS norm operations to CUTLASS when the target is CUDA"
197- "and TVM has been built with CUTLASS enabled."
195+ "Disable offloading layer and RMS norm operations to CUTLASS."
198196 ),
199197 "action" : "store_true" ,
200198 },
@@ -229,6 +227,15 @@ class BuildArgs:
229227 ),
230228 },
231229 )
230+ use_flash_attn_mqa : bool = field (
231+ default = False ,
232+ metadata = {
233+ "help" : (
234+ "Offload multi-query attention workload to Flash Attention."
235+ ),
236+ "action" : "store_true" ,
237+ },
238+ )
232239
233240
234241def convert_build_args_to_argparser () -> argparse .ArgumentParser :
@@ -404,8 +411,13 @@ def mod_transform_before_build(
404411 has_cutlass = tvm .get_global_func ("relax.ext.cutlass" , True )
405412
406413 if has_cutlass and not args .no_cutlass_attn :
407- mod ["prefill" ] = rewrite_attention (mod ["prefill" ])
408- mod ["decode" ] = rewrite_attention (mod ["decode" ])
414+ if args .use_flash_attn_mqa :
415+ mod ["prefill" ] = rewrite_attention (mod ["prefill" ], use_flash_mqa = True )
416+ mod ["decode" ] = rewrite_attention (mod ["decode" ], use_flash_mqa = True )
417+
418+ mod ["prefill" ] = rewrite_attention (mod ["prefill" ], use_flash_mqa = False )
419+ mod ["decode" ] = rewrite_attention (mod ["decode" ], use_flash_mqa = False )
420+
409421 patterns += get_patterns_with_prefix ("cutlass.attention" )
410422
411423 if has_cutlass and not args .no_cutlass_norm :
0 commit comments