Skip to content
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
0105fc4
update
DN6 Dec 19, 2023
2afad15
Merge branch 'main' into refactor-single-file
DN6 Dec 22, 2023
2686fdd
update
DN6 Dec 22, 2023
ef656d7
update
DN6 Dec 22, 2023
daf4d05
update
DN6 Dec 25, 2023
8b7eecd
update
DN6 Dec 26, 2023
0cd1be4
update
DN6 Dec 26, 2023
16a80d3
update
DN6 Dec 26, 2023
0b24f88
Merge branch 'main' into refactor-single-file
DN6 Dec 27, 2023
7289be1
update
DN6 Dec 28, 2023
0012dd2
update
DN6 Dec 28, 2023
2616e03
update
DN6 Dec 28, 2023
7db4f50
update'
DN6 Dec 28, 2023
872aa6c
update
DN6 Dec 28, 2023
83c5b8e
update
DN6 Dec 29, 2023
5a8e10e
update
DN6 Dec 29, 2023
7a8c722
update
DN6 Dec 29, 2023
ccf8d62
update
DN6 Dec 29, 2023
da9c9d5
update
DN6 Dec 29, 2023
b791a71
up
DN6 Dec 29, 2023
c6c8fc7
update
DN6 Dec 29, 2023
6ba7a50
update
DN6 Dec 29, 2023
b44d2b4
update
DN6 Dec 30, 2023
41e97e0
update
DN6 Dec 30, 2023
658d80f
update
DN6 Dec 30, 2023
5daf61a
update
DN6 Dec 30, 2023
af6cd36
update
DN6 Dec 30, 2023
6d743ef
update
DN6 Dec 30, 2023
b7732a0
update
DN6 Dec 30, 2023
9d10d2d
update
DN6 Dec 30, 2023
820313b
update
DN6 Dec 30, 2023
efc6380
update
DN6 Dec 30, 2023
9453626
up
DN6 Dec 30, 2023
afa62e6
update
DN6 Dec 30, 2023
e033f9f
update
DN6 Dec 30, 2023
c0d62ac
update
DN6 Dec 30, 2023
9605db5
update
DN6 Dec 30, 2023
e945e18
update'
DN6 Dec 30, 2023
fa3a0d6
update
DN6 Jan 2, 2024
bbc60be
update
DN6 Jan 2, 2024
b69cddb
update
DN6 Jan 2, 2024
3ae0b83
update
DN6 Jan 2, 2024
6c19f0a
update
DN6 Jan 2, 2024
ba704fd
update
DN6 Jan 2, 2024
f304528
update
DN6 Jan 2, 2024
3c806be
update
DN6 Jan 2, 2024
f86ba55
update
DN6 Jan 2, 2024
cf2fe1e
Merge branch 'main' into refactor-single-file
DN6 Jan 12, 2024
cf560a7
update
DN6 Jan 15, 2024
0ec1ed7
update
DN6 Jan 16, 2024
4bb4ed4
update
DN6 Jan 16, 2024
68a49b1
update
DN6 Jan 16, 2024
e37abaf
update
DN6 Jan 17, 2024
1bd8ba3
update
DN6 Jan 17, 2024
1cce591
update
DN6 Jan 17, 2024
df4a8ea
update
DN6 Jan 17, 2024
249f78e
update
DN6 Jan 17, 2024
8a24733
update
DN6 Jan 17, 2024
de77ff6
update
DN6 Jan 18, 2024
0939565
update
DN6 Jan 18, 2024
c22c2aa
update
DN6 Jan 18, 2024
eb71c80
update
DN6 Jan 18, 2024
32349c5
update
DN6 Jan 18, 2024
a076513
update
DN6 Jan 18, 2024
db3eb06
update
DN6 Jan 19, 2024
9b42fbf
update
DN6 Jan 19, 2024
1ca79f7
update
DN6 Jan 19, 2024
ffde123
update
DN6 Jan 19, 2024
fd2ec36
update
DN6 Jan 19, 2024
aee8b5f
update
DN6 Jan 19, 2024
2fb9baf
update
DN6 Jan 19, 2024
bb8d317
clean
DN6 Jan 19, 2024
480a4b4
update
DN6 Jan 19, 2024
2483d51
update
DN6 Jan 19, 2024
dab7f01
clean up
DN6 Jan 19, 2024
68ddb25
clean up
DN6 Jan 19, 2024
7395283
update
DN6 Jan 19, 2024
153e746
clean
DN6 Jan 19, 2024
a371c3b
clean
DN6 Jan 19, 2024
ba66fb8
update
DN6 Jan 19, 2024
b658618
updaet
DN6 Jan 19, 2024
3620357
clean up
DN6 Jan 19, 2024
dae09d0
fix docs
DN6 Jan 19, 2024
0746cf9
update
DN6 Jan 22, 2024
dbfb8f1
update
DN6 Jan 22, 2024
82ce94e
Revert "update"
DN6 Jan 22, 2024
6f8446a
update
DN6 Jan 22, 2024
e1d82e2
Merge branch 'main' into refactor-single-file
DN6 Jan 22, 2024
b2c9561
update
DN6 Jan 22, 2024
e297ac8
update
DN6 Jan 23, 2024
d1e3466
update
DN6 Jan 23, 2024
650a632
fix controlnet
DN6 Jan 23, 2024
99fdba9
fix scheduler
DN6 Jan 23, 2024
8c9af6c
fix controlnet tests
DN6 Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ def text_encoder_attn_modules(text_encoder):
_import_structure = {}

if is_torch_available():
_import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"]
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]
_import_structure["controlnet"] = ["FromOriginalControlnetMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]

if is_transformers_available():
_import_structure["single_file"].extend(["FromSingleFileMixin"])
_import_structure["single_file"] = ["FromSingleFileMixin"]
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
Expand Down
223 changes: 223 additions & 0 deletions src/diffusers/loaders/autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
from contextlib import nullcontext
from io import BytesIO
from pathlib import Path

import requests
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args

from ..utils import (
is_accelerate_available,
is_omegaconf_available,
is_transformers_available,
logging,
)
from ..utils.import_utils import BACKENDS_MAPPING


if is_transformers_available():
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if is_transformers_available():
pass
if is_transformers_available():
pass

What is the purpose of this?


if is_accelerate_available():
from accelerate import init_empty_weights

logger = logging.get_logger(__name__)


class FromOriginalVAEMixin:
"""
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into an [`AutoencoderKL`].
"""

@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.

Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is .ckpt still used? If we know that to be "yes" I think it would make sense to add a comment below that ".ckpt" files are also supported.

- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We allow passing "auto" here?

dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
upcast_attention (`bool`, *optional*, defaults to `None`):
Whether the attention computation should always be upcasted.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.

<Tip warning={true}>

Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
a VAE from SDXL or a Stable Diffusion v2 model or higher.

</Tip>

Examples:

```py
from diffusers import AutoencoderKL

url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
model = AutoencoderKL.from_single_file(url)
```
"""
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])

from omegaconf import OmegaConf

from ..models import AutoencoderKL

# import here to avoid circular dependency
from ..pipelines.stable_diffusion.convert_from_ckpt import (
convert_ldm_vae_checkpoint,
create_vae_diffusers_config,
)

config_file = kwargs.pop("config_file", None)
cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
image_size = kwargs.pop("image_size", None)
scaling_factor = kwargs.pop("scaling_factor", None)
kwargs.pop("upcast_attention", None)

torch_dtype = kwargs.pop("torch_dtype", None)

use_safetensors = kwargs.pop("use_safetensors", None)

file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
from_safetensors = file_extension == "safetensors"

if from_safetensors and use_safetensors is False:
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")

# remove huggingface url
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]

# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
ckpt_path = Path(pretrained_model_link_or_path)
if not ckpt_path.is_file():
# get repo_id and (potentially nested) file path of ckpt in repo
repo_id = "/".join(ckpt_path.parts[:2])
file_path = "/".join(ckpt_path.parts[2:])

if file_path.startswith("blob/"):
file_path = file_path[len("blob/") :]

if file_path.startswith("main/"):
file_path = file_path[len("main/") :]

pretrained_model_link_or_path = hf_hub_download(
repo_id,
filename=file_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
force_download=force_download,
)

if from_safetensors:
from safetensors import safe_open

checkpoint = {}
with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f:
for key in f.keys():
checkpoint[key] = f.get_tensor(key)
else:
checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu")

if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]

if config_file is None:
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
config_file = BytesIO(requests.get(config_url).content)

original_config = OmegaConf.load(config_file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually is there a way we could get rid of OmegaConf ? It's just a yaml file so couldn't we just load it with file.open(...) and some syntax code. It'd be great to not require an additional OmegaConf dependency here. Or we use the more used pyyaml library: https://github.com/yaml/pyyaml ?


# default to sd-v1-5
image_size = image_size or 512

vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should refactor the convert_ldm_vae_checkpoint function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it make sense to have in a followup PR?


if scaling_factor is None:
if (
"model" in original_config
and "params" in original_config.model
and "scale_factor" in original_config.model.params
):
vae_scaling_factor = original_config.model.params.scale_factor
else:
vae_scaling_factor = 0.18215 # default SD scaling factor

vae_config["scaling_factor"] = vae_scaling_factor

ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
vae = AutoencoderKL(**vae_config)

if is_accelerate_available():
from ..models.modeling_utils import load_model_dict_into_meta

load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu")
else:
vae.load_state_dict(converted_vae_checkpoint)

if torch_dtype is not None:
vae.to(dtype=torch_dtype)

return vae
Loading