4949)
5050from diffusers .optimization import get_scheduler
5151from diffusers .utils import check_min_version , is_wandb_available
52+ from diffusers .utils .hub_utils import load_or_create_model_card , populate_model_card
5253from diffusers .utils .import_utils import is_xformers_available
5354from diffusers .utils .torch_utils import is_compiled_module
5455
@@ -195,7 +196,7 @@ def import_model_class_from_model_name_or_path(
195196 raise ValueError (f"{ model_class } is not supported." )
196197
197198
198- def save_model_card (repo_id : str , image_logs = None , base_model = str , repo_folder = None ):
199+ def save_model_card (repo_id : str , image_logs : dict = None , base_model : str = None , repo_folder : str = None ):
199200 img_str = ""
200201 if image_logs is not None :
201202 img_str = "You can find some example images below.\n "
@@ -209,27 +210,25 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
209210 image_grid (images , 1 , len (images )).save (os .path .join (repo_folder , f"images_{ i } .png" ))
210211 img_str += f"\n "
211212
212- yaml = f"""
213- ---
214- license: creativeml-openrail-m
215- base_model: { base_model }
216- tags:
217- - stable-diffusion-xl
218- - stable-diffusion-xl-diffusers
219- - text-to-image
220- - diffusers
221- - t2iadapter
222- inference: true
223- ---
224- """
225- model_card = f"""
213+ model_description = f"""
226214# t2iadapter-{ repo_id }
227215
228216These are t2iadapter weights trained on { base_model } with new type of conditioning.
229217{ img_str }
230218"""
231- with open (os .path .join (repo_folder , "README.md" ), "w" ) as f :
232- f .write (yaml + model_card )
219+ model_card = load_or_create_model_card (
220+ repo_id_or_path = repo_id ,
221+ from_training = True ,
222+ license = "creativeml-openrail-m" ,
223+ base_model = base_model ,
224+ model_description = model_description ,
225+ inference = True ,
226+ )
227+
228+ tags = ["stable-diffusion-xl" , "stable-diffusion-xl-diffusers" , "text-to-image" , "diffusers" , "t2iadapter" ]
229+ model_card = populate_model_card (model_card , tags = tags )
230+
231+ model_card .save (os .path .join (repo_folder , "README.md" ))
233232
234233
235234def parse_args (input_args = None ):
0 commit comments