Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 38 additions & 14 deletions bindings/python/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
import json
import os
import shutil
from tempfile import TemporaryDirectory
from collections import defaultdict
from inspect import signature
from typing import Optional, List
from tempfile import TemporaryDirectory
from typing import Dict, List, Optional

import torch

from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download, get_repo_discussions
from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name
from safetensors.torch import save_file
from transformers import AutoConfig
from transformers.pipelines.base import infer_framework_load_model
from safetensors.torch import save_file


class AlreadyExists(Exception):
Expand All @@ -30,15 +30,18 @@ def shared_pointers(tensors):
failing.append(names)
return failing


def check_file_size(sf_filename: str, pt_filename: str):
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size

if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(f"""The file size different is more than 1%:
raise RuntimeError(
f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
""")
"""
)


def rename(pt_filename: str) -> str:
Expand All @@ -47,12 +50,13 @@ def rename(pt_filename: str) -> str:
return local


def convert_multi(model_id: str) -> List["CommitOperationAdd"]:
def convert_multi(model_id: str, folder: str) -> List["CommitOperationAdd"]:
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json")
with open(filename, "r") as f:
data = json.load(f)

filenames = set(data["weight_map"].values())
local_filenames = []
for filename in filenames:
cached_filename = hf_hub_download(repo_id=model_id, filename=filename)
loaded = torch.load(cached_filename)
Expand All @@ -71,7 +75,9 @@ def convert_multi(model_id: str) -> List["CommitOperationAdd"]:
json.dump(newdata, f)
local_filenames.append(index)

operations = [CommitOperationAdd(path_in_repo=local.split("/")[-1], path_or_fileobj=local) for local in local_filenames]
operations = [
CommitOperationAdd(path_in_repo=local.split("/")[-1], path_or_fileobj=local) for local in local_filenames
]

return operations

Expand All @@ -97,16 +103,34 @@ def convert_single(model_id: str, folder: str) -> List["CommitOperationAdd"]:
operations = [CommitOperationAdd(path_in_repo=sf_filename, path_or_fileobj=local)]
return operations


def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]]) -> str:
errors = []
for key in ["missing_keys", "mismatched_keys", "unexpected_keys"]:
pt_set = set(pt_infos[key])
sf_set = set(sf_infos[key])

pt_only = pt_set - sf_set
sf_only = sf_set - pt_set

if pt_only:
errors.append(f"{key} : PT warnings contain {pt_only} which are not present in SF warnings")
if sf_only:
errors.append(f"{key} : SF warnings contain {sf_only} which are not present in PT warnings")
return "\n".join(errors)


def check_final_model(model_id: str, folder: str):
config = hf_hub_download(repo_id=model_id, filename="config.json")
shutil.copy(config, os.path.join(folder, "config.json"))
config = AutoConfig.from_pretrained(folder)

_, pt_model = infer_framework_load_model(model_id, config)
_, sf_model = infer_framework_load_model(folder, config)
_, (pt_model, pt_infos) = infer_framework_load_model(model_id, config, output_loading_info=True)
_, (sf_model, sf_infos) = infer_framework_load_model(folder, config, output_loading_info=True)

pt_model = pt_model
sf_model = sf_model
if pt_infos != sf_infos:
error_string = create_diff(pt_infos, sf_infos)
raise ValueError(f"Different infos when reloading the model: {error_string}")

pt_params = pt_model.state_dict()
sf_params = sf_model.state_dict()
Expand Down Expand Up @@ -134,7 +158,6 @@ def check_final_model(model_id: str, folder: str):
if "image" in sig.parameters:
kwargs["image"] = pixel_values


if torch.cuda.is_available():
pt_model = pt_model.cuda()
sf_model = sf_model.cuda()
Expand All @@ -146,6 +169,7 @@ def check_final_model(model_id: str, folder: str):
torch.testing.assert_close(sf_logits, pt_logits)
print(f"Model {model_id} is ok !")


def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
try:
discussions = api.get_repo_discussions(repo_id=model_id)
Expand All @@ -156,7 +180,7 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
return discussion


def convert(api: "HfApi", model_id: str, force: bool=False) -> Optional["CommitInfo"]:
def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't catch this before but we shouldn't use any api object and just the top level methods (create_commit is directly in the main init).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why ?

I figured it was actually better to pass around an HfApi codewise, so you can include a specific token (for instance to create conversions directly).
The original space convertion was passing around a Token, but I figure HfApi was even easier to transport around and there's less risk to forget to include token=token in whatever call.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think anyone who uses hugginface_hub uses the HfApi object.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should include a kwarg for token instead IMO

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think anyone who uses hugginface_hub uses the HfApi object.

I do :)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this feature of setting a token on HfApi is quite recent though, but it's useful IMO)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me rephrase, I don't think many people except Julien use HfApi :-p

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did use it naturally actually an I think it's pretty cool to encapsulate common client options (most important of which the token).

I really think the "oh but you forgot to pass the token" is a real failure case, and it did happen when I used get_repo_discussions where I failed to pass the token and it would fail on private models on the space.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll go ahead and merge this.

Please open an issue if you feel so strongly about the HfApi object :)

pr_title = "Adding `safetensors` variant of this model"
info = api.model_info(model_id)
filenames = set(s.rfilename for s in info.siblings)
Expand Down
22 changes: 17 additions & 5 deletions bindings/python/convert_all.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
"""Simple utility tool to convert automatically most downloaded models"""
from convert import AlreadyExists, convert
from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments
from convert import convert, AlreadyExists
from transformers import AutoConfig


if __name__ == "__main__":
api = HfApi()
args = ModelSearchArguments()

total = 100
models = list(api.list_models(filter=ModelFilter(library=args.library.Transformers), sort="downloads", direction=-1))[:total]
total = 50
models = list(
api.list_models(filter=ModelFilter(library=args.library.Transformers), sort="downloads", direction=-1)
)[:total]

correct = 0
errors = set()
for model in models:
model = api.model_info(model.modelId, files_metadata=True)
size = None
for sibling in model.siblings:
if sibling.rfilename == "pytorch_model.bin":
size = sibling.size
if size is None or size > 2_000_000_000:
print(f"[{model.downloads}] Skipping {model.modelId} (too large {size})")
continue

model_id = model.modelId
print(f"[{model.downloads}] {model.modelId}")
try:
Expand All @@ -22,10 +34,10 @@
correct += 1
print(e)
except Exception as e:
errors.add( model_id)
config = AutoConfig.from_pretrained(model_id)
errors.add(config.__class__.__name__)
print(e)


print(f"Errors: {errors}")
print(f"File size is difference {len(errors)}")
print(f"Correct rate {correct}/{total} ({correct/total * 100:.2f}%)")