Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d5f2cf6
feat: explicitly tag to diffusers when using push_to_hub
sayakpaul Jan 23, 2024
dc7f55d
remove tags.
sayakpaul Jan 23, 2024
91b26ff
reset repo.
sayakpaul Jan 23, 2024
156586f
Merge branch 'main' into add-diffusers-tag
sayakpaul Jan 23, 2024
03a704a
Apply suggestions from code review
sayakpaul Jan 23, 2024
d33ed6c
fix: tests
sayakpaul Jan 23, 2024
2baf0d2
fix: push_to_hub behaviour for tagging from save_pretrained
sayakpaul Jan 23, 2024
5d9e664
Apply suggestions from code review
sayakpaul Jan 23, 2024
62ddbb8
Apply suggestions from code review
sayakpaul Jan 23, 2024
0d73555
import fixes.
sayakpaul Jan 23, 2024
5297ad4
add library name to existing model card.
sayakpaul Jan 23, 2024
99ce47c
add: standalone test for generate_model_card
sayakpaul Jan 23, 2024
19d26da
Merge branch 'main' into add-diffusers-tag
sayakpaul Jan 23, 2024
2b93dcc
fix tests for standalone method
sayakpaul Jan 23, 2024
0f31032
moved library_name to a better place.
sayakpaul Jan 23, 2024
987178b
merge create_model_card and generate_model_card.
sayakpaul Jan 23, 2024
5bd864c
fix test
sayakpaul Jan 23, 2024
33e2d91
address lucain's comments
sayakpaul Jan 23, 2024
322c0e1
fix return identation
sayakpaul Jan 23, 2024
73ea51d
Apply suggestions from code review
sayakpaul Jan 23, 2024
e32e5e1
address further comments.
sayakpaul Jan 23, 2024
6b36050
Merge branch 'main' into add-diffusers-tag
sayakpaul Jan 23, 2024
ffc3845
Update src/diffusers/pipelines/pipeline_utils.py
sayakpaul Jan 23, 2024
183fd65
Merge branch 'main' into add-diffusers-tag
sayakpaul Jan 23, 2024
2d5c555
Merge branch 'main' into add-diffusers-tag
sayakpaul Jan 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
is_torch_version,
logging,
)
from ..utils.hub_utils import PushToHubMixin
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -377,6 +377,11 @@ def save_pretrained(
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")

if push_to_hub:
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token)
model_card = populate_model_card(model_card)
model_card.save(os.path.join(save_directory, "README.md"))

self._upload_folder(
save_directory,
repo_id,
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
logging,
numpy_to_pil,
)
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from ..utils.torch_utils import is_compiled_module


Expand Down Expand Up @@ -725,6 +726,11 @@ def is_saveable_module(name, value):
self.save_config(save_directory)

if push_to_hub:
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
model_card = populate_model_card(model_card)
model_card.save(os.path.join(save_directory, "README.md"))

self._upload_folder(
save_directory,
repo_id,
Expand Down
81 changes: 39 additions & 42 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
ModelCard,
ModelCardData,
create_repo,
get_full_repo_name,
hf_hub_download,
upload_folder,
)
Expand Down Expand Up @@ -67,7 +66,6 @@
logger = get_logger(__name__)


MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
SESSION_ID = uuid4().hex


Expand Down Expand Up @@ -95,53 +93,45 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua


def create_model_card(args, model_name):
def load_or_create_model_card(
repo_id_or_path: Optional[str] = None, token: Optional[str] = None, is_pipeline: bool = False
) -> ModelCard:
"""
Loads or creates a model card.

Args:
repo_id (`str`):
The repo_id where to look for the model card.
token (`str`, *optional*):
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details.
is_pipeline (`bool`, *optional*):
Boolean to indicate if we're adding tag to a [`DiffusionPipeline`].
"""
if not is_jinja_available():
raise ValueError(
"Modelcard rendering is based on Jinja templates."
" Please make sure to have `jinja` installed before using `create_model_card`."
" To install it, please run `pip install Jinja2`."
)

if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow the discussion here: #6678 (comment).

return

hub_token = args.hub_token if hasattr(args, "hub_token") else None
repo_name = get_full_repo_name(model_name, token=hub_token)

model_card = ModelCard.from_template(
card_data=ModelCardData( # Card metadata object that will be converted to YAML block
language="en",
license="apache-2.0",
library_name="diffusers",
tags=[],
datasets=args.dataset_name,
metrics=[],
),
template_path=MODEL_CARD_TEMPLATE_PATH,
model_name=model_name,
repo_name=repo_name,
dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
learning_rate=args.learning_rate,
train_batch_size=args.train_batch_size,
eval_batch_size=args.eval_batch_size,
gradient_accumulation_steps=(
args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None
),
adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
ema_power=args.ema_power if hasattr(args, "ema_power") else None,
ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
mixed_precision=args.mixed_precision,
)

card_path = os.path.join(args.output_dir, "README.md")
model_card.save(card_path)
try:
# Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id_or_path, token=token)
except EntryNotFoundError:
# Otherwise create a simple model card from template
component = "pipeline" if is_pipeline else "model"
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
card_data = ModelCardData()
model_card = ModelCard.from_template(card_data, model_description=model_description)

return model_card


def populate_model_card(model_card: ModelCard) -> ModelCard:
"""Populates the `model_card` with library name."""
if model_card.data.library_name is None:
model_card.data.library_name = "diffusers"
return model_card


def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
Expand Down Expand Up @@ -435,6 +425,10 @@ def push_to_hub(
"""
repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id

# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token)
model_card = populate_model_card(model_card)

# Save all files.
save_kwargs = {"safe_serialization": safe_serialization}
if "Scheduler" not in self.__class__.__name__:
Expand All @@ -443,6 +437,9 @@ def push_to_hub(
with tempfile.TemporaryDirectory() as tmpdir:
self.save_pretrained(tmpdir, **save_kwargs)

# Update model card if needed:
model_card.save(os.path.join(tmpdir, "README.md"))

return self._upload_folder(
tmpdir,
repo_id,
Expand Down
50 changes: 0 additions & 50 deletions src/diffusers/utils/model_card_template.md

This file was deleted.

26 changes: 25 additions & 1 deletion tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import numpy as np
import requests_mock
import torch
from huggingface_hub import delete_repo
from huggingface_hub import ModelCard, delete_repo
from huggingface_hub.utils import is_jinja_available
from requests.exceptions import HTTPError

from diffusers.models import UNet2DConditionModel
Expand Down Expand Up @@ -732,3 +733,26 @@ def test_push_to_hub_in_organization(self):

# Reset repo
delete_repo(self.org_repo_id, token=TOKEN)

@unittest.skipIf(
not is_jinja_available(),
reason="Model card tests cannot be performed without Jinja installed.",
)
def test_push_to_hub_library_name(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) I would add an explicit test both for when the model card doesn't exist yet and for when the model card already exists. Maybe not needed to test the full push_to_hub method but simply the create_and_tag_model_card helper (or whatever its name :) )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does 99ce47c work for you?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I didn't notice that there is a generate_model_card and a create_model_card in the hub utils. Should we merge them since they seem to do closely related things? (the difference is generating for a training or generating from anywhere, right?). Naming is misleading in that case (sorry, didn't notice before when I suggest generate_model_card).

Regarding the test, yes it looks good to me :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but it won't work since tmpdir is local. What's the best way to test here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wauplin I merged them. However, I am not sure about the test since tmpdir is local and ModelCard.load() would fail there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think tests are good. I've added a comment below to test from existing file with existing library_name that is not diffusers.

But what I meant above is that with this PR the hub utils feels clunky. We now have:

  • create_model_card that creates a model card for a training. The method looks outdated and not used anywhere. However it introduces a template that is nice.
  • generate_model_card that either loads an existing model card or create a new one (from a different template) and add library_name: diffusers to it. This method is used in the codebase.

Maybe what I would do to solve this (and sorry if it's a revamp of the PR):

  • deprecate create_model_card (or even remove it completely) if it's not used
  • add in hub_utils.py a load_or_create_model_card helper that returns a ModelCard object without modifying anything. It's similar to the try: (ModelCard.load) except EntryNotFound: (ModelCard.from_template) part
  • add in hub_utils.py a populate_model_card that takes as input a ModelCard object and add library_name: diffusers if doesn't exist yet.
  • then in the codebase (for example in pipeline_utils.py), you would do
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
populate_model_card(model_card)
model_card.save(os.path.join(save_directory, "README.md"))

WDTY?

(to take with a grain of salt, I'm not expert in diffusers codebase so I might be missing some parts)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(sorry @sayakpaul I didn't see your comment while posting this message)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. The latest commit should be have reflected these. Let me know if that makes sense.

I have opted to remove create_model_card() and the related test as it's not really used.

model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
model.push_to_hub(self.repo_id, token=TOKEN)

model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
assert model_card.library_name == "diffusers"

# Reset repo
delete_repo(self.repo_id, token=TOKEN)
36 changes: 7 additions & 29 deletions tests/others/test_hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,15 @@
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import Mock, patch

import diffusers.utils.hub_utils
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card


class CreateModelCardTest(unittest.TestCase):
@patch("diffusers.utils.hub_utils.get_full_repo_name")
def test_create_model_card(self, repo_name_mock: Mock) -> None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow the discussion here: #6678 (comment).

repo_name_mock.return_value = "full_repo_name"
def test_generate_model_card_with_library_name(self):
with TemporaryDirectory() as tmpdir:
# Dummy args values
args = Mock()
args.output_dir = tmpdir
args.local_rank = 0
args.hub_token = "hub_token"
args.dataset_name = "dataset_name"
args.learning_rate = 0.01
args.train_batch_size = 100000
args.eval_batch_size = 10000
args.gradient_accumulation_steps = 0.01
args.adam_beta1 = 0.02
args.adam_beta2 = 0.03
args.adam_weight_decay = 0.0005
args.adam_epsilon = 0.000001
args.lr_scheduler = 1
args.lr_warmup_steps = 10
args.ema_inv_gamma = 0.001
args.ema_power = 0.1
args.ema_max_decay = 0.2
args.mixed_precision = True

# Model card mush be rendered and saved
diffusers.utils.hub_utils.create_model_card(args, model_name="model_name")
self.assertTrue((Path(tmpdir) / "README.md").is_file())
file_path = Path(tmpdir) / "README.md"
file_path.write_text("---\nlibrary_name: foo\n---\nContent\n")
model_card = load_or_create_model_card(file_path)
populate_model_card(model_card)
assert model_card.data.library_name == "foo"
18 changes: 17 additions & 1 deletion tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import numpy as np
import PIL.Image
import torch
from huggingface_hub import delete_repo
from huggingface_hub import ModelCard, delete_repo
from huggingface_hub.utils import is_jinja_available
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer

import diffusers
Expand Down Expand Up @@ -1142,6 +1143,21 @@ def test_push_to_hub_in_organization(self):
# Reset repo
delete_repo(self.org_repo_id, token=TOKEN)

@unittest.skipIf(
not is_jinja_available(),
reason="Model card tests cannot be performed without Jinja installed.",
)
def test_push_to_hub_library_name(self):
components = self.get_pipeline_components()
pipeline = StableDiffusionPipeline(**components)
pipeline.push_to_hub(self.repo_id, token=TOKEN)

model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
assert model_card.library_name == "diffusers"

# Reset repo
delete_repo(self.repo_id, token=TOKEN)


# For SDXL and its derivative pipelines (such as ControlNet), we have the text encoders
# and the tokenizers as optional components. So, we need to override the `test_save_load_optional_components()`
Expand Down