Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 28 additions & 25 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module

Expand All @@ -69,33 +70,20 @@

def save_model_card(
repo_id: str,
images=None,
base_model=str,
images: list = None,
base_model: str = None,
train_text_encoder=False,
prompt=str,
repo_folder=None,
prompt: str = None,
repo_folder: str = None,
pipeline: DiffusionPipeline = None,
):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"

yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
- text-to-image
- diffusers
- dreambooth
inference: true
---
"""
model_card = f"""
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"

model_description = f"""
# DreamBooth - {repo_id}

This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
Expand All @@ -104,8 +92,23 @@ def save_model_card(

DreamBooth for the text encoder was enabled: {train_text_encoder}.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
license="creativeml-openrail-m",
base_model=base_model,
instance_prompt=prompt,
model_description=model_description,
inference=True,
)

tags = ["text-to-image", "dreambooth"]
if isinstance(pipeline, StableDiffusionPipeline):
tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
else:
tags.extend(["if", "if-diffusers"])
model_card = populate_model_card(model_card, tags=tags)

model_card.save(os.path.join(repo_folder, "README.md"))


def log_validation(
Expand Down
39 changes: 28 additions & 11 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import traceback
import warnings
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union
from uuid import uuid4

from huggingface_hub import (
Expand Down Expand Up @@ -94,14 +94,14 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:


def load_or_create_model_card(
repo_id_or_path: Optional[str] = None, token: Optional[str] = None, is_pipeline: bool = False
repo_id_or_path: Optional[str] = None, token: Optional[str] = None, is_pipeline: bool = False, **kwargs
) -> ModelCard:
"""
Loads or creates a model card.

Args:
repo_id (`str`):
The repo_id where to look for the model card.
repo_id_or_path (`str`):
The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card.
token (`str`, *optional*):
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details.
is_pipeline (`bool`, *optional*):
Expand All @@ -110,27 +110,44 @@ def load_or_create_model_card(
if not is_jinja_available():
raise ValueError(
"Modelcard rendering is based on Jinja templates."
" Please make sure to have `jinja` installed before using `create_model_card`."
" Please make sure to have `jinja` installed before using `load_or_create_model_card`."
" To install it, please run `pip install Jinja2`."
)

try:
# Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id_or_path, token=token)
except EntryNotFoundError:
# Otherwise create a simple model card from template
except (EntryNotFoundError, RepositoryNotFoundError):
# Otherwise create a model card from template
component = "pipeline" if is_pipeline else "model"
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
card_data = ModelCardData()

if kwargs is not None:
model_description = kwargs.pop("model_description", None)
card_data = ModelCardData(**kwargs) if kwargs is not None else ModelCardData()
else:
model_description = None

if model_description is None:
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."

model_card = ModelCard.from_template(card_data, model_description=model_description)

return model_card


def populate_model_card(model_card: ModelCard) -> ModelCard:
"""Populates the `model_card` with library name."""
def populate_model_card(model_card: ModelCard, tags: Union[str, List[str]] = None) -> ModelCard:
"""Populates the `model_card` with library name and optional tags."""
if model_card.data.library_name is None:
model_card.data.library_name = "diffusers"

if tags is not None:
if isinstance(tags, str):
tags = [tags]
if model_card.data.tags is None:
model_card.data.tags = []
for tag in tags:
model_card.data.tags.append(tag)

return model_card


Expand Down