88import  tvm 
99from  tvm  import  meta_schedule  as  ms 
1010from  tvm  import  relax 
11+ from  tvm .relax .backend .pattern_registry  import  get_pattern 
1112
1213import  mlc_llm 
1314from  mlc_llm  import  utils 
1415from  mlc_llm .relax_model  import  gpt_neox , llama , moss 
16+ from  mlc_llm .transform  import  rewrite_attention 
1517
1618
1719def  _parse_args ():
@@ -36,6 +38,7 @@ def _parse_args():
3638        choices = [* utils .quantization_dict .keys ()],
3739        default = list (utils .quantization_dict .keys ())[0 ],
3840    )
41+     args .add_argument ("--cutlass-offload" , action = "store_true" , default = False )
3942    args .add_argument ("--max-seq-len" , type = int , default = - 1 )
4043    args .add_argument ("--target" , type = str , default = "auto" )
4144    args .add_argument (
@@ -185,6 +188,8 @@ def debug_dump_script(mod, name, args):
185188        return 
186189    dump_path  =  os .path .join (args .artifact_path , "debug" , name )
187190    with  open (dump_path , "w" , encoding = "utf-8" ) as  outfile :
191+         # Remove runtime modules from external codegen so that the IR module can be printed. 
192+         mod  =  mod .without_attr ("external_mods" ).without_attr ("const_name_to_constant" )
188193        outfile .write (mod .script (show_meta = True ))
189194    print (f"Dump mod to { dump_path }  " )
190195
@@ -240,11 +245,23 @@ def mod_transform_before_build(
240245            storage_nbit = args .quantization .storage_nbit ,
241246            dtype = args .quantization .model_dtype ,
242247        )(mod )
243-     mod  =  mlc_llm .transform .FuseTransposeMatmul ()(mod )  # pylint: disable=not-callable 
244-     mod  =  relax .pipeline .get_pipeline ()(mod )  # pylint: disable=no-value-for-parameter 
245-     mod  =  mlc_llm .transform .FuseDecodeMatmulEwise (  # pylint: disable=not-callable 
246-         args .quantization .model_dtype , args .target_kind 
247-     )(mod )
248+     if  args .target_kind  ==  "cuda"  and  args .cutlass_offload :
249+         from  tvm .relax .backend .contrib .cutlass  import  partition_for_cutlass 
250+ 
251+         debug_dump_script (mod , "mod_before_cutlass.py" , args )
252+         mod  =  partition_for_cutlass (mod )
253+         debug_dump_script (mod , "mod_after_cutlass_partition.py" , args )
254+         codegen_pass  =  relax .transform .RunCodegen (
255+             {"cutlass" : {"sm" : 80 , "find_first_valid" : False }},
256+             entry_functions = model_names ,
257+         )
258+         mod  =  codegen_pass (mod )
259+         debug_dump_script (mod , "mod_after_cutlass_codegen.py" , args )
260+ 
261+     mod  =  mlc_llm .transform .FuseTransposeMatmul ()(mod )
262+ 
263+     mod  =  relax .pipeline .get_pipeline ()(mod )
264+     mod  =  mlc_llm .transform .FuseDecodeMatmulEwise (args .dtype )(mod )
248265    mod  =  relax .transform .DeadCodeElimination (model_names )(mod )
249266    mod  =  relax .transform .LiftTransformParams ()(mod )
250267    mod_transform , mod_deploy  =  utils .split_transform_deploy_mod (mod , model_names )
@@ -317,10 +334,10 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
317334    ex  =  relax .build (mod_deploy , args .target , system_lib = args .system_lib )
318335
319336    output_filename  =  (
320-         f"{ args .model }  -{ args .quantization .name }  -{ target_kind }  .{ args .lib_format }  " 
337+         f"{ args .model }  -{ args .quantization .name }  -{ target_kind } _ { args . dtype }  .{ args .lib_format }  " 
321338    )
322339
323-     debug_dump_shader (ex , f"{ args .model }  _{ args .quantization .name }  _{ target_kind }  " , args )
340+     debug_dump_shader (ex , f"{ args .model }  _{ args .quantization .name }  _{ target_kind } _ { args . dtype }  " , args )
324341    lib_path  =  os .path .join (args .artifact_path , output_filename )
325342    ex .export_library (lib_path , ** args .export_kwargs )
326343    print (f"Finish exporting to { lib_path }  " )
0 commit comments