@@ -592,6 +592,7 @@ def get_tuning_config(self) -> TuningConfig:
592592 5 , 0 ,
593593 helper .infer_shape_max_num_tiles )),
594594 inputs_pre_hook = helper .inputs_pre_hook ,
595+ use_cuda_graph = True ,
595596 )
596597 return self .__class__ .tuning_config_cache [key ]
597598
@@ -869,6 +870,7 @@ def get_tuning_config(self) -> TuningConfig:
869870 7 , 0 , helper .infer_shape_max_num_permuted_tokens ),
870871 ConstraintSpec (9 , 0 , helper .infer_shape_num_tokens )),
871872 inputs_pre_hook = helper .inputs_pre_hook_finalize_fusion ,
873+ use_cuda_graph = True ,
872874 )
873875 return self .__class__ .tuning_config_cache [key ]
874876
@@ -1183,6 +1185,7 @@ def get_tuning_config(self) -> TuningConfig:
11831185 5 , 0 ,
11841186 helper .infer_shape_max_num_tiles )),
11851187 inputs_pre_hook = helper .inputs_pre_hook ,
1188+ use_cuda_graph = True ,
11861189 )
11871190 return self .__class__ .tuning_config_cache [key ]
11881191
@@ -1470,6 +1473,7 @@ def get_tuning_config(self) -> TuningConfig:
14701473 ConstraintSpec (
14711474 3 , 0 , helper .infer_shape_num_tokens )),
14721475 inputs_pre_hook = helper .inputs_pre_hook ,
1476+ use_cuda_graph = True ,
14731477 )
14741478 return self .__class__ .tuning_config_cache [key ]
14751479
0 commit comments