27
27
import coloredlogs
28
28
import onnx
29
29
from fusion_options import FusionOptions
30
+ from onnx_model_bert import BertOnnxModel
30
31
from onnx_model_clip import ClipOnnxModel
31
32
from onnx_model_unet import UnetOnnxModel
32
33
from onnx_model_vae import VaeOnnxModel
@@ -46,9 +47,20 @@ def has_external_data(onnx_model_path):
46
47
return False
47
48
48
49
50
+ def _get_model_list (source_dir : Path ):
51
+ is_xl = (source_dir / "text_encoder_2" ).exists ()
52
+ is_sd3 = (source_dir / "text_encoder_3" ).exists ()
53
+ model_list_sd3 = ["text_encoder" , "text_encoder_2" , "text_encoder_3" , "transformer" , "vae_encoder" , "vae_decoder" ]
54
+ model_list_sdxl = ["text_encoder" , "text_encoder_2" , "unet" , "vae_encoder" , "vae_decoder" ]
55
+ model_list_sd = ["text_encoder" , "unet" , "vae_encoder" , "vae_decoder" ]
56
+ model_list = model_list_sd3 if is_sd3 else (model_list_sdxl if is_xl else model_list_sd )
57
+ return model_list
58
+
59
+
49
60
def _optimize_sd_pipeline (
50
61
source_dir : Path ,
51
62
target_dir : Path ,
63
+ model_list : List [str ],
52
64
use_external_data_format : Optional [bool ],
53
65
float16 : bool ,
54
66
force_fp32_ops : List [str ],
@@ -60,6 +72,7 @@ def _optimize_sd_pipeline(
60
72
Args:
61
73
source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models.
62
74
target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models.
75
+ model_list (List[str]): list of directory names with onnx model.
63
76
use_external_data_format (Optional[bool]): use external data format.
64
77
float16 (bool): use half precision
65
78
force_fp32_ops(List[str]): operators that are forced to run in float32.
@@ -70,18 +83,21 @@ def _optimize_sd_pipeline(
70
83
RuntimeError: output onnx model path existed
71
84
"""
72
85
model_type_mapping = {
86
+ "transformer" : "mmdit" ,
73
87
"unet" : "unet" ,
74
88
"vae_encoder" : "vae" ,
75
89
"vae_decoder" : "vae" ,
76
90
"text_encoder" : "clip" ,
77
91
"text_encoder_2" : "clip" ,
78
92
"safety_checker" : "unet" ,
93
+ "text_encoder_3" : "clip" ,
79
94
}
80
95
81
96
model_type_class_mapping = {
82
97
"unet" : UnetOnnxModel ,
83
98
"vae" : VaeOnnxModel ,
84
99
"clip" : ClipOnnxModel ,
100
+ "mmdit" : BertOnnxModel , # TODO: have a new class for DiT
85
101
}
86
102
87
103
force_fp32_operators = {
@@ -91,10 +107,10 @@ def _optimize_sd_pipeline(
91
107
"text_encoder" : [],
92
108
"text_encoder_2" : [],
93
109
"safety_checker" : [],
110
+ "text_encoder_3" : [],
111
+ "transformer" : [],
94
112
}
95
113
96
- is_xl = (source_dir / "text_encoder_2" ).exists ()
97
-
98
114
if force_fp32_ops :
99
115
for fp32_operator in force_fp32_ops :
100
116
parts = fp32_operator .split (":" )
@@ -108,8 +124,8 @@ def _optimize_sd_pipeline(
108
124
for name , model_type in model_type_mapping .items ():
109
125
onnx_model_path = source_dir / name / "model.onnx"
110
126
if not os .path .exists (onnx_model_path ):
111
- if name != "safety_checker" :
112
- logger .info ("input onnx model does not exist: %s" , onnx_model_path )
127
+ if name != "safety_checker" and name in model_list :
128
+ logger .warning ("input onnx model does not exist: %s" , onnx_model_path )
113
129
# some model are optional so we do not raise error here.
114
130
continue
115
131
@@ -122,7 +138,7 @@ def _optimize_sd_pipeline(
122
138
use_external_data_format = has_external_data (onnx_model_path )
123
139
124
140
# Graph fusion before fp16 conversion, otherwise they cannot be fused later.
125
- logger .info (f "Optimize { onnx_model_path } ..." )
141
+ logger .info ("Optimize %s ..." , onnx_model_path )
126
142
127
143
args .model_type = model_type
128
144
fusion_options = FusionOptions .parse (args )
@@ -147,6 +163,7 @@ def _optimize_sd_pipeline(
147
163
148
164
if float16 :
149
165
# For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32.
166
+ is_xl = (source_dir / "text_encoder_2" ).exists ()
150
167
if is_xl and name == "vae_decoder" :
151
168
logger .info ("Skip converting %s to float16 to avoid NaN" , name )
152
169
else :
@@ -181,17 +198,18 @@ def _optimize_sd_pipeline(
181
198
logger .info ("*" * 20 )
182
199
183
200
184
- def _copy_extra_directory (source_dir : Path , target_dir : Path ):
201
+ def _copy_extra_directory (source_dir : Path , target_dir : Path , model_list : List [ str ] ):
185
202
"""Copy extra directory that does not have onnx model
186
203
187
204
Args:
188
205
source_dir (Path): source directory
189
206
target_dir (Path): target directory
207
+ model_list (List[str]): list of directory names with onnx model.
190
208
191
209
Raises:
192
210
RuntimeError: source path does not exist
193
211
"""
194
- extra_dirs = ["scheduler" , "tokenizer" , "tokenizer_2" , "feature_extractor" ]
212
+ extra_dirs = ["scheduler" , "tokenizer" , "tokenizer_2" , "tokenizer_3" , " feature_extractor" ]
195
213
196
214
for name in extra_dirs :
197
215
source_path = source_dir / name
@@ -213,8 +231,7 @@ def _copy_extra_directory(source_dir: Path, target_dir: Path):
213
231
logger .info ("%s => %s" , source_path , target_path )
214
232
215
233
# Some directory are optional
216
- onnx_model_dirs = ["text_encoder" , "text_encoder_2" , "unet" , "vae_encoder" , "vae_decoder" , "safety_checker" ]
217
- for onnx_model_dir in onnx_model_dirs :
234
+ for onnx_model_dir in model_list :
218
235
source_path = source_dir / onnx_model_dir / "config.json"
219
236
target_path = target_dir / onnx_model_dir / "config.json"
220
237
if source_path .exists ():
@@ -236,17 +253,20 @@ def optimize_stable_diffusion_pipeline(
236
253
if overwrite :
237
254
shutil .rmtree (output_dir , ignore_errors = True )
238
255
else :
239
- raise RuntimeError ("output directory existed:{output_dir}. Add --overwrite to empty the directory." )
256
+ raise RuntimeError (f "output directory existed:{ output_dir } . Add --overwrite to empty the directory." )
240
257
241
258
source_dir = Path (input_dir )
242
259
target_dir = Path (output_dir )
243
260
target_dir .mkdir (parents = True , exist_ok = True )
244
261
245
- _copy_extra_directory (source_dir , target_dir )
262
+ model_list = _get_model_list (source_dir )
263
+
264
+ _copy_extra_directory (source_dir , target_dir , model_list )
246
265
247
266
_optimize_sd_pipeline (
248
267
source_dir ,
249
268
target_dir ,
269
+ model_list ,
250
270
use_external_data_format ,
251
271
float16 ,
252
272
args .force_fp32_ops ,
0 commit comments