Skip to content
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
154 commits
Select commit Hold shift + click to select a range
ae56c00
define model configuration classes
Aug 12, 2023
b5d97b1
blackify
Aug 12, 2023
e8edb0d
add ABC for config storage
Aug 12, 2023
e8815a1
rename ModelConfig to ModelConfigFactory
Aug 12, 2023
32958db
add YAML file storage backend
Aug 13, 2023
b2894b5
add class docstring and blackify
Aug 13, 2023
6c9b9e1
Merge branch 'main' into lstein/model-manager-refactor
lstein Aug 13, 2023
0c74300
change paths to str to make json serializable
Aug 13, 2023
5434dcd
fix test to work with string paths
Aug 13, 2023
1ea0ccb
add SQL backend
Aug 13, 2023
51e84e6
Merge branch 'main' into lstein/model-manager-refactor
lstein Aug 13, 2023
81da3d3
change model field name "hash" to "id"
Aug 13, 2023
155d9fc
Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/…
Aug 13, 2023
c56fb38
added ability to force config class returned by make_config()
Aug 13, 2023
7db71ed
rename modules
Aug 15, 2023
1c7d9db
start installer module
Aug 15, 2023
e83d005
module skeleton written
Aug 15, 2023
916cc26
partial rewrite of checkpoint template creator
Aug 17, 2023
0deb3f9
Merge branch 'main' into lstein/model-manager-refactor
Aug 20, 2023
1784aeb
fix flake8 errors
Aug 20, 2023
f023e34
added main templates
Aug 21, 2023
6f9bf87
reimplement and clean up probe class
Aug 23, 2023
4b3d54d
install ABC written
Aug 23, 2023
9adc897
added install module
Aug 23, 2023
055ad01
merge with main; resolve conflicts
Aug 23, 2023
93cef55
blackify
Aug 23, 2023
97f2e77
make ModelSearch pydantic
Aug 24, 2023
8396bf7
Merge branch 'main' into lstein/model-manager-refactor
Aug 30, 2023
e6512e1
add ABC for download manager
Aug 30, 2023
869f310
download of individual files working
Sep 2, 2023
8fc2092
added download manager service and began repo_id download
Sep 4, 2023
d1c5990
merge and resolve conflicts
Sep 4, 2023
8f51adc
chore: black
psychedelicious Sep 5, 2023
57552de
threaded repo_id download working; error conditions not tested
Sep 5, 2023
ca6d248
resolve merge conflicts
Sep 5, 2023
e907417
add unit tests for queued model download
Sep 6, 2023
404cfe0
add download manager to invoke services
Sep 6, 2023
626acd5
remove unecessary HTTP probe for repo_id model component sizes
Sep 6, 2023
3448eda
fix progress reporting for repo_ids
Sep 6, 2023
82499d4
fix various typing errors in api dependencies initialization
Sep 7, 2023
11ead34
fix flake8 warnings
Sep 7, 2023
d979c50
Merge branch 'main' into lstein/model-manager-refactor
lstein Sep 7, 2023
c9a016f
more flake8 fixes
Sep 7, 2023
b09e012
Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/…
Sep 7, 2023
79b2423
last flake8 fix - why is local flake8 not identical to git flake8?
Sep 7, 2023
a7aca29
implement regression tests for pause/cancel/error conditions
Sep 7, 2023
2165d55
add checks for malformed URLs and malicious content dispositions
Sep 8, 2023
b7ca983
blackify
Sep 8, 2023
598fe81
wire together download and install; now need to write install events
Sep 9, 2023
64424c6
install of repo_ids records author, tags and license
Sep 9, 2023
3582cfa
make download manager optional in InvokeAIServices during development
Sep 9, 2023
b2892f9
incorporate civitai metadata into model config
Sep 10, 2023
b7a6a53
fix flake8 warnings
Sep 10, 2023
8636015
increase download chunksize for better speed
Sep 10, 2023
8052f2e
Merge branch 'main' into lstein/model-manager-refactor
Sep 10, 2023
f454304
make it possible to pause/resume repo_id downloads
Sep 10, 2023
b583bdd
loading works -- web app broken
Sep 11, 2023
7430d87
loader working
Sep 11, 2023
6d8b2a7
pytests mostly working; model_manager_service needs rewriting
lstein Sep 12, 2023
4b932b2
refactor create_download_job; override probe info in install call
lstein Sep 13, 2023
27dcd89
merge with main; model_manager_service.py needs to be rewritten
Sep 14, 2023
ac88863
fix exception traceback reporting
lstein Sep 14, 2023
171d789
model loader autoscans models_dir on initialization
lstein Sep 14, 2023
716a1b6
model_manager_service now mostly type correct
lstein Sep 15, 2023
a033ccc
blackify
lstein Sep 15, 2023
3529925
services rewritten; starting work on routes
Sep 15, 2023
b7789bb
list_models() API call now working
Sep 16, 2023
08952b9
Merge branch 'main' into lstein/model-manager-refactor
Sep 16, 2023
b9a90fb
blackify and isort
Sep 16, 2023
db7fdc3
fix more isort issues
Sep 16, 2023
c090c5f
update_model and delete_model working; convert is WIP
Sep 16, 2023
dc68347
loading and conversions of checkpoints working
Sep 16, 2023
c029534
all methods in router API now tested and working
Sep 16, 2023
539776a
import_model API now working
Sep 17, 2023
e880f4b
add logs to confirm that event info is being sent to bus
Sep 17, 2023
f0ce559
add install job control to web API
Sep 17, 2023
238d7fa
add models.yaml conversion script
Sep 17, 2023
d051c08
attempt to fix flake8 lint errors
Sep 17, 2023
151ba02
fix models.yaml version assertion error in pytests
Sep 17, 2023
d1382f2
fasthash produces same results on windows & linux
lstein Sep 19, 2023
0c88491
Merge branch 'main' into lstein/model-manager-refactor
Sep 19, 2023
73bc088
blackify
Sep 19, 2023
de666fd
move incorrectly placed models into correct directory at startup time
Sep 19, 2023
ed91f48
TUI installer more or less working
lstein Sep 21, 2023
3402cf6
preserve description in metadata when installing a starter model
Sep 21, 2023
3199409
TUI installer functional; minor cosmetic work needed
Sep 21, 2023
30aea54
remove debug statement
Sep 21, 2023
c9cd418
add/delete from command line working; training words downloaded
Sep 21, 2023
07ddd60
fix install of models with relative paths
Sep 22, 2023
d2cdbe5
configure script now working
Sep 23, 2023
d5d517d
correctly download the selected version of a civitai model
Sep 23, 2023
ab58eb2
resolve conflicts with ip-adapter change
Sep 23, 2023
6edee2d
automatically convert models.yaml to new format
Sep 23, 2023
8bc1ca0
allow priority to be set at install job submission time
Sep 24, 2023
f9b92dd
resolve conflicts with get_logger() code changes from main
Sep 24, 2023
ac46340
merge with main & resolve conflicts
Sep 25, 2023
effced8
added `cancel_all` and `prune` model install operations to router API
Sep 25, 2023
1d6a4e7
add tests for model installation events
Sep 26, 2023
2e9a7b0
Merge branch 'main' into lstein/model-manager-refactor
lstein Sep 26, 2023
0b75a4f
resolve merge conflicts
Sep 28, 2023
81fce18
reorder pytests to prevent fixture race condition
Sep 28, 2023
2f16a2c
fix migrate script and type mismatches in probe, config and loader
Sep 29, 2023
3b832f1
fix one more type mismatch in probe module
Sep 29, 2023
4555aec
remove unused code from invokeai.backend.model_manager.storage.yaml
Sep 29, 2023
cbf0310
add README explaining reorg of tests directory
Sep 29, 2023
208d390
almost all type mismatches fixed
Sep 29, 2023
807ae82
more type mismatch fixes
Sep 30, 2023
acaaff4
make model merge script work with new model manager
Sep 30, 2023
c025c9c
speed up model scanning at startup
Sep 30, 2023
230ee18
do not ignore keyboard interrupt while scanning models
Sep 30, 2023
c91429d
merge with main
Oct 3, 2023
63f6c12
make merge script read invokeai.yaml when default root passed
Oct 3, 2023
48c3d92
make textual inversion training work with new model manager
Oct 3, 2023
062a6ed
prevent crash on windows due to lack of os.pathconf call
Oct 3, 2023
e3912e8
replace config.ram_cache_size with config.ram and similarly for vram
Oct 3, 2023
459f023
multiple minor fixes
Oct 4, 2023
4624de0
Merge branch 'main' into lstein/model-manager-refactor
lstein Oct 4, 2023
de90d40
Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/…
Oct 4, 2023
16ec7a3
fix type mismatches in download_manager service
Oct 4, 2023
a180c0f
check model hash before and after moving in filesystem
Oct 4, 2023
cb0fdf3
refactor model install job class hierarchy
Oct 4, 2023
cd5d3e3
refactor model_manager_service.py into small functional modules
Oct 5, 2023
9cbc62d
fix reorganized module dependencies
Oct 5, 2023
8e06088
refactor services
Oct 6, 2023
6303f74
allow user to select main database or external file for model record/…
Oct 7, 2023
00e85bc
make autoimport directory optional, defaulting to inactive
Oct 7, 2023
4421638
fix conversion call
Oct 7, 2023
432231e
merge with main
Oct 7, 2023
7f68f58
restore printing of version when invokeai-web and invokeai called wit…
Oct 7, 2023
5106054
support clipvision image encoder downloading
Oct 7, 2023
a64a34b
add support for repo_id subfolders
Oct 8, 2023
e5b2bc8
refactor download queue jobs
Oct 8, 2023
bccfe8b
fix some type mismatches introduces by reorg
Oct 8, 2023
ce2baa3
port support for AutoencoderTiny models
Oct 8, 2023
a80ff75
Update invokeai/app/invocations/model.py
lstein Oct 9, 2023
fe10386
address all PR 4252 comments from ryan through October 5
Oct 9, 2023
3644d40
Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/…
Oct 9, 2023
3962914
merge with main
Oct 9, 2023
33d4756
improve selection of huggingface repo id files to download
Oct 9, 2023
4149d35
refactor installer class hierarchy
Oct 9, 2023
e50a257
merge with main
Oct 9, 2023
4bab724
fix broken import
Oct 9, 2023
67607f0
fix issues with module import order breaking pytest node tests
Oct 10, 2023
71e7e61
add documentation for model record service and loader
Oct 10, 2023
76aa19a
first draft of documentation finished
Oct 11, 2023
e079cc9
add back source URL validation to download job hierarchy
Oct 12, 2023
0a0412f
restore CLI to broken state
Oct 12, 2023
a2079bd
Update docs/installation/050_INSTALLING_MODELS.md
lstein Oct 12, 2023
aace679
Update invokeai/app/services/model_convert.py
lstein Oct 12, 2023
b708aef
misc small fixes requested by Ryan
Oct 12, 2023
5f80d4d
Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/…
Oct 12, 2023
a51b165
clean up model downloader status locking to avoid race conditions
Oct 12, 2023
0f9c676
remove download queue change_priority() calls completely
Oct 12, 2023
53e1199
prevent potential infinite recursion on exceptions raised by event ha…
Oct 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 276 additions & 0 deletions invokeai/backend/model_management2/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Configuration definitions for image generation models.

Typical usage:

from invokeai.backend.model_management2.model_config import ModelConfigFactory
raw = dict(path='models/sd-1/main/foo.ckpt',
name='foo',
base_model='sd-1',
model_type='main',
config='configs/stable-diffusion/v1-inference.yaml',
model_variant='normal',
model_format='checkpoint'
)
config = ModelConfigFactory.make_config(raw)
print(config.name)

Validation errors will raise an InvalidModelConfigException error.

"""
import pydantic
from enum import Enum
from pydantic import BaseModel, Field, Extra
from typing import Optional, Literal, List, Union, Type
from pydantic.error_wrappers import ValidationError
from omegaconf.listconfig import ListConfig # to support the yaml backend


class InvalidModelConfigException(Exception):
"""Exception raised when the config parser doesn't recognize the passed
combination of model type and format."""

pass


class BaseModelType(str, Enum):
"""Base model type."""

StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
# Kandinsky2_1 = "kandinsky-2.1"


class ModelType(str, Enum):
"""Model type."""

ONNX = "onnx"
Main = "main"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"


class SubModelType(str, Enum):
"""Submodel type."""

UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"


class ModelVariantType(str, Enum):
"""Variant type."""

Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"


class ModelFormat(str, Enum):
"""Storage format of model."""

Diffusers = "diffusers"
Checkpoint = "checkpoint"
Lycoris = "lycoris"
Onnx = "onnx"
Olive = "olive"
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"


class SchedulerPredictionType(str, Enum):
"""Scheduler prediction type."""

Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"


class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""

path: str
name: str
base_model: BaseModelType
model_type: ModelType
model_format: ModelFormat
id: Optional[str] = Field(None) # this may get added by the store
description: Optional[str] = Field(None)
author: Optional[str] = Field(description="Model author")
thumbnail_url: Optional[str] = Field(description="URL of thumbnail image")
license_url: Optional[str] = Field(description="URL of license")
source_url: Optional[str] = Field(description="Model download source")
tags: Optional[List[str]] = Field(description="Descriptive tags") # Set would be better, but not JSON serializable

class Config:
"""Pydantic configuration hint."""

use_enum_values = True
extra = Extra.forbid
validate_assignment = True

@pydantic.validator("tags", pre=True)
@classmethod
def _fix_tags(cls, v):
if isinstance(v, ListConfig): # to support yaml backend
v = list(v)
return v


class CheckpointConfig(ModelConfigBase):
"""Model config for checkpoint-style models."""

model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file")


class DiffusersConfig(ModelConfigBase):
"""Model config for diffusers-style models."""

model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers


class LoRAConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models."""

model_format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]


class VaeCheckpointConfig(ModelConfigBase):
"""Model config for standalone VAE models."""

model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint


class VaeDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""

model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers


class TextualInversionConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""

model_format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]


class MainConfig(ModelConfigBase):
"""Model config for main models."""

vae: Optional[str] = Field(None)
model_variant: ModelVariantType = ModelVariantType.Normal


class MainCheckpointConfig(CheckpointConfig, MainConfig):
"""Model config for main checkpoint models."""

pass


class MainDiffusersConfig(DiffusersConfig, MainConfig):
"""Model config for main diffusers models."""

pass


class ONNXSD1Config(MainConfig):
"""Model config for ONNX format models based on sd-1."""

model_format: Literal[ModelFormat.Onnx, ModelFormat.Olive]


class ONNXSD2Config(MainConfig):
"""Model config for ONNX format models based on sd-2."""

model_format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
# No yaml config file for ONNX, so these are part of config
prediction_type: SchedulerPredictionType
upcast_attention: bool


class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion*Config obects."""

_class_map: dict = {
ModelFormat.Checkpoint: {
ModelType.Main: MainCheckpointConfig,
ModelType.Vae: VaeCheckpointConfig,
},
ModelFormat.Diffusers: {
ModelType.Main: MainDiffusersConfig,
ModelType.Lora: LoRAConfig,
ModelType.Vae: VaeDiffusersConfig,
},
ModelFormat.Lycoris: {
ModelType.Lora: LoRAConfig,
},
ModelFormat.Onnx: {
ModelType.ONNX: {
BaseModelType.StableDiffusion1: ONNXSD1Config,
BaseModelType.StableDiffusion2: ONNXSD2Config,
},
},
ModelFormat.Olive: {
ModelType.ONNX: {
BaseModelType.StableDiffusion1: ONNXSD1Config,
BaseModelType.StableDiffusion2: ONNXSD2Config,
},
},
ModelFormat.EmbeddingFile: {
ModelType.TextualInversion: TextualInversionConfig,
},
ModelFormat.EmbeddingFolder: {
ModelType.TextualInversion: TextualInversionConfig,
},
}

@classmethod
def make_config(
cls,
model_data: Union[dict, ModelConfigBase],
dest_class: Optional[Type] = None,
) -> Union[
MainCheckpointConfig,
MainDiffusersConfig,
LoRAConfig,
TextualInversionConfig,
ONNXSD1Config,
ONNXSD2Config,
]:
"""
Return the appropriate config object from raw dict values.

:param model_data: A raw dict corresponding the obect fields to be
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
object, which will be passed through unchanged.
:param dest_class: The config class to be returned. If not provided, will
be selected automatically.
"""
if isinstance(model_data, ModelConfigBase):
return model_data
try:
model_format = model_data.get("model_format")
model_type = model_data.get("model_type")
model_base = model_data.get("base_model")
class_to_return = dest_class or cls._class_map[model_format][model_type]
if isinstance(class_to_return, dict): # additional level allowed
class_to_return = class_to_return[model_base]
return class_to_return.parse_obj(model_data)
except KeyError:
raise InvalidModelConfigException(
f"Unknown combination of model_format '{model_format}' and model_type '{model_type}'"
)
except ValidationError as e:
raise InvalidModelConfigException(f"Invalid model configuration passed: {str(e)}") from e
115 changes: 115 additions & 0 deletions invokeai/backend/model_management2/storage/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Abstract base class for storing and retrieving model configuration records.
"""


from abc import ABC, abstractmethod
from typing import Union, Set, List, Optional

from ..model_config import ModelConfigBase, BaseModelType, ModelType


class DuplicateModelException(Exception):
"""Raised on an attempt to add a model with the same key twice."""

pass


class UnknownModelException(Exception):
"""Raised on an attempt to delete a model with a nonexistent key."""

pass


class ModelConfigStore(ABC):
"""Abstract base class for storage and retrieval of model configs."""

@abstractmethod
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
"""
Add a model to the database.

:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.

Can raise DuplicateModelException and InvalidModelConfig exceptions.
"""
pass

@abstractmethod
def del_model(self, key: str) -> None:
"""
Delete a model.

:param key: Unique key for the model to be deleted

Can raise an UnknownModelException
"""
pass

@abstractmethod
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
"""
Update the model, returning the updated version.

:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
pass

@abstractmethod
def get_model(self, key: str) -> ModelConfigBase:
"""
Retrieve the ModelConfigBase instance for the indicated model.

:param key: Key of model config to be fetched.

Exceptions: UnknownModelException
"""
pass

@abstractmethod
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.

:param key: Unique key for the model to be deleted
"""
pass

@abstractmethod
def search_by_tag(self, tags: Set[str]) -> List[ModelConfigBase]:
"""
Return models containing all of the listed tags.

:param tags: Set of tags to search on.
"""
pass

@abstractmethod
def search_by_type(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[ModelConfigBase]:
"""
Return models matching name, base and/or type.

:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)

If none of the optional filters are passed, will return all
models in the database.
"""
pass

def all_models(self) -> List[ModelConfigBase]:
"""
Return all the model configs in the database.
"""
return self.search_by_type()
Loading