From 34efcc20342974a59bb24b24c0a55c7f4c59247c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 15 Feb 2024 15:58:47 +0530 Subject: [PATCH 1/8] feat: support single file checkpoint from from_pretrained() --- src/diffusers/pipelines/pipeline_utils.py | 583 ++++++++++++---------- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/constants.py | 1 + 3 files changed, 311 insertions(+), 274 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 18a4b5cb346b..852108418c61 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -23,6 +23,7 @@ from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union +from urllib.parse import urlparse import numpy as np import PIL.Image @@ -45,6 +46,7 @@ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( + _ACCEPTED_SINGLE_FILE_FORMATS, CONFIG_NAME, DEPRECATED_REVISION_ARGS, SAFETENSORS_WEIGHTS_NAME, @@ -566,6 +568,21 @@ def _fetch_class_library_tuple(module): return (library, class_name) +def is_valid_url(url): + result = urlparse(url) + if result.scheme and result.netloc: + return True + + +def is_single_file_checkpoint(filepath): + if filepath.endswith(_ACCEPTED_SINGLE_FILE_FORMATS): + if is_valid_url(filepath): + return True + elif os.path.isfile(filepath): + return True + return False + + class DiffusionPipeline(ConfigMixin, PushToHubMixin): r""" Base class for all pipelines. @@ -1054,308 +1071,326 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P >>> pipeline.scheduler = scheduler ``` """ - cache_dir = kwargs.pop("cache_dir", None) - resume_download = kwargs.pop("resume_download", False) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - from_flax = kwargs.pop("from_flax", False) - torch_dtype = kwargs.pop("torch_dtype", None) - custom_pipeline = kwargs.pop("custom_pipeline", None) - custom_revision = kwargs.pop("custom_revision", None) - provider = kwargs.pop("provider", None) - sess_options = kwargs.pop("sess_options", None) - device_map = kwargs.pop("device_map", None) - max_memory = kwargs.pop("max_memory", None) - offload_folder = kwargs.pop("offload_folder", None) - offload_state_dict = kwargs.pop("offload_state_dict", False) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - variant = kwargs.pop("variant", None) - use_safetensors = kwargs.pop("use_safetensors", None) - use_onnx = kwargs.pop("use_onnx", None) - load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) - - # 1. Download the checkpoints and configs - # use snapshot download here to get it working from from_pretrained - if not os.path.isdir(pretrained_model_name_or_path): - if pretrained_model_name_or_path.count("/") > 1: - raise ValueError( - f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"' - " is neither a valid local path nor a valid repo id. Please check the parameter." - ) - cached_folder = cls.download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - from_flax=from_flax, - use_safetensors=use_safetensors, - use_onnx=use_onnx, - custom_pipeline=custom_pipeline, - custom_revision=custom_revision, - variant=variant, - load_connected_pipeline=load_connected_pipeline, - **kwargs, - ) + if is_single_file_checkpoint(pretrained_model_name_or_path): + logger.info("Single file checkpoint detected...") + model = cls.from_single_file(pretrained_model_name_or_path, **kwargs) else: - cached_folder = pretrained_model_name_or_path - - config_dict = cls.load_config(cached_folder) - - # pop out "_ignore_files" as it is only needed for download - config_dict.pop("_ignore_files", None) - - # 2. Define which model components should load variants - # We retrieve the information by matching whether variant - # model checkpoints exist in the subfolders - model_variants = {} - if variant is not None: - for folder in os.listdir(cached_folder): - folder_path = os.path.join(cached_folder, folder) - is_folder = os.path.isdir(folder_path) and folder in config_dict - variant_exists = is_folder and any( - p.split(".")[1].startswith(variant) for p in os.listdir(folder_path) + cache_dir = kwargs.pop("cache_dir", None) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + from_flax = kwargs.pop("from_flax", False) + torch_dtype = kwargs.pop("torch_dtype", None) + custom_pipeline = kwargs.pop("custom_pipeline", None) + custom_revision = kwargs.pop("custom_revision", None) + provider = kwargs.pop("provider", None) + sess_options = kwargs.pop("sess_options", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + use_onnx = kwargs.pop("use_onnx", None) + load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) + + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + if pretrained_model_name_or_path.count("/") > 1: + raise ValueError( + f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"' + " is neither a valid local path nor a valid repo id. Please check the parameter." + ) + cached_folder = cls.download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + from_flax=from_flax, + use_safetensors=use_safetensors, + use_onnx=use_onnx, + custom_pipeline=custom_pipeline, + custom_revision=custom_revision, + variant=variant, + load_connected_pipeline=load_connected_pipeline, + **kwargs, ) - if variant_exists: - model_variants[folder] = variant - - # 3. Load the pipeline class, if using custom module then load it from the hub - # if we load from explicit class, let's use it - custom_class_name = None - if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")): - custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py") - elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile( - os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py") - ): - custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py") - custom_class_name = config_dict["_class_name"][1] - - pipeline_class = _get_pipeline_class( - cls, - config_dict, - load_connected_pipeline=load_connected_pipeline, - custom_pipeline=custom_pipeline, - class_name=custom_class_name, - cache_dir=cache_dir, - revision=custom_revision, - ) - - # DEPRECATED: To be removed in 1.0.0 - if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( - version.parse(config_dict["_diffusers_version"]).base_version - ) <= version.parse("0.5.1"): - from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy - - pipeline_class = StableDiffusionInpaintPipelineLegacy - - deprecation_message = ( - "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" - f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" - " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" - " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" - f" checkpoint {pretrained_model_name_or_path} to the format of" - " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" - " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." - ) - deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) - - # 4. Define expected modules given pipeline signature - # and define non-None initialized modules (=`init_kwargs`) + else: + cached_folder = pretrained_model_name_or_path + + config_dict = cls.load_config(cached_folder) + + # pop out "_ignore_files" as it is only needed for download + config_dict.pop("_ignore_files", None) + + # 2. Define which model components should load variants + # We retrieve the information by matching whether variant + # model checkpoints exist in the subfolders + model_variants = {} + if variant is not None: + for folder in os.listdir(cached_folder): + folder_path = os.path.join(cached_folder, folder) + is_folder = os.path.isdir(folder_path) and folder in config_dict + variant_exists = is_folder and any( + p.split(".")[1].startswith(variant) for p in os.listdir(folder_path) + ) + if variant_exists: + model_variants[folder] = variant - # some modules can be passed directly to the init - # in this case they are already instantiated in `kwargs` - # extract them here - expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) - passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} - passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + # 3. Load the pipeline class, if using custom module then load it from the hub + # if we load from explicit class, let's use it + custom_class_name = None + if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")): + custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py") + elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile( + os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py") + ): + custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py") + custom_class_name = config_dict["_class_name"][1] - init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + pipeline_class = _get_pipeline_class( + cls, + config_dict, + load_connected_pipeline=load_connected_pipeline, + custom_pipeline=custom_pipeline, + class_name=custom_class_name, + cache_dir=cache_dir, + revision=custom_revision, + ) - # define init kwargs and make sure that optional component modules are filtered out - init_kwargs = { - k: init_dict.pop(k) - for k in optional_kwargs - if k in init_dict and k not in pipeline_class._optional_components - } - init_kwargs = {**init_kwargs, **passed_pipe_kwargs} + # DEPRECATED: To be removed in 1.0.0 + if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( + version.parse(config_dict["_diffusers_version"]).base_version + ) <= version.parse("0.5.1"): + from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy - # remove `null` components - def load_module(name, value): - if value[0] is None: - return False - if name in passed_class_obj and passed_class_obj[name] is None: - return False - return True + pipeline_class = StableDiffusionInpaintPipelineLegacy - init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + deprecation_message = ( + "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" + f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" + " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" + " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" + f" checkpoint {pretrained_model_name_or_path} to the format of" + " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" + " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." + ) + deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) - # Special case: safety_checker must be loaded separately when using `from_flax` - if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: - raise NotImplementedError( - "The safety checker cannot be automatically loaded when loading weights `from_flax`." - " Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker" - " separately if you need it." - ) + # 4. Define expected modules given pipeline signature + # and define non-None initialized modules (=`init_kwargs`) - # 5. Throw nice warnings / errors for fast accelerate loading - if len(unused_kwargs) > 0: - logger.warning( - f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." - ) + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} - if low_cpu_mem_usage and not is_accelerate_available(): - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) + init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) - if device_map is not None and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `device_map=None`." - ) + # define init kwargs and make sure that optional component modules are filtered out + init_kwargs = { + k: init_dict.pop(k) + for k in optional_kwargs + if k in init_dict and k not in pipeline_class._optional_components + } + init_kwargs = {**init_kwargs, **passed_pipe_kwargs} + + # remove `null` components + def load_module(name, value): + if value[0] is None: + return False + if name in passed_class_obj and passed_class_obj[name] is None: + return False + return True + + init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + + # Special case: safety_checker must be loaded separately when using `from_flax` + if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: + raise NotImplementedError( + "The safety checker cannot be automatically loaded when loading weights `from_flax`." + " Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker" + " separately if you need it." + ) - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `low_cpu_mem_usage=False`." - ) + # 5. Throw nice warnings / errors for fast accelerate loading + if len(unused_kwargs) > 0: + logger.warning( + f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." + ) - if low_cpu_mem_usage is False and device_map is not None: - raise ValueError( - f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" - " dispatching. Please make sure to set `low_cpu_mem_usage=True`." - ) + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) - # import it here to avoid circular import - from diffusers import pipelines - - # 6. Load each module in the pipeline - for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): - # 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names - class_name = class_name[4:] if class_name.startswith("Flax") else class_name - - # 6.2 Define all importable classes - is_pipeline_module = hasattr(pipelines, library_name) - importable_classes = ALL_IMPORTABLE_CLASSES - loaded_sub_model = None - - # 6.3 Use passed sub model or load class_name from library_name - if name in passed_class_obj: - # if the model is in a pipeline module, then we load it from the pipeline - # check that passed_class_obj has correct parent class - maybe_raise_or_warn( - library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." ) - loaded_sub_model = passed_class_obj[name] - else: - # load sub model - loaded_sub_model = load_sub_model( - library_name=library_name, - class_name=class_name, - importable_classes=importable_classes, - pipelines=pipelines, - is_pipeline_module=is_pipeline_module, - pipeline_class=pipeline_class, - torch_dtype=torch_dtype, - provider=provider, - sess_options=sess_options, - device_map=device_map, - max_memory=max_memory, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - model_variants=model_variants, - name=name, - from_flax=from_flax, - variant=variant, - low_cpu_mem_usage=low_cpu_mem_usage, - cached_folder=cached_folder, - revision=revision, + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." ) - logger.info( - f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) - init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - - if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")): - modelcard = ModelCard.load(os.path.join(cached_folder, "README.md")) - connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS} - load_kwargs = { - "cache_dir": cache_dir, - "resume_download": resume_download, - "force_download": force_download, - "proxies": proxies, - "local_files_only": local_files_only, - "token": token, - "revision": revision, - "torch_dtype": torch_dtype, - "custom_pipeline": custom_pipeline, - "custom_revision": custom_revision, - "provider": provider, - "sess_options": sess_options, - "device_map": device_map, - "max_memory": max_memory, - "offload_folder": offload_folder, - "offload_state_dict": offload_state_dict, - "low_cpu_mem_usage": low_cpu_mem_usage, - "variant": variant, - "use_safetensors": use_safetensors, - } + # import it here to avoid circular import + from diffusers import pipelines + + # 6. Load each module in the pipeline + for name, (library_name, class_name) in logging.tqdm( + init_dict.items(), desc="Loading pipeline components..." + ): + # 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names + class_name = class_name[4:] if class_name.startswith("Flax") else class_name + + # 6.2 Define all importable classes + is_pipeline_module = hasattr(pipelines, library_name) + importable_classes = ALL_IMPORTABLE_CLASSES + loaded_sub_model = None + + # 6.3 Use passed sub model or load class_name from library_name + if name in passed_class_obj: + # if the model is in a pipeline module, then we load it from the pipeline + # check that passed_class_obj has correct parent class + maybe_raise_or_warn( + library_name, + library, + class_name, + importable_classes, + passed_class_obj, + name, + is_pipeline_module, + ) - def get_connected_passed_kwargs(prefix): - connected_passed_class_obj = { - k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix + loaded_sub_model = passed_class_obj[name] + else: + # load sub model + loaded_sub_model = load_sub_model( + library_name=library_name, + class_name=class_name, + importable_classes=importable_classes, + pipelines=pipelines, + is_pipeline_module=is_pipeline_module, + pipeline_class=pipeline_class, + torch_dtype=torch_dtype, + provider=provider, + sess_options=sess_options, + device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + model_variants=model_variants, + name=name, + from_flax=from_flax, + variant=variant, + low_cpu_mem_usage=low_cpu_mem_usage, + cached_folder=cached_folder, + revision=revision, + ) + logger.info( + f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." + ) + + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")): + modelcard = ModelCard.load(os.path.join(cached_folder, "README.md")) + connected_pipes = { + prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS } - connected_passed_pipe_kwargs = { - k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix + load_kwargs = { + "cache_dir": cache_dir, + "resume_download": resume_download, + "force_download": force_download, + "proxies": proxies, + "local_files_only": local_files_only, + "token": token, + "revision": revision, + "torch_dtype": torch_dtype, + "custom_pipeline": custom_pipeline, + "custom_revision": custom_revision, + "provider": provider, + "sess_options": sess_options, + "device_map": device_map, + "max_memory": max_memory, + "offload_folder": offload_folder, + "offload_state_dict": offload_state_dict, + "low_cpu_mem_usage": low_cpu_mem_usage, + "variant": variant, + "use_safetensors": use_safetensors, } - connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs} - return connected_passed_kwargs + def get_connected_passed_kwargs(prefix): + connected_passed_class_obj = { + k.replace(f"{prefix}_", ""): w + for k, w in passed_class_obj.items() + if k.split("_")[0] == prefix + } + connected_passed_pipe_kwargs = { + k.replace(f"{prefix}_", ""): w + for k, w in passed_pipe_kwargs.items() + if k.split("_")[0] == prefix + } - connected_pipes = { - prefix: DiffusionPipeline.from_pretrained( - repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix) - ) - for prefix, repo_id in connected_pipes.items() - if repo_id is not None - } + connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs} + return connected_passed_kwargs - for prefix, connected_pipe in connected_pipes.items(): - # add connected pipes to `init_kwargs` with _, e.g. "prior_text_encoder" - init_kwargs.update( - {"_".join([prefix, name]): component for name, component in connected_pipe.components.items()} - ) + connected_pipes = { + prefix: DiffusionPipeline.from_pretrained( + repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix) + ) + for prefix, repo_id in connected_pipes.items() + if repo_id is not None + } - # 7. Potentially add passed objects if expected - missing_modules = set(expected_modules) - set(init_kwargs.keys()) - passed_modules = list(passed_class_obj.keys()) - optional_modules = pipeline_class._optional_components - if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): - for module in missing_modules: - init_kwargs[module] = passed_class_obj.get(module, None) - elif len(missing_modules) > 0: - passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs - raise ValueError( - f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." - ) + for prefix, connected_pipe in connected_pipes.items(): + # add connected pipes to `init_kwargs` with _, e.g. "prior_text_encoder" + init_kwargs.update( + {"_".join([prefix, name]): component for name, component in connected_pipe.components.items()} + ) + + # 7. Potentially add passed objects if expected + missing_modules = set(expected_modules) - set(init_kwargs.keys()) + passed_modules = list(passed_class_obj.keys()) + optional_modules = pipeline_class._optional_components + if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): + for module in missing_modules: + init_kwargs[module] = passed_class_obj.get(module, None) + elif len(missing_modules) > 0: + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs + raise ValueError( + f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." + ) - # 8. Instantiate the pipeline - model = pipeline_class(**init_kwargs) + # 8. Instantiate the pipeline + model = pipeline_class(**init_kwargs) - # 9. Save where the model was instantiated from - model.register_to_config(_name_or_path=pretrained_model_name_or_path) + # 9. Save where the model was instantiated from + model.register_to_config(_name_or_path=pretrained_model_name_or_path) return model @property diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 35aba10d7e58..33bfb03d85ba 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -19,6 +19,7 @@ from .. import __version__ from .constants import ( + _ACCEPTED_SINGLE_FILE_FORMATS, CONFIG_NAME, DEPRECATED_REVISION_ARGS, DIFFUSERS_DYNAMIC_MODULE_NAME, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index bc4268a32ac5..db8899a93719 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -37,6 +37,7 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] +_ACCEPTED_SINGLE_FILE_FORMATS = (".safetensors", ".ckpt", ".bin", ".pth", ".pt") # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are From 64998bca1b89cab0c996d3a8c7e0d4cb82efb319 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 15 Feb 2024 16:29:28 +0530 Subject: [PATCH 2/8] support models too --- src/diffusers/loaders/autoencoder.py | 1 + src/diffusers/loaders/controlnet.py | 1 + src/diffusers/models/modeling_utils.py | 434 +++++++++++----------- src/diffusers/pipelines/pipeline_utils.py | 18 +- src/diffusers/utils/__init__.py | 2 +- src/diffusers/utils/loading_utils.py | 17 + 6 files changed, 242 insertions(+), 231 deletions(-) diff --git a/src/diffusers/loaders/autoencoder.py b/src/diffusers/loaders/autoencoder.py index 4bcdda9bf6ef..fa5e75c85881 100644 --- a/src/diffusers/loaders/autoencoder.py +++ b/src/diffusers/loaders/autoencoder.py @@ -144,4 +144,5 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): if torch_dtype is not None: vae = vae.to(torch_dtype) + vae.eval() return vae diff --git a/src/diffusers/loaders/controlnet.py b/src/diffusers/loaders/controlnet.py index 4999ab32694b..075d084ab7b3 100644 --- a/src/diffusers/loaders/controlnet.py +++ b/src/diffusers/loaders/controlnet.py @@ -134,4 +134,5 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): if torch_dtype is not None: controlnet = controlnet.to(torch_dtype) + controlnet.eval() return controlnet diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index cea76c53d945..dc5f3b46f2f8 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -39,6 +39,7 @@ _get_model_file, deprecate, is_accelerate_available, + is_single_file_checkpoint, is_torch_version, logging, ) @@ -497,102 +498,87 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` """ - cache_dir = kwargs.pop("cache_dir", None) - ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) - force_download = kwargs.pop("force_download", False) - from_flax = kwargs.pop("from_flax", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - output_loading_info = kwargs.pop("output_loading_info", False) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) - subfolder = kwargs.pop("subfolder", None) - device_map = kwargs.pop("device_map", None) - max_memory = kwargs.pop("max_memory", None) - offload_folder = kwargs.pop("offload_folder", None) - offload_state_dict = kwargs.pop("offload_state_dict", False) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - variant = kwargs.pop("variant", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - if low_cpu_mem_usage and not is_accelerate_available(): - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) + single_file_ckpt = False + if is_single_file_checkpoint(pretrained_model_name_or_path): + logger.info("Single file checkpoint detected...") + model = cls.from_single_file(pretrained_model_name_or_path, **kwargs) + single_file_ckpt = True - if device_map is not None and not is_accelerate_available(): - raise NotImplementedError( - "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" - " `device_map=None`. You can install accelerate with `pip install accelerate`." - ) + else: + cache_dir = kwargs.pop("cache_dir", None) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) - # Check if we can handle device_map and dispatching the weights - if device_map is not None and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `device_map=None`." - ) + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `low_cpu_mem_usage=False`." - ) + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) - if low_cpu_mem_usage is False and device_map is not None: - raise ValueError( - f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" - " dispatching. Please make sure to set `low_cpu_mem_usage=True`." - ) + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) - # Load config if we don't provide a configuration - config_path = pretrained_model_name_or_path - - user_agent = { - "diffusers": __version__, - "file_type": "model", - "framework": "pytorch", - } - - # load config - config, unused_kwargs, commit_hash = cls.load_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - return_commit_hash=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - device_map=device_map, - max_memory=max_memory, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - user_agent=user_agent, - **kwargs, - ) - - # load model - model_file = None - if from_flax: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=FLAX_WEIGHTS_NAME, + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, force_download=force_download, resume_download=resume_download, proxies=proxies, @@ -600,40 +586,20 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P token=token, revision=revision, subfolder=subfolder, + device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, user_agent=user_agent, - commit_hash=commit_hash, + **kwargs, ) - model = cls.from_config(config, **unused_kwargs) - - # Convert the weights - from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model - model = load_flax_checkpoint_in_pytorch_model(model, model_file) - else: - if use_safetensors: - try: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - ) - except IOError as e: - if not allow_pickle: - raise e - pass - if model_file is None: + # load model + model_file = None + if from_flax: model_file = _get_model_file( pretrained_model_name_or_path, - weights_name=_add_variant(WEIGHTS_NAME, variant), + weights_name=FLAX_WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, @@ -645,76 +611,90 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P user_agent=user_agent, commit_hash=commit_hash, ) + model = cls.from_config(config, **unused_kwargs) - if low_cpu_mem_usage: - # Instantiate model with empty weights - with accelerate.init_empty_weights(): - model = cls.from_config(config, **unused_kwargs) + # Convert the weights + from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model - # if device_map is None, load the state dict and move the params from meta device to the cpu - if device_map is None: - param_device = "cpu" - state_dict = load_state_dict(model_file, variant=variant) - model._convert_deprecated_attention_blocks(state_dict) - # move the params from meta device to cpu - missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) - if len(missing_keys) > 0: - raise ValueError( - f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" - f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" - " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" - " those weights or else make sure your checkpoint file is correct." + model = load_flax_checkpoint_in_pytorch_model(model, model_file) + else: + if use_safetensors: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, ) - - unexpected_keys = load_model_dict_into_meta( - model, - state_dict, - device=param_device, - dtype=torch_dtype, - model_name_or_path=pretrained_model_name_or_path, + except IOError as e: + if not allow_pickle: + raise e + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, ) - if cls._keys_to_ignore_on_load_unexpected is not None: - for pat in cls._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - - if len(unexpected_keys) > 0: - logger.warn( - f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" - ) + if low_cpu_mem_usage: + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **unused_kwargs) + + # if device_map is None, load the state dict and move the params from meta device to the cpu + if device_map is None: + param_device = "cpu" + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) - else: # else let accelerate handle loading and dispatching. - # Load weights and dispatch according to the device_map - # by default the device_map is None and the weights are loaded on the CPU - try: - accelerate.load_checkpoint_and_dispatch( + unexpected_keys = load_model_dict_into_meta( model, - model_file, - device_map, - max_memory=max_memory, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, + state_dict, + device=param_device, dtype=torch_dtype, + model_name_or_path=pretrained_model_name_or_path, ) - except AttributeError as e: - # When using accelerate loading, we do not have the ability to load the state - # dict and rename the weight names manually. Additionally, accelerate skips - # torch loading conventions and directly writes into `module.{_buffers, _parameters}` - # (which look like they should be private variables?), so we can't use the standard hooks - # to rename parameters on load. We need to mimic the original weight names so the correct - # attributes are available. After we have loaded the weights, we convert the deprecated - # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert - # the weights so we don't have to do this again. - - if "'Attention' object has no attribute" in str(e): + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: logger.warn( - f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" - " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" - " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," - " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," - " please also re-upload it or open a PR on the original repository." + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" ) - model._temp_convert_self_to_deprecated_attention_blocks() + + else: # else let accelerate handle loading and dispatching. + # Load weights and dispatch according to the device_map + # by default the device_map is None and the weights are loaded on the CPU + try: accelerate.load_checkpoint_and_dispatch( model, model_file, @@ -724,49 +704,77 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P offload_state_dict=offload_state_dict, dtype=torch_dtype, ) - model._undo_temp_convert_self_to_deprecated_attention_blocks() - else: - raise e - - loading_info = { - "missing_keys": [], - "unexpected_keys": [], - "mismatched_keys": [], - "error_msgs": [], - } - else: - model = cls.from_config(config, **unused_kwargs) + except AttributeError as e: + # When using accelerate loading, we do not have the ability to load the state + # dict and rename the weight names manually. Additionally, accelerate skips + # torch loading conventions and directly writes into `module.{_buffers, _parameters}` + # (which look like they should be private variables?), so we can't use the standard hooks + # to rename parameters on load. We need to mimic the original weight names so the correct + # attributes are available. After we have loaded the weights, we convert the deprecated + # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert + # the weights so we don't have to do this again. + + if "'Attention' object has no attribute" in str(e): + logger.warn( + f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" + " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" + " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," + " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," + " please also re-upload it or open a PR on the original repository." + ) + model._temp_convert_self_to_deprecated_attention_blocks() + accelerate.load_checkpoint_and_dispatch( + model, + model_file, + device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + ) + model._undo_temp_convert_self_to_deprecated_attention_blocks() + else: + raise e + + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + model = cls.from_config(config, **unused_kwargs) - state_dict = load_state_dict(model_file, variant=variant) - model._convert_deprecated_attention_blocks(state_dict) + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( - model, - state_dict, - model_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - ) + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - "error_msgs": error_msgs, - } + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." - ) - elif torch_dtype is not None: - model = model.to(torch_dtype) + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) - model.register_to_config(_name_or_path=pretrained_model_name_or_path) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) # Set model in evaluation mode to deactivate DropOut modules by default model.eval() - if output_loading_info: + if output_loading_info and not single_file_ckpt: return model, loading_info return model diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 852108418c61..2c5bb32d428c 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -23,7 +23,6 @@ from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union -from urllib.parse import urlparse import numpy as np import PIL.Image @@ -46,7 +45,6 @@ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( - _ACCEPTED_SINGLE_FILE_FORMATS, CONFIG_NAME, DEPRECATED_REVISION_ARGS, SAFETENSORS_WEIGHTS_NAME, @@ -57,6 +55,7 @@ is_accelerate_available, is_accelerate_version, is_peft_available, + is_single_file_checkpoint, is_torch_version, is_transformers_available, logging, @@ -568,21 +567,6 @@ def _fetch_class_library_tuple(module): return (library, class_name) -def is_valid_url(url): - result = urlparse(url) - if result.scheme and result.netloc: - return True - - -def is_single_file_checkpoint(filepath): - if filepath.endswith(_ACCEPTED_SINGLE_FILE_FORMATS): - if is_valid_url(filepath): - return True - elif os.path.isfile(filepath): - return True - return False - - class DiffusionPipeline(ConfigMixin, PushToHubMixin): r""" Base class for all pipelines. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 33bfb03d85ba..5ce06a4e8a4c 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -84,7 +84,7 @@ is_xformers_available, requires_backends, ) -from .loading_utils import load_image +from .loading_utils import is_single_file_checkpoint, load_image from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index 18f6ead64c4e..c962abb78448 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -1,10 +1,27 @@ import os from typing import Callable, Union +from urllib.parse import urlparse import PIL.Image import PIL.ImageOps import requests +from ..utils.constants import _ACCEPTED_SINGLE_FILE_FORMATS + + +def is_single_file_checkpoint(filepath): + def is_valid_url(url): + result = urlparse(url) + if result.scheme and result.netloc: + return True + + if filepath.endswith(_ACCEPTED_SINGLE_FILE_FORMATS): + if is_valid_url(filepath): + return True + elif os.path.isfile(filepath): + return True + return False + def load_image( image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None From 46f4c4399c8277e8a54b20f2087c76bf35ccf510 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 15 Feb 2024 16:56:29 +0530 Subject: [PATCH 3/8] add proper error handling through loadable classes check. --- src/diffusers/models/modeling_utils.py | 6 ++++++ src/diffusers/pipelines/pipeline_utils.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index dc5f3b46f2f8..6c456347d1a3 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -49,6 +49,8 @@ logger = logging.get_logger(__name__) +SINGLE_FILE_LOADABLE_CLASSES = {"ControlNetModel", "AutoencoderKL"} + if is_torch_version(">=", "1.9.0"): _LOW_CPU_MEM_USAGE_DEFAULT = True else: @@ -500,6 +502,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P """ single_file_ckpt = False if is_single_file_checkpoint(pretrained_model_name_or_path): + if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES: + raise ValueError( + f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}." + ) logger.info("Single file checkpoint detected...") model = cls.from_single_file(pretrained_model_name_or_path, **kwargs) single_file_ckpt = True diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2c5bb32d428c..69dd404d6e0b 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -109,6 +109,20 @@ }, } +SINGLE_FILE_LOADABLE_CLASSES = { + "StableDiffusionPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionUpscalePipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionXLPipeline", + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", +} + ALL_IMPORTABLE_CLASSES = {} for library in LOADABLE_CLASSES: ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) @@ -1056,6 +1070,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ``` """ if is_single_file_checkpoint(pretrained_model_name_or_path): + if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES: + raise ValueError( + f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}." + ) logger.info("Single file checkpoint detected...") model = cls.from_single_file(pretrained_model_name_or_path, **kwargs) else: From 9c734f78e8e070afe7f00083111b6b36457d9bc2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 15 Feb 2024 17:02:54 +0530 Subject: [PATCH 4/8] fix: condition for loading_info --- src/diffusers/models/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 6c456347d1a3..a304e16f45b1 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -780,7 +780,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Set model in evaluation mode to deactivate DropOut modules by default model.eval() - if output_loading_info and not single_file_ckpt: + if not single_file_ckpt and output_loading_info: return model, loading_info return model From 45ab4399cc1c158b4394b93df6eab4f2fe2bcd4b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 15 Feb 2024 17:14:58 +0530 Subject: [PATCH 5/8] Empty-Commit From e9e41981d713719ff169d127028877f5e852daa1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 19 Feb 2024 13:18:34 +0530 Subject: [PATCH 6/8] fix: posix --- src/diffusers/utils/loading_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index c962abb78448..45f3516fe643 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -15,6 +15,7 @@ def is_valid_url(url): if result.scheme and result.netloc: return True + filepath = str(filepath) if filepath.endswith(_ACCEPTED_SINGLE_FILE_FORMATS): if is_valid_url(filepath): return True From d7d757a38791c9737ba953a002121d455dcf0db4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 26 Feb 2024 15:14:35 +0530 Subject: [PATCH 7/8] make single file loader cleaner models --- src/diffusers/models/modeling_utils.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index a304e16f45b1..d31c6f4d5be9 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -500,7 +500,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` """ - single_file_ckpt = False if is_single_file_checkpoint(pretrained_model_name_or_path): if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES: raise ValueError( @@ -508,8 +507,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) logger.info("Single file checkpoint detected...") model = cls.from_single_file(pretrained_model_name_or_path, **kwargs) - single_file_ckpt = True - + model = model.eval() + return model else: cache_dir = kwargs.pop("cache_dir", None) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) @@ -778,12 +777,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model.register_to_config(_name_or_path=pretrained_model_name_or_path) - # Set model in evaluation mode to deactivate DropOut modules by default - model.eval() - if not single_file_ckpt and output_loading_info: - return model, loading_info + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info - return model + return model @classmethod def _load_pretrained_model( From 4d2ca28e24494b1ce469dadee28d58ebac08374d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 26 Feb 2024 15:16:19 +0530 Subject: [PATCH 8/8] ditto for pipelines. --- src/diffusers/pipelines/pipeline_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 69dd404d6e0b..61daa651df73 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1076,6 +1076,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) logger.info("Single file checkpoint detected...") model = cls.from_single_file(pretrained_model_name_or_path, **kwargs) + return model + else: cache_dir = kwargs.pop("cache_dir", None) resume_download = kwargs.pop("resume_download", False) @@ -1393,7 +1395,7 @@ def get_connected_passed_kwargs(prefix): # 9. Save where the model was instantiated from model.register_to_config(_name_or_path=pretrained_model_name_or_path) - return model + return model @property def name_or_path(self) -> str: