Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions invokeai/backend/model_management/model_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Fast hashing of diffusers and checkpoint-style models.

Usage:
from invokeai.backend.model_management.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""

import os
import hashlib
from imohash import hashfile
from pathlib import Path
from typing import Dict, Union


class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""

# When traversing directories, ignore files smaller than this
# minimum value
MINIMUM_FILE_SIZE = 100000

@classmethod
def hash(cls, model_location: Union[str, Path]) -> str:
"""
Return hexdigest string for model located at model_location.

:param model_location: Path to the model
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location)
else:
# avoid circular import
from .models import InvalidModelException

raise InvalidModelException(f"Not a valid file or directory: {model_location}")

@classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str:
"""
Fasthash a single file and return its hexdigest.

:param model_location: Path to the model file
"""
# we return sha256 hash of the filehash in order to be
# consistent with length of hashes returned by _hash_dir()
return hashlib.sha256(hashfile(model_location)).hexdigest()

@classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {}

for root, dirs, files in os.walk(model_location):
for file in files:
# Only pay attention to the big files. The config
# files contain things like diffusers point version
# which change locally.
path = Path(root) / file
if path.stat().st_size < cls.MINIMUM_FILE_SIZE:
continue
fast_hash = cls._hash_file(path)
components.update({str(path): fast_hash})

# hash all the model hashes together, using alphabetic file order
sha = hashlib.sha256()
for path, fast_hash in sorted(components.items()):
sha.update(fast_hash.encode("utf-8"))
return sha.hexdigest()
49 changes: 43 additions & 6 deletions invokeai/backend/model_management/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are n
InvalidModelException,
DuplicateModelException,
)
from .model_hash import FastModelHash

# We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help
Expand Down Expand Up @@ -364,6 +365,8 @@ def _read_models(self, config: Optional[DictConfig] = None):
model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
if not model_config.get("hash"):
model_config["hash"] = FastModelHash.hash(self.resolve_model_path(model_config["path"]))
self.models[model_key] = model_class.create_config(**model_config)

# check config version number and update on disk/RAM if necessary
Expand Down Expand Up @@ -431,6 +434,28 @@ def initialize_model_config(cls, config_path: Path):
with open(config_path, "w") as yaml_file:
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))

def get_model_by_hash(
self,
model_hash: str,
submodel_type: Optional[SubModelType] = None,
) -> ModelInfo:
"""
Given a model's unique hash, return its ModelInfo.

:param model_hash: Unique hash for this model.
"""
info = self.list_models()
keys = [x for x in info if x["hash"] == model_hash]
if len(keys) == 0:
raise InvalidModelException(f"No model with hash {model_hash} found")
if len(keys) > 1:
raise DuplicateModelException(f"Duplicate models detected: {keys}")
return self.get_model(
keys[0]["model_name"],
base_model=keys[0]["base_model"],
model_type=keys[0]["model_type"],
)

def get_model(
self,
model_name: str,
Expand Down Expand Up @@ -500,14 +525,12 @@ def get_model(
self.cache_keys[model_key] = set()
self.cache_keys[model_key].add(model_context.key)

model_hash = "<NO_HASH>" # TODO:

return ModelInfo(
context=model_context,
name=model_name,
base_model=base_model,
type=submodel_type or model_type,
hash=model_hash,
hash=model_config.hash,
location=model_path, # TODO:
precision=self.cache.precision,
_cache=self.cache,
Expand Down Expand Up @@ -660,12 +683,22 @@ def add_model(
if path := model_attributes.get("path"):
model_attributes["path"] = str(self.relative_model_path(Path(path)))

if not model_attributes.get("hash"):
hash = FastModelHash.hash(self.resolve_model_path(model_attributes["path"]))
model_attributes["hash"] = hash

model_class = MODEL_CLASSES[base_model][model_type]
model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type)

if model_key in self.models and not clobber:
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
if not clobber:
if model_key in self.models:
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
try:
i = self.get_model_by_hash(model_attributes["hash"])
raise DuplicateModelException(f"There is already a model with hash {hash}: {i['name']}")
except:
pass

old_model = self.models.pop(model_key, None)
if old_model is not None:
Expand Down Expand Up @@ -941,7 +974,11 @@ def scan_models_directory(
raise DuplicateModelException(f"Model with key {model_key} added twice")

model_path = self.relative_model_path(model_path)
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
model_config: ModelConfigBase = model_class.probe_config(
str(model_path),
hash=FastModelHash.hash(model_path),
model_base=cur_base_model,
)
self.models[model_key] = model_config
new_models_found = True
except DuplicateModelException as e:
Expand Down
6 changes: 5 additions & 1 deletion invokeai/backend/model_management/model_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,12 @@ def get_base_type(self) -> BaseModelType:
return BaseModelType.StableDiffusion1
elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif lora_token_vector_length is None: # variant w/o the text encoder!
return BaseModelType.StableDiffusion1
else:
raise InvalidModelException(f"Unknown LoRA type")
raise InvalidModelException(
f"Unknown LoRA type: {self.checkpoint_path}, lora_token_vector_length={lora_token_vector_length}"
)


class TextualInversionCheckpointProbe(CheckpointProbeBase):
Expand Down
6 changes: 4 additions & 2 deletions invokeai/backend/model_management/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class ModelConfigBase(BaseModel):
path: str # or Path
description: Optional[str] = Field(None)
model_format: Optional[str] = Field(None)
hash: Optional[str] = Field(None)
error: Optional[ModelError] = Field(None)

class Config:
Expand Down Expand Up @@ -197,15 +198,16 @@ def _get_configs(cls):
def create_config(cls, **kwargs) -> ModelConfigBase:
if "model_format" not in kwargs:
raise Exception("Field 'model_format' not found in model config")

configs = cls._get_configs()
return configs[kwargs["model_format"]](**kwargs)
config = configs[kwargs["model_format"]](**kwargs)
return config

@classmethod
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
return cls.create_config(
path=path,
model_format=cls.detect_format(path),
hash=kwargs["hash"],
)

@classmethod
Expand Down
18 changes: 13 additions & 5 deletions invokeai/backend/model_management/models/sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@
read_checkpoint_meta,
classproperty,
)
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf

app_config = InvokeAIAppConfig.get_config()


class StableDiffusionXLModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"


class StableDiffusionXLModel(DiffusersModel):
# TODO: check that configs overwriten properly
# TODO: check that configs overwritten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
vae: Optional[str] = Field(None)
Expand Down Expand Up @@ -79,14 +82,19 @@ def probe_config(cls, path: str, **kwargs):
else:
raise Exception("Unkown stable diffusion 2.* model format")

if ckpt_config_path is None:
# TO DO: implement picking
pass
if ckpt_config_path is None and "model_base" in kwargs:
ckpt_config_path = (
app_config.legacy_conf_path / "sd_xl_base.yaml"
if kwargs["model_base"] == BaseModelType.StableDiffusionXL
else app_config.legacy_conf_path / "sd_xl_refiner.yaml"
if kwargs["model_base"] == BaseModelType.StableDiffusionXLRefiner
else None
)

return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
config=str(ckpt_config_path),
variant=variant,
)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ dependencies = [
"flask_socketio==5.3.0",
"flaskwebgui==1.0.3",
"huggingface-hub>=0.11.1",
"imohash~=1.0.0",
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
"matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model
Expand Down