-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Reimplement model config storage backend, download and install modules #4252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 14 commits
Commits
Show all changes
154 commits
Select commit
Hold shift + click to select a range
ae56c00
define model configuration classes
b5d97b1
blackify
e8edb0d
add ABC for config storage
e8815a1
rename ModelConfig to ModelConfigFactory
32958db
add YAML file storage backend
b2894b5
add class docstring and blackify
6c9b9e1
Merge branch 'main' into lstein/model-manager-refactor
lstein 0c74300
change paths to str to make json serializable
5434dcd
fix test to work with string paths
1ea0ccb
add SQL backend
51e84e6
Merge branch 'main' into lstein/model-manager-refactor
lstein 81da3d3
change model field name "hash" to "id"
155d9fc
Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/…
c56fb38
added ability to force config class returned by make_config()
7db71ed
rename modules
1c7d9db
start installer module
e83d005
module skeleton written
916cc26
partial rewrite of checkpoint template creator
0deb3f9
Merge branch 'main' into lstein/model-manager-refactor
1784aeb
fix flake8 errors
f023e34
added main templates
6f9bf87
reimplement and clean up probe class
4b3d54d
install ABC written
9adc897
added install module
055ad01
merge with main; resolve conflicts
93cef55
blackify
97f2e77
make ModelSearch pydantic
8396bf7
Merge branch 'main' into lstein/model-manager-refactor
e6512e1
add ABC for download manager
869f310
download of individual files working
8fc2092
added download manager service and began repo_id download
d1c5990
merge and resolve conflicts
8f51adc
chore: black
psychedelicious 57552de
threaded repo_id download working; error conditions not tested
ca6d248
resolve merge conflicts
e907417
add unit tests for queued model download
404cfe0
add download manager to invoke services
626acd5
remove unecessary HTTP probe for repo_id model component sizes
3448eda
fix progress reporting for repo_ids
82499d4
fix various typing errors in api dependencies initialization
11ead34
fix flake8 warnings
d979c50
Merge branch 'main' into lstein/model-manager-refactor
lstein c9a016f
more flake8 fixes
b09e012
Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/…
79b2423
last flake8 fix - why is local flake8 not identical to git flake8?
a7aca29
implement regression tests for pause/cancel/error conditions
2165d55
add checks for malformed URLs and malicious content dispositions
b7ca983
blackify
598fe81
wire together download and install; now need to write install events
64424c6
install of repo_ids records author, tags and license
3582cfa
make download manager optional in InvokeAIServices during development
b2892f9
incorporate civitai metadata into model config
b7a6a53
fix flake8 warnings
8636015
increase download chunksize for better speed
8052f2e
Merge branch 'main' into lstein/model-manager-refactor
f454304
make it possible to pause/resume repo_id downloads
b583bdd
loading works -- web app broken
7430d87
loader working
6d8b2a7
pytests mostly working; model_manager_service needs rewriting
lstein 4b932b2
refactor create_download_job; override probe info in install call
lstein 27dcd89
merge with main; model_manager_service.py needs to be rewritten
ac88863
fix exception traceback reporting
lstein 171d789
model loader autoscans models_dir on initialization
lstein 716a1b6
model_manager_service now mostly type correct
lstein a033ccc
blackify
lstein 3529925
services rewritten; starting work on routes
b7789bb
list_models() API call now working
08952b9
Merge branch 'main' into lstein/model-manager-refactor
b9a90fb
blackify and isort
db7fdc3
fix more isort issues
c090c5f
update_model and delete_model working; convert is WIP
dc68347
loading and conversions of checkpoints working
c029534
all methods in router API now tested and working
539776a
import_model API now working
e880f4b
add logs to confirm that event info is being sent to bus
f0ce559
add install job control to web API
238d7fa
add models.yaml conversion script
d051c08
attempt to fix flake8 lint errors
151ba02
fix models.yaml version assertion error in pytests
d1382f2
fasthash produces same results on windows & linux
lstein 0c88491
Merge branch 'main' into lstein/model-manager-refactor
73bc088
blackify
de666fd
move incorrectly placed models into correct directory at startup time
ed91f48
TUI installer more or less working
lstein 3402cf6
preserve description in metadata when installing a starter model
3199409
TUI installer functional; minor cosmetic work needed
30aea54
remove debug statement
c9cd418
add/delete from command line working; training words downloaded
07ddd60
fix install of models with relative paths
d2cdbe5
configure script now working
d5d517d
correctly download the selected version of a civitai model
ab58eb2
resolve conflicts with ip-adapter change
6edee2d
automatically convert models.yaml to new format
8bc1ca0
allow priority to be set at install job submission time
f9b92dd
resolve conflicts with get_logger() code changes from main
ac46340
merge with main & resolve conflicts
effced8
added `cancel_all` and `prune` model install operations to router API
1d6a4e7
add tests for model installation events
2e9a7b0
Merge branch 'main' into lstein/model-manager-refactor
lstein 0b75a4f
resolve merge conflicts
81fce18
reorder pytests to prevent fixture race condition
2f16a2c
fix migrate script and type mismatches in probe, config and loader
3b832f1
fix one more type mismatch in probe module
4555aec
remove unused code from invokeai.backend.model_manager.storage.yaml
cbf0310
add README explaining reorg of tests directory
208d390
almost all type mismatches fixed
807ae82
more type mismatch fixes
acaaff4
make model merge script work with new model manager
c025c9c
speed up model scanning at startup
230ee18
do not ignore keyboard interrupt while scanning models
c91429d
merge with main
63f6c12
make merge script read invokeai.yaml when default root passed
48c3d92
make textual inversion training work with new model manager
062a6ed
prevent crash on windows due to lack of os.pathconf call
e3912e8
replace config.ram_cache_size with config.ram and similarly for vram
459f023
multiple minor fixes
4624de0
Merge branch 'main' into lstein/model-manager-refactor
lstein de90d40
Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/…
16ec7a3
fix type mismatches in download_manager service
a180c0f
check model hash before and after moving in filesystem
cb0fdf3
refactor model install job class hierarchy
cd5d3e3
refactor model_manager_service.py into small functional modules
9cbc62d
fix reorganized module dependencies
8e06088
refactor services
6303f74
allow user to select main database or external file for model record/…
00e85bc
make autoimport directory optional, defaulting to inactive
4421638
fix conversion call
432231e
merge with main
7f68f58
restore printing of version when invokeai-web and invokeai called wit…
5106054
support clipvision image encoder downloading
a64a34b
add support for repo_id subfolders
e5b2bc8
refactor download queue jobs
bccfe8b
fix some type mismatches introduces by reorg
ce2baa3
port support for AutoencoderTiny models
a80ff75
Update invokeai/app/invocations/model.py
lstein fe10386
address all PR 4252 comments from ryan through October 5
3644d40
Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/…
3962914
merge with main
33d4756
improve selection of huggingface repo id files to download
4149d35
refactor installer class hierarchy
e50a257
merge with main
4bab724
fix broken import
67607f0
fix issues with module import order breaking pytest node tests
71e7e61
add documentation for model record service and loader
76aa19a
first draft of documentation finished
e079cc9
add back source URL validation to download job hierarchy
0a0412f
restore CLI to broken state
a2079bd
Update docs/installation/050_INSTALLING_MODELS.md
lstein aace679
Update invokeai/app/services/model_convert.py
lstein b708aef
misc small fixes requested by Ryan
5f80d4d
Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/…
a51b165
clean up model downloader status locking to avoid race conditions
0f9c676
remove download queue change_priority() calls completely
53e1199
prevent potential infinite recursion on exceptions raised by event ha…
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,276 @@ | ||
| # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team | ||
| """ | ||
| Configuration definitions for image generation models. | ||
|
|
||
| Typical usage: | ||
|
|
||
| from invokeai.backend.model_management2.model_config import ModelConfigFactory | ||
| raw = dict(path='models/sd-1/main/foo.ckpt', | ||
| name='foo', | ||
| base_model='sd-1', | ||
| model_type='main', | ||
| config='configs/stable-diffusion/v1-inference.yaml', | ||
| model_variant='normal', | ||
| model_format='checkpoint' | ||
| ) | ||
| config = ModelConfigFactory.make_config(raw) | ||
| print(config.name) | ||
|
|
||
| Validation errors will raise an InvalidModelConfigException error. | ||
|
|
||
| """ | ||
| import pydantic | ||
| from enum import Enum | ||
| from pydantic import BaseModel, Field, Extra | ||
| from typing import Optional, Literal, List, Union, Type | ||
| from pydantic.error_wrappers import ValidationError | ||
| from omegaconf.listconfig import ListConfig # to support the yaml backend | ||
|
|
||
|
|
||
| class InvalidModelConfigException(Exception): | ||
| """Exception raised when the config parser doesn't recognize the passed | ||
| combination of model type and format.""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| class BaseModelType(str, Enum): | ||
| """Base model type.""" | ||
|
|
||
| StableDiffusion1 = "sd-1" | ||
| StableDiffusion2 = "sd-2" | ||
| StableDiffusionXL = "sdxl" | ||
| StableDiffusionXLRefiner = "sdxl-refiner" | ||
| # Kandinsky2_1 = "kandinsky-2.1" | ||
|
|
||
|
|
||
| class ModelType(str, Enum): | ||
| """Model type.""" | ||
|
|
||
| ONNX = "onnx" | ||
| Main = "main" | ||
| Vae = "vae" | ||
| Lora = "lora" | ||
| ControlNet = "controlnet" # used by model_probe | ||
| TextualInversion = "embedding" | ||
|
|
||
|
|
||
| class SubModelType(str, Enum): | ||
| """Submodel type.""" | ||
|
|
||
| UNet = "unet" | ||
| TextEncoder = "text_encoder" | ||
| TextEncoder2 = "text_encoder_2" | ||
| Tokenizer = "tokenizer" | ||
| Tokenizer2 = "tokenizer_2" | ||
| Vae = "vae" | ||
| VaeDecoder = "vae_decoder" | ||
| VaeEncoder = "vae_encoder" | ||
| Scheduler = "scheduler" | ||
| SafetyChecker = "safety_checker" | ||
|
|
||
|
|
||
| class ModelVariantType(str, Enum): | ||
| """Variant type.""" | ||
|
|
||
| Normal = "normal" | ||
| Inpaint = "inpaint" | ||
| Depth = "depth" | ||
|
|
||
|
|
||
| class ModelFormat(str, Enum): | ||
| """Storage format of model.""" | ||
|
|
||
| Diffusers = "diffusers" | ||
| Checkpoint = "checkpoint" | ||
| Lycoris = "lycoris" | ||
| Onnx = "onnx" | ||
| Olive = "olive" | ||
| EmbeddingFile = "embedding_file" | ||
| EmbeddingFolder = "embedding_folder" | ||
|
|
||
|
|
||
| class SchedulerPredictionType(str, Enum): | ||
| """Scheduler prediction type.""" | ||
|
|
||
| Epsilon = "epsilon" | ||
| VPrediction = "v_prediction" | ||
| Sample = "sample" | ||
|
|
||
|
|
||
| class ModelConfigBase(BaseModel): | ||
| """Base class for model configuration information.""" | ||
|
|
||
| path: str | ||
| name: str | ||
| base_model: BaseModelType | ||
| model_type: ModelType | ||
| model_format: ModelFormat | ||
| id: Optional[str] = Field(None) # this may get added by the store | ||
| description: Optional[str] = Field(None) | ||
| author: Optional[str] = Field(description="Model author") | ||
| thumbnail_url: Optional[str] = Field(description="URL of thumbnail image") | ||
| license_url: Optional[str] = Field(description="URL of license") | ||
| source_url: Optional[str] = Field(description="Model download source") | ||
| tags: Optional[List[str]] = Field(description="Descriptive tags") # Set would be better, but not JSON serializable | ||
|
|
||
| class Config: | ||
| """Pydantic configuration hint.""" | ||
|
|
||
| use_enum_values = True | ||
| extra = Extra.forbid | ||
| validate_assignment = True | ||
|
|
||
| @pydantic.validator("tags", pre=True) | ||
| @classmethod | ||
| def _fix_tags(cls, v): | ||
| if isinstance(v, ListConfig): # to support yaml backend | ||
| v = list(v) | ||
| return v | ||
|
|
||
|
|
||
| class CheckpointConfig(ModelConfigBase): | ||
| """Model config for checkpoint-style models.""" | ||
|
|
||
| model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint | ||
| config: str = Field(description="path to the checkpoint model config file") | ||
|
|
||
|
|
||
| class DiffusersConfig(ModelConfigBase): | ||
| """Model config for diffusers-style models.""" | ||
|
|
||
| model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers | ||
|
|
||
|
|
||
| class LoRAConfig(ModelConfigBase): | ||
| """Model config for LoRA/Lycoris models.""" | ||
|
|
||
| model_format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers] | ||
|
|
||
|
|
||
| class VaeCheckpointConfig(ModelConfigBase): | ||
| """Model config for standalone VAE models.""" | ||
|
|
||
| model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint | ||
|
|
||
|
|
||
| class VaeDiffusersConfig(ModelConfigBase): | ||
| """Model config for standalone VAE models (diffusers version).""" | ||
|
|
||
| model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers | ||
|
|
||
|
|
||
| class TextualInversionConfig(ModelConfigBase): | ||
| """Model config for textual inversion embeddings.""" | ||
|
|
||
| model_format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder] | ||
|
|
||
|
|
||
| class MainConfig(ModelConfigBase): | ||
| """Model config for main models.""" | ||
|
|
||
| vae: Optional[str] = Field(None) | ||
| model_variant: ModelVariantType = ModelVariantType.Normal | ||
|
|
||
|
|
||
| class MainCheckpointConfig(CheckpointConfig, MainConfig): | ||
| """Model config for main checkpoint models.""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| class MainDiffusersConfig(DiffusersConfig, MainConfig): | ||
| """Model config for main diffusers models.""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| class ONNXSD1Config(MainConfig): | ||
| """Model config for ONNX format models based on sd-1.""" | ||
|
|
||
| model_format: Literal[ModelFormat.Onnx, ModelFormat.Olive] | ||
|
|
||
|
|
||
| class ONNXSD2Config(MainConfig): | ||
| """Model config for ONNX format models based on sd-2.""" | ||
|
|
||
| model_format: Literal[ModelFormat.Onnx, ModelFormat.Olive] | ||
| # No yaml config file for ONNX, so these are part of config | ||
| prediction_type: SchedulerPredictionType | ||
| upcast_attention: bool | ||
|
|
||
|
|
||
| class ModelConfigFactory(object): | ||
| """Class for parsing config dicts into StableDiffusion*Config obects.""" | ||
|
|
||
| _class_map: dict = { | ||
| ModelFormat.Checkpoint: { | ||
| ModelType.Main: MainCheckpointConfig, | ||
| ModelType.Vae: VaeCheckpointConfig, | ||
| }, | ||
| ModelFormat.Diffusers: { | ||
| ModelType.Main: MainDiffusersConfig, | ||
| ModelType.Lora: LoRAConfig, | ||
| ModelType.Vae: VaeDiffusersConfig, | ||
| }, | ||
| ModelFormat.Lycoris: { | ||
| ModelType.Lora: LoRAConfig, | ||
| }, | ||
| ModelFormat.Onnx: { | ||
| ModelType.ONNX: { | ||
| BaseModelType.StableDiffusion1: ONNXSD1Config, | ||
| BaseModelType.StableDiffusion2: ONNXSD2Config, | ||
| }, | ||
| }, | ||
| ModelFormat.Olive: { | ||
| ModelType.ONNX: { | ||
| BaseModelType.StableDiffusion1: ONNXSD1Config, | ||
| BaseModelType.StableDiffusion2: ONNXSD2Config, | ||
| }, | ||
| }, | ||
| ModelFormat.EmbeddingFile: { | ||
| ModelType.TextualInversion: TextualInversionConfig, | ||
| }, | ||
| ModelFormat.EmbeddingFolder: { | ||
| ModelType.TextualInversion: TextualInversionConfig, | ||
| }, | ||
| } | ||
|
|
||
| @classmethod | ||
| def make_config( | ||
| cls, | ||
| model_data: Union[dict, ModelConfigBase], | ||
| dest_class: Optional[Type] = None, | ||
| ) -> Union[ | ||
| MainCheckpointConfig, | ||
| MainDiffusersConfig, | ||
| LoRAConfig, | ||
| TextualInversionConfig, | ||
| ONNXSD1Config, | ||
| ONNXSD2Config, | ||
| ]: | ||
| """ | ||
| Return the appropriate config object from raw dict values. | ||
|
|
||
| :param model_data: A raw dict corresponding the obect fields to be | ||
| parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase | ||
| object, which will be passed through unchanged. | ||
| :param dest_class: The config class to be returned. If not provided, will | ||
| be selected automatically. | ||
| """ | ||
| if isinstance(model_data, ModelConfigBase): | ||
| return model_data | ||
| try: | ||
| model_format = model_data.get("model_format") | ||
| model_type = model_data.get("model_type") | ||
| model_base = model_data.get("base_model") | ||
| class_to_return = dest_class or cls._class_map[model_format][model_type] | ||
| if isinstance(class_to_return, dict): # additional level allowed | ||
| class_to_return = class_to_return[model_base] | ||
| return class_to_return.parse_obj(model_data) | ||
| except KeyError: | ||
| raise InvalidModelConfigException( | ||
| f"Unknown combination of model_format '{model_format}' and model_type '{model_type}'" | ||
| ) | ||
| except ValidationError as e: | ||
| raise InvalidModelConfigException(f"Invalid model configuration passed: {str(e)}") from e | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team | ||
| """ | ||
| Abstract base class for storing and retrieving model configuration records. | ||
| """ | ||
|
|
||
|
|
||
| from abc import ABC, abstractmethod | ||
| from typing import Union, Set, List, Optional | ||
|
|
||
| from ..model_config import ModelConfigBase, BaseModelType, ModelType | ||
|
|
||
|
|
||
| class DuplicateModelException(Exception): | ||
| """Raised on an attempt to add a model with the same key twice.""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| class UnknownModelException(Exception): | ||
| """Raised on an attempt to delete a model with a nonexistent key.""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| class ModelConfigStore(ABC): | ||
| """Abstract base class for storage and retrieval of model configs.""" | ||
|
|
||
| @abstractmethod | ||
| def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None: | ||
| """ | ||
| Add a model to the database. | ||
|
|
||
| :param key: Unique key for the model | ||
| :param config: Model configuration record, either a dict with the | ||
| required fields or a ModelConfigBase instance. | ||
|
|
||
| Can raise DuplicateModelException and InvalidModelConfig exceptions. | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def del_model(self, key: str) -> None: | ||
| """ | ||
| Delete a model. | ||
|
|
||
| :param key: Unique key for the model to be deleted | ||
|
|
||
| Can raise an UnknownModelException | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: | ||
| """ | ||
| Update the model, returning the updated version. | ||
|
|
||
| :param key: Unique key for the model to be updated | ||
| :param config: Model configuration record. Either a dict with the | ||
| required fields, or a ModelConfigBase instance. | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def get_model(self, key: str) -> ModelConfigBase: | ||
| """ | ||
| Retrieve the ModelConfigBase instance for the indicated model. | ||
|
|
||
| :param key: Key of model config to be fetched. | ||
|
|
||
| Exceptions: UnknownModelException | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def exists(self, key: str) -> bool: | ||
| """ | ||
| Return True if a model with the indicated key exists in the databse. | ||
|
|
||
| :param key: Unique key for the model to be deleted | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def search_by_tag(self, tags: Set[str]) -> List[ModelConfigBase]: | ||
| """ | ||
| Return models containing all of the listed tags. | ||
|
|
||
| :param tags: Set of tags to search on. | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def search_by_type( | ||
| self, | ||
| model_name: Optional[str] = None, | ||
| base_model: Optional[BaseModelType] = None, | ||
| model_type: Optional[ModelType] = None, | ||
| ) -> List[ModelConfigBase]: | ||
| """ | ||
| Return models matching name, base and/or type. | ||
|
|
||
| :param model_name: Filter by name of model (optional) | ||
| :param base_model: Filter by base model (optional) | ||
| :param model_type: Filter by type of model (optional) | ||
|
|
||
| If none of the optional filters are passed, will return all | ||
| models in the database. | ||
| """ | ||
| pass | ||
|
|
||
| def all_models(self) -> List[ModelConfigBase]: | ||
| """ | ||
| Return all the model configs in the database. | ||
| """ | ||
| return self.search_by_type() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.