diff --git a/invokeai/backend/model_management/model_hash.py b/invokeai/backend/model_management/model_hash.py new file mode 100644 index 00000000000..93b0cbed4b7 --- /dev/null +++ b/invokeai/backend/model_management/model_hash.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +Fast hashing of diffusers and checkpoint-style models. + +Usage: +from invokeai.backend.model_management.model_hash import FastModelHash +>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5') +'a8e693a126ea5b831c96064dc569956f' +""" + +import os +import hashlib +from imohash import hashfile +from pathlib import Path +from typing import Dict, Union + + +class FastModelHash(object): + """FastModelHash obect provides one public class method, hash().""" + + # When traversing directories, ignore files smaller than this + # minimum value + MINIMUM_FILE_SIZE = 100000 + + @classmethod + def hash(cls, model_location: Union[str, Path]) -> str: + """ + Return hexdigest string for model located at model_location. + + :param model_location: Path to the model + """ + model_location = Path(model_location) + if model_location.is_file(): + return cls._hash_file(model_location) + elif model_location.is_dir(): + return cls._hash_dir(model_location) + else: + # avoid circular import + from .models import InvalidModelException + + raise InvalidModelException(f"Not a valid file or directory: {model_location}") + + @classmethod + def _hash_file(cls, model_location: Union[str, Path]) -> str: + """ + Fasthash a single file and return its hexdigest. + + :param model_location: Path to the model file + """ + # we return sha256 hash of the filehash in order to be + # consistent with length of hashes returned by _hash_dir() + return hashlib.sha256(hashfile(model_location)).hexdigest() + + @classmethod + def _hash_dir(cls, model_location: Union[str, Path]) -> str: + components: Dict[str, str] = {} + + for root, dirs, files in os.walk(model_location): + for file in files: + # Only pay attention to the big files. The config + # files contain things like diffusers point version + # which change locally. + path = Path(root) / file + if path.stat().st_size < cls.MINIMUM_FILE_SIZE: + continue + fast_hash = cls._hash_file(path) + components.update({str(path): fast_hash}) + + # hash all the model hashes together, using alphabetic file order + sha = hashlib.sha256() + for path, fast_hash in sorted(components.items()): + sha.update(fast_hash.encode("utf-8")) + return sha.hexdigest() diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 0bad714a171..3125a2a8a32 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -260,6 +260,7 @@ class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are n InvalidModelException, DuplicateModelException, ) +from .model_hash import FastModelHash # We are only starting to number the config file with release 3. # The config file version doesn't have to start at release version, but it will help @@ -364,6 +365,8 @@ def _read_models(self, config: Optional[DictConfig] = None): model_class = MODEL_CLASSES[base_model][model_type] # alias for config file model_config["model_format"] = model_config.pop("format") + if not model_config.get("hash"): + model_config["hash"] = FastModelHash.hash(self.resolve_model_path(model_config["path"])) self.models[model_key] = model_class.create_config(**model_config) # check config version number and update on disk/RAM if necessary @@ -431,6 +434,28 @@ def initialize_model_config(cls, config_path: Path): with open(config_path, "w") as yaml_file: yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) + def get_model_by_hash( + self, + model_hash: str, + submodel_type: Optional[SubModelType] = None, + ) -> ModelInfo: + """ + Given a model's unique hash, return its ModelInfo. + + :param model_hash: Unique hash for this model. + """ + info = self.list_models() + keys = [x for x in info if x["hash"] == model_hash] + if len(keys) == 0: + raise InvalidModelException(f"No model with hash {model_hash} found") + if len(keys) > 1: + raise DuplicateModelException(f"Duplicate models detected: {keys}") + return self.get_model( + keys[0]["model_name"], + base_model=keys[0]["base_model"], + model_type=keys[0]["model_type"], + ) + def get_model( self, model_name: str, @@ -500,14 +525,12 @@ def get_model( self.cache_keys[model_key] = set() self.cache_keys[model_key].add(model_context.key) - model_hash = "" # TODO: - return ModelInfo( context=model_context, name=model_name, base_model=base_model, type=submodel_type or model_type, - hash=model_hash, + hash=model_config.hash, location=model_path, # TODO: precision=self.cache.precision, _cache=self.cache, @@ -660,12 +683,22 @@ def add_model( if path := model_attributes.get("path"): model_attributes["path"] = str(self.relative_model_path(Path(path))) + if not model_attributes.get("hash"): + hash = FastModelHash.hash(self.resolve_model_path(model_attributes["path"])) + model_attributes["hash"] = hash + model_class = MODEL_CLASSES[base_model][model_type] model_config = model_class.create_config(**model_attributes) model_key = self.create_key(model_name, base_model, model_type) - if model_key in self.models and not clobber: - raise Exception(f'Attempt to overwrite existing model definition "{model_key}"') + if not clobber: + if model_key in self.models: + raise Exception(f'Attempt to overwrite existing model definition "{model_key}"') + try: + i = self.get_model_by_hash(model_attributes["hash"]) + raise DuplicateModelException(f"There is already a model with hash {hash}: {i['name']}") + except: + pass old_model = self.models.pop(model_key, None) if old_model is not None: @@ -941,7 +974,11 @@ def scan_models_directory( raise DuplicateModelException(f"Model with key {model_key} added twice") model_path = self.relative_model_path(model_path) - model_config: ModelConfigBase = model_class.probe_config(str(model_path)) + model_config: ModelConfigBase = model_class.probe_config( + str(model_path), + hash=FastModelHash.hash(model_path), + model_base=cur_base_model, + ) self.models[model_key] = model_config new_models_found = True except DuplicateModelException as e: diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 21462cf6e63..7209cb30f13 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -345,8 +345,12 @@ def get_base_type(self) -> BaseModelType: return BaseModelType.StableDiffusion1 elif lora_token_vector_length == 1024: return BaseModelType.StableDiffusion2 + elif lora_token_vector_length is None: # variant w/o the text encoder! + return BaseModelType.StableDiffusion1 else: - raise InvalidModelException(f"Unknown LoRA type") + raise InvalidModelException( + f"Unknown LoRA type: {self.checkpoint_path}, lora_token_vector_length={lora_token_vector_length}" + ) class TextualInversionCheckpointProbe(CheckpointProbeBase): diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index d335b645c8c..15e3d510516 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -89,6 +89,7 @@ class ModelConfigBase(BaseModel): path: str # or Path description: Optional[str] = Field(None) model_format: Optional[str] = Field(None) + hash: Optional[str] = Field(None) error: Optional[ModelError] = Field(None) class Config: @@ -197,15 +198,16 @@ def _get_configs(cls): def create_config(cls, **kwargs) -> ModelConfigBase: if "model_format" not in kwargs: raise Exception("Field 'model_format' not found in model config") - configs = cls._get_configs() - return configs[kwargs["model_format"]](**kwargs) + config = configs[kwargs["model_format"]](**kwargs) + return config @classmethod def probe_config(cls, path: str, **kwargs) -> ModelConfigBase: return cls.create_config( path=path, model_format=cls.detect_format(path), + hash=kwargs["hash"], ) @classmethod diff --git a/invokeai/backend/model_management/models/sdxl.py b/invokeai/backend/model_management/models/sdxl.py index 7fc3efb77c7..1d54e2882ae 100644 --- a/invokeai/backend/model_management/models/sdxl.py +++ b/invokeai/backend/model_management/models/sdxl.py @@ -13,8 +13,11 @@ read_checkpoint_meta, classproperty, ) +from invokeai.app.services.config import InvokeAIAppConfig from omegaconf import OmegaConf +app_config = InvokeAIAppConfig.get_config() + class StableDiffusionXLModelFormat(str, Enum): Checkpoint = "checkpoint" @@ -22,7 +25,7 @@ class StableDiffusionXLModelFormat(str, Enum): class StableDiffusionXLModel(DiffusersModel): - # TODO: check that configs overwriten properly + # TODO: check that configs overwritten properly class DiffusersConfig(ModelConfigBase): model_format: Literal[StableDiffusionXLModelFormat.Diffusers] vae: Optional[str] = Field(None) @@ -79,14 +82,19 @@ def probe_config(cls, path: str, **kwargs): else: raise Exception("Unkown stable diffusion 2.* model format") - if ckpt_config_path is None: - # TO DO: implement picking - pass + if ckpt_config_path is None and "model_base" in kwargs: + ckpt_config_path = ( + app_config.legacy_conf_path / "sd_xl_base.yaml" + if kwargs["model_base"] == BaseModelType.StableDiffusionXL + else app_config.legacy_conf_path / "sd_xl_refiner.yaml" + if kwargs["model_base"] == BaseModelType.StableDiffusionXLRefiner + else None + ) return cls.create_config( path=path, model_format=model_format, - config=ckpt_config_path, + config=str(ckpt_config_path), variant=variant, ) diff --git a/pyproject.toml b/pyproject.toml index b3f12481a87..7975ebf5aa8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "flask_socketio==5.3.0", "flaskwebgui==1.0.3", "huggingface-hub>=0.11.1", + "imohash~=1.0.0", "invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids "matplotlib", # needed for plotting of Penner easing functions "mediapipe", # needed for "mediapipeface" controlnet model